import os from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import QueuePool from contextlib import contextmanager from .models import Base class DatabaseManager: """数据库管理器""" def __init__(self, settings): """初始化数据库管理器 Args: settings: Scrapy设置对象 """ self.sqlite_file = settings.get('SQLITE_FILE', 'videos.db') self.mysql_config = { 'host': settings.get('MYSQL_HOST', 'localhost'), 'port': settings.get('MYSQL_PORT', 3306), 'user': settings.get('MYSQL_USER', 'root'), 'password': settings.get('MYSQL_PASSWORD', ''), 'database': settings.get('MYSQL_DATABASE', 'crawler'), } # 初始化数据库引擎 self._init_sqlite() self._init_mysql() # 创建会话工厂 self.sqlite_session_maker = sessionmaker(bind=self.sqlite_engine) self.mysql_session_maker = sessionmaker(bind=self.mysql_engine) def _init_sqlite(self): """初始化SQLite数据库""" # 确保数据库目录存在 db_dir = os.path.dirname(self.sqlite_file) if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir) # 创建SQLite引擎 self.sqlite_engine = create_engine( f'sqlite:///{self.sqlite_file}', poolclass=QueuePool, pool_size=5, max_overflow=10, pool_timeout=30 ) # 自动创建所有表 Base.metadata.create_all(self.sqlite_engine) def _init_mysql(self): """初始化MySQL/MariaDB数据库""" # 创建MySQL引擎 self.mysql_engine = create_engine( 'mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset=utf8mb4'.format( **self.mysql_config ), poolclass=QueuePool, pool_size=5, max_overflow=10, pool_timeout=30, pool_pre_ping=True # 自动检测断开的连接 ) @contextmanager def sqlite_session(self): """SQLite会话上下文管理器 Yields: Session: SQLite数据库会话 """ session = self.sqlite_session_maker() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close() @contextmanager def mysql_session(self): """MySQL会话上下文管理器 Yields: Session: MySQL数据库会话 """ session = self.mysql_session_maker() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close()