crawler_81tv/scrapy_proj/database.py
2025-06-08 16:25:53 +08:00

99 lines
3.0 KiB
Python

import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
from contextlib import contextmanager
from .models import Base
class DatabaseManager:
"""数据库管理器"""
def __init__(self, settings):
"""初始化数据库管理器
Args:
settings: Scrapy设置对象
"""
self.sqlite_file = settings.get('SQLITE_FILE', 'videos.db')
self.mysql_config = {
'host': settings.get('MYSQL_HOST', 'localhost'),
'port': settings.get('MYSQL_PORT', 3306),
'user': settings.get('MYSQL_USER', 'root'),
'password': settings.get('MYSQL_PASSWORD', ''),
'database': settings.get('MYSQL_DATABASE', 'crawler'),
}
# 初始化数据库引擎
self._init_sqlite()
self._init_mysql()
# 创建会话工厂
self.sqlite_session_maker = sessionmaker(bind=self.sqlite_engine)
self.mysql_session_maker = sessionmaker(bind=self.mysql_engine)
def _init_sqlite(self):
"""初始化SQLite数据库"""
# 确保数据库目录存在
db_dir = os.path.dirname(self.sqlite_file)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir)
# 创建SQLite引擎
self.sqlite_engine = create_engine(
f'sqlite:///{self.sqlite_file}',
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_timeout=30
)
# 自动创建所有表
Base.metadata.create_all(self.sqlite_engine)
def _init_mysql(self):
"""初始化MySQL/MariaDB数据库"""
# 创建MySQL引擎
self.mysql_engine = create_engine(
'mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset=utf8mb4'.format(
**self.mysql_config
),
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_timeout=30,
pool_pre_ping=True # 自动检测断开的连接
)
@contextmanager
def sqlite_session(self):
"""SQLite会话上下文管理器
Yields:
Session: SQLite数据库会话
"""
session = self.sqlite_session_maker()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
@contextmanager
def mysql_session(self):
"""MySQL会话上下文管理器
Yields:
Session: MySQL数据库会话
"""
session = self.mysql_session_maker()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()