99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
import os
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import QueuePool
|
|
from contextlib import contextmanager
|
|
from .models import Base
|
|
|
|
class DatabaseManager:
|
|
"""数据库管理器"""
|
|
|
|
def __init__(self, settings):
|
|
"""初始化数据库管理器
|
|
|
|
Args:
|
|
settings: Scrapy设置对象
|
|
"""
|
|
self.sqlite_file = settings.get('SQLITE_FILE', 'videos.db')
|
|
self.mysql_config = {
|
|
'host': settings.get('MYSQL_HOST', 'localhost'),
|
|
'port': settings.get('MYSQL_PORT', 3306),
|
|
'user': settings.get('MYSQL_USER', 'root'),
|
|
'password': settings.get('MYSQL_PASSWORD', ''),
|
|
'database': settings.get('MYSQL_DATABASE', 'crawler'),
|
|
}
|
|
|
|
# 初始化数据库引擎
|
|
self._init_sqlite()
|
|
self._init_mysql()
|
|
|
|
# 创建会话工厂
|
|
self.sqlite_session_maker = sessionmaker(bind=self.sqlite_engine)
|
|
self.mysql_session_maker = sessionmaker(bind=self.mysql_engine)
|
|
|
|
def _init_sqlite(self):
|
|
"""初始化SQLite数据库"""
|
|
# 确保数据库目录存在
|
|
db_dir = os.path.dirname(self.sqlite_file)
|
|
if db_dir and not os.path.exists(db_dir):
|
|
os.makedirs(db_dir)
|
|
|
|
# 创建SQLite引擎
|
|
self.sqlite_engine = create_engine(
|
|
f'sqlite:///{self.sqlite_file}',
|
|
poolclass=QueuePool,
|
|
pool_size=5,
|
|
max_overflow=10,
|
|
pool_timeout=30
|
|
)
|
|
|
|
# 自动创建所有表
|
|
Base.metadata.create_all(self.sqlite_engine)
|
|
|
|
def _init_mysql(self):
|
|
"""初始化MySQL/MariaDB数据库"""
|
|
# 创建MySQL引擎
|
|
self.mysql_engine = create_engine(
|
|
'mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset=utf8mb4'.format(
|
|
**self.mysql_config
|
|
),
|
|
poolclass=QueuePool,
|
|
pool_size=5,
|
|
max_overflow=10,
|
|
pool_timeout=30,
|
|
pool_pre_ping=True # 自动检测断开的连接
|
|
)
|
|
|
|
@contextmanager
|
|
def sqlite_session(self):
|
|
"""SQLite会话上下文管理器
|
|
|
|
Yields:
|
|
Session: SQLite数据库会话
|
|
"""
|
|
session = self.sqlite_session_maker()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
@contextmanager
|
|
def mysql_session(self):
|
|
"""MySQL会话上下文管理器
|
|
|
|
Yields:
|
|
Session: MySQL数据库会话
|
|
"""
|
|
session = self.mysql_session_maker()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close() |