feature: 数据库连接与SQL集中管理,提高代码可读性

This commit is contained in:
liuwei
2025-03-18 10:24:38 +08:00
parent fd1676b908
commit 727d2d3938
7 changed files with 571 additions and 0 deletions

80
db/base.py Normal file
View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
import logging
from typing import List, Dict, Any, Optional, Tuple, Union
from db.connection import DBConnectionManager
class BaseDBOperator:
"""基础数据库操作类"""
def __init__(self, db_manager: DBConnectionManager):
self.db_manager = db_manager
self.LOG = logging.getLogger(self.__class__.__name__)
def execute_query(self, sql: str, params: Optional[tuple] = None, fetch_one: bool = False) -> Union[List[Dict], Dict, None]:
"""执行查询SQL"""
conn = self.db_manager.get_mysql_connection()
try:
with conn.cursor(dictionary=True) as cursor:
cursor.execute(sql, params or ())
if fetch_one:
return cursor.fetchone()
return cursor.fetchall()
except Exception as e:
self.LOG.error(f"执行查询SQL出错: {e}, SQL: {sql}, 参数: {params}")
return None
finally:
conn.close()
def execute_update(self, sql: str, params: Optional[tuple] = None) -> bool:
"""执行更新SQL"""
conn = self.db_manager.get_mysql_connection()
try:
with conn.cursor() as cursor:
cursor.execute(sql, params or ())
conn.commit()
return True
except Exception as e:
self.LOG.error(f"执行更新SQL出错: {e}, SQL: {sql}, 参数: {params}")
conn.rollback()
return False
finally:
conn.close()
def execute_batch(self, sql: str, params_list: List[tuple]) -> bool:
"""批量执行SQL"""
if not params_list:
return True
conn = self.db_manager.get_mysql_connection()
try:
with conn.cursor() as cursor:
cursor.executemany(sql, params_list)
conn.commit()
return True
except Exception as e:
self.LOG.error(f"批量执行SQL出错: {e}, SQL: {sql}, 参数数量: {len(params_list)}")
conn.rollback()
return False
finally:
conn.close()
def execute_transaction(self, operations: List[Tuple[str, tuple]]) -> bool:
"""执行事务"""
if not operations:
return True
conn = self.db_manager.get_mysql_connection()
try:
with conn.cursor() as cursor:
for sql, params in operations:
cursor.execute(sql, params)
conn.commit()
return True
except Exception as e:
self.LOG.error(f"执行事务出错: {e}, 操作数量: {len(operations)}")
conn.rollback()
return False
finally:
conn.close()

51
db/connection.py Normal file
View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
import logging
import mysql.connector.pooling
import redis
from typing import Optional
class DBConnectionManager:
"""数据库连接管理类用于管理MySQL和Redis连接池"""
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(DBConnectionManager, cls).__new__(cls)
cls._instance.initialized = False
return cls._instance
def __init__(self, mysql_config: Optional[dict] = None, redis_config: Optional[dict] = None):
if self.initialized:
return
self.LOG = logging.getLogger("DBConnectionManager")
# 初始化MySQL连接池
if mysql_config:
self.mysql_pool = mysql.connector.pooling.MySQLConnectionPool(**mysql_config)
self.LOG.info(f"MySQL连接池初始化完成: {mysql_config}")
else:
self.mysql_pool = None
# 初始化Redis连接池
if redis_config:
self.redis_pool = redis.ConnectionPool(**redis_config)
self.LOG.info(f"Redis连接池初始化完成: {redis_config}")
else:
self.redis_pool = None
self.initialized = True
def get_mysql_connection(self):
"""获取MySQL连接"""
if not self.mysql_pool:
raise Exception("MySQL连接池未初始化")
return self.mysql_pool.get_connection()
def get_redis_connection(self):
"""获取Redis连接"""
if not self.redis_pool:
raise Exception("Redis连接池未初始化")
return redis.Redis(connection_pool=self.redis_pool)

154
db/encyclopedia.py Normal file
View File

@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
from typing import Dict, List, Optional, Tuple
from db.base import BaseDBOperator
from db.connection import DBConnectionManager
class EncyclopediaDB(BaseDBOperator):
"""百科答题游戏相关数据库操作"""
def __init__(self, db_manager: DBConnectionManager):
super().__init__(db_manager)
def add_group(self, group_id: str) -> bool:
"""添加群组"""
sql = "INSERT INTO t_encyclopedia_groups (group_id) VALUES (%s)"
return self.execute_update(sql, (group_id,))
def check_group_exists(self, group_id: str) -> bool:
"""检查群组是否存在"""
sql = "SELECT 1 FROM t_encyclopedia_groups WHERE group_id = %s"
result = self.execute_query(sql, (group_id,), fetch_one=True)
return result is not None
def add_player(self, player_id: str, group_id: str, player_name: str) -> bool:
"""添加玩家"""
sql = """
INSERT INTO t_encyclopedia_players (player_id, group_id, player_name)
VALUES (%s, %s, %s)
ON DUPLICATE KEY UPDATE player_name = VALUES(player_name)
"""
return self.execute_update(sql, (player_id, group_id, player_name))
def get_active_task(self, group_id: str) -> Optional[Dict]:
"""获取群组的活跃任务"""
sql = """
SELECT active_task_id, group_id, question, answer, score, description, holder_id, assigned_at, status
FROM t_encyclopedia_active_tasks
WHERE group_id = %s AND status = 'pending'
ORDER BY assigned_at DESC
LIMIT 1
"""
return self.execute_query(sql, (group_id,), fetch_one=True)
def add_task_history(self, group_id: str, task_id: int, player_id: str,
answer: str, is_correct: bool, points_earned: int) -> bool:
"""添加任务历史记录"""
sql = """
INSERT INTO t_encyclopedia_task_history
(group_id, active_task_id, player_id, answer, is_correct, points_earned)
VALUES (%s, %s, %s, %s, %s, %s)
"""
params = (group_id, task_id, player_id, answer, 1 if is_correct else 0, points_earned)
return self.execute_update(sql, params)
def get_player_ranking(self, group_id: str, limit: int = 10) -> List[Dict]:
"""获取玩家排名"""
sql = """
SELECT player_name, points
FROM t_encyclopedia_players
WHERE group_id = %s
ORDER BY points DESC
LIMIT %s
"""
return self.execute_query(sql, (group_id, limit)) or []
def create_active_task(self, group_id: str, question: str, answer: str,
score: int, description: str, holder_id: str) -> Optional[int]:
"""创建活跃任务"""
sql = """
INSERT INTO t_encyclopedia_active_tasks
(group_id, question, answer, score, description, holder_id)
VALUES (%s, %s, %s, %s, %s, %s)
"""
if self.execute_update(sql, (group_id, question, answer, score, description, holder_id)):
# 获取最新创建的任务ID
get_id_sql = """
SELECT active_task_id
FROM t_encyclopedia_active_tasks
WHERE group_id = %s AND question = %s AND holder_id = %s
ORDER BY assigned_at DESC LIMIT 1
"""
result = self.execute_query(get_id_sql, (group_id, question, holder_id), fetch_one=True)
return result['active_task_id'] if result else None
return None
def update_player_points(self, player_id: str, group_id: str, points: int) -> bool:
"""更新玩家积分"""
if points > 0:
sql = """
UPDATE t_encyclopedia_players
SET points = points + %s
WHERE group_id = %s AND player_id = %s
"""
else:
sql = """
UPDATE t_encyclopedia_players
SET points = GREATEST(points + %s, 0)
WHERE group_id = %s AND player_id = %s
"""
return self.execute_update(sql, (points, group_id, player_id))
def complete_task(self, active_task_id: int) -> bool:
"""完成任务"""
sql = """
UPDATE t_encyclopedia_active_tasks
SET status = 'completed'
WHERE active_task_id = %s
"""
return self.execute_update(sql, (active_task_id,))
def get_player(self, player_id: str, group_id: str) -> Optional[Dict]:
"""获取玩家信息"""
sql = """
SELECT player_id, player_name, points
FROM t_encyclopedia_players
WHERE group_id = %s AND player_id = %s
"""
return self.execute_query(sql, (group_id, player_id), fetch_one=True)
def get_all_groups(self) -> List[str]:
"""获取所有群组ID"""
sql = "SELECT group_id FROM t_encyclopedia_groups"
results = self.execute_query(sql)
return [row['group_id'] for row in results] if results else []
def get_all_players_in_group(self, group_id: str) -> List[Dict]:
"""获取群组中的所有玩家"""
sql = "SELECT player_id, player_name FROM t_encyclopedia_players WHERE group_id = %s"
return self.execute_query(sql, (group_id,)) or []
def get_active_tasks_in_group(self, group_id: str) -> List[Dict]:
"""获取群组中的所有活跃任务"""
sql = """
SELECT a.active_task_id, a.question, p.player_name, p.player_id
FROM t_encyclopedia_active_tasks a
JOIN t_encyclopedia_players p ON a.holder_id = p.player_id AND a.group_id = p.group_id
WHERE a.group_id = %s AND a.status = 'pending'
"""
return self.execute_query(sql, (group_id,)) or []
def get_task_by_id(self, group_id: str, task_id: int) -> Optional[Dict]:
"""根据ID获取任务"""
sql = """
SELECT question, answer, score, holder_id, status
FROM t_encyclopedia_active_tasks
WHERE group_id = %s AND active_task_id = %s
"""
return self.execute_query(sql, (group_id, task_id), fetch_one=True)
def get_task_holder(self, group_id: str, holder_id: str) -> Optional[Dict]:
"""获取任务持有者信息"""
sql = "SELECT player_name FROM t_encyclopedia_players WHERE group_id = %s AND player_id = %s"
return self.execute_query(sql, (group_id, holder_id), fetch_one=True)

69
db/message_storage.py Normal file
View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Dict, List, Optional
from wcferry import WxMsg
from db.base import BaseDBOperator
from db.connection import DBConnectionManager
class MessageStorageDB(BaseDBOperator):
"""消息存储相关数据库操作"""
def __init__(self, db_manager: DBConnectionManager):
super().__init__(db_manager)
def archive_message(self, msg: WxMsg) -> bool:
"""存档消息"""
now_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sql = """
INSERT INTO messages (group_id, timestamp, sender, content, message_type, attachment_url, message_id, message_xml, message_thumb)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
params = (msg.roomid, now_time, msg.sender, msg.content, msg.type, msg.extra, msg.id, msg.xml, msg.thumb)
result = self.execute_update(sql, params)
if result:
self.LOG.info(f"消息存档成功: {now_time}:{msg.roomid}:{msg.sender}")
return result
def get_recent_messages(self, group_id: str, hours_ago: int = 8, min_content_length: int = 6) -> List[Dict]:
"""获取最近的消息"""
sql = """
SELECT timestamp, sender, content, message_type
FROM messages
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s HOUR)
AND message_type in (1, 49)
AND group_id = %s
AND length(content) > %s
AND content NOT LIKE '/%'
"""
params = (hours_ago, group_id, min_content_length)
return self.execute_query(sql, params) or []
def get_message_count_by_date(self, date: str) -> List[Dict]:
"""获取指定日期的消息统计"""
sql = """
SELECT group_id, sender, COUNT(*) as count
FROM messages
WHERE DATE(timestamp) = %s
GROUP BY group_id, sender
"""
return self.execute_query(sql, (date,)) or []
# 在 MessageStorageDB 类中添加以下方法
def get_speech_ranking(self, date: str, group_id: str, limit: int = 20) -> List[Dict]:
"""获取指定日期和群组的发言排名"""
sql = """
SELECT wx_id, count AS speech_count
FROM speech_counts
WHERE date = %s
AND group_id = %s
GROUP BY wx_id
ORDER BY count DESC
LIMIT %s
"""
params = (date, group_id, limit)
results = self.execute_query(sql, params)
return results or []

64
db/sign_in.py Normal file
View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Dict, Optional
from db.base import BaseDBOperator
from db.connection import DBConnectionManager
class SignInDB(BaseDBOperator):
"""签到系统相关数据库操作"""
def __init__(self, db_manager: DBConnectionManager):
super().__init__(db_manager)
def initialize_table(self) -> bool:
"""初始化签到表"""
sql = """
CREATE TABLE IF NOT EXISTS t_sign_record (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
wx_id VARCHAR(100) NOT NULL,
group_id VARCHAR(100) NOT NULL,
wx_nick_name VARCHAR(100) NOT NULL,
points INT DEFAULT 0,
sign_stat DATETIME,
signin_streak INT DEFAULT 0,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY unique_sign (wx_id, group_id)
)
"""
return self.execute_update(sql)
def get_user_record(self, wx_id: str, group_id: str) -> Optional[Dict]:
"""获取用户签到记录"""
sql = """
SELECT wx_id, group_id, wx_nick_name, points, sign_stat, signin_streak
FROM t_sign_record
WHERE wx_id = %s AND group_id = %s
"""
return self.execute_query(sql, (wx_id, group_id), fetch_one=True)
def update_sign_record(self, wx_id: str, group_id: str, wx_nick_name: str,
points_to_add: int, sign_time: datetime, streak: int) -> bool:
"""更新签到记录"""
sql = """
UPDATE t_sign_record
SET wx_nick_name = %s, points = points + %s,
sign_stat = %s, signin_streak = %s,
update_time = %s
WHERE wx_id = %s AND group_id = %s
"""
params = (wx_nick_name, points_to_add, sign_time, streak, sign_time, wx_id, group_id)
return self.execute_update(sql, params)
def create_sign_record(self, wx_id: str, group_id: str, wx_nick_name: str,
points: int, sign_time: datetime, streak: int) -> bool:
"""创建签到记录"""
sql = """
INSERT INTO t_sign_record
(wx_id, group_id, wx_nick_name, points, sign_stat, signin_streak)
VALUES (%s, %s, %s, %s, %s, %s)
"""
params = (wx_id, group_id, wx_nick_name, points, sign_time, streak)
return self.execute_update(sql, params)

77
db/sign_in_redis.py Normal file
View File

@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Dict, Optional
from db.connection import DBConnectionManager
class SignInRedisDB:
"""签到系统Redis相关操作"""
def __init__(self, db_manager: DBConnectionManager):
self.db_manager = db_manager
self.prefix = "group:sign_in:"
def get_redis_connection(self):
"""获取Redis连接"""
return self.db_manager.get_redis_connection()
def load_signin_count(self) -> Dict[str, int]:
"""加载签到人数数据"""
signin_count = {}
with self.get_redis_connection() as redis_client:
keys = redis_client.keys(f'{self.prefix}*')
for key in keys:
if isinstance(key, bytes):
key = key.decode('utf-8')
if key == f'{self.prefix}last_reset_date':
continue
group_id = key.replace(self.prefix, '')
count = redis_client.get(key)
if count is not None:
if isinstance(count, bytes):
count = count.decode('utf-8')
signin_count[group_id] = int(count)
return signin_count
def save_signin_count(self, group_id: str, count: int) -> bool:
"""保存签到人数"""
try:
with self.get_redis_connection() as redis_client:
redis_client.set(f'{self.prefix}{group_id}', count)
return True
except Exception:
return False
def get_last_reset_date(self) -> Optional[datetime.date]:
"""获取最后重置日期"""
with self.get_redis_connection() as redis_client:
date_str = redis_client.get(f'{self.prefix}last_reset_date')
if date_str:
if isinstance(date_str, bytes):
date_str = date_str.decode('utf-8')
return datetime.strptime(date_str, '%Y-%m-%d').date()
return None
def save_last_reset_date(self, date: datetime.date) -> bool:
"""保存最后重置日期"""
try:
with self.get_redis_connection() as redis_client:
redis_client.set(f'{self.prefix}last_reset_date', date.strftime('%Y-%m-%d'))
return True
except Exception:
return False
def reset_daily_counts(self) -> bool:
"""重置每日签到计数"""
try:
with self.get_redis_connection() as redis_client:
keys = redis_client.keys(f'{self.prefix}*')
for key in keys:
if isinstance(key, bytes):
key = key.decode('utf-8')
if key != f'{self.prefix}last_reset_date':
redis_client.delete(key)
return True
except Exception:
return False

76
db/tasks.py Normal file
View File

@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Dict, List, Optional
from db.base import BaseDBOperator
from db.connection import DBConnectionManager
class TasksDB(BaseDBOperator):
"""任务管理相关数据库操作"""
def __init__(self, db_manager: DBConnectionManager):
super().__init__(db_manager)
def initialize_table(self) -> bool:
"""初始化任务表"""
sql = """
CREATE TABLE IF NOT EXISTS tasks (
task_id INT AUTO_INCREMENT PRIMARY KEY,
task_description VARCHAR(255) NOT NULL,
reminder_time TIME NOT NULL,
task_type ENUM('single', 'recurring') DEFAULT 'single',
status ENUM('pending', 'completed') DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
return self.execute_update(sql)
def add_task(self, description: str, reminder_time: str, task_type: str = 'single') -> Optional[int]:
"""添加任务"""
sql = """
INSERT INTO tasks (task_description, reminder_time, task_type)
VALUES (%s, %s, %s)
"""
conn = self.db_manager.get_mysql_connection()
try:
with conn.cursor() as cursor:
cursor.execute(sql, (description, reminder_time, task_type))
task_id = cursor.lastrowid
conn.commit()
return task_id
except Exception as e:
self.LOG.error(f"添加任务出错: {e}")
conn.rollback()
return None
finally:
conn.close()
def get_pending_tasks(self) -> List[Dict]:
"""获取待办任务"""
sql = """
SELECT task_id, task_description, reminder_time, task_type, status, created_at
FROM tasks
WHERE status = 'pending'
ORDER BY reminder_time
"""
return self.execute_query(sql) or []
def complete_task(self, task_id: int) -> bool:
"""完成任务"""
sql = """
UPDATE tasks
SET status = 'completed'
WHERE task_id = %s
"""
return self.execute_update(sql, (task_id,))
def get_tasks_by_time(self, current_time: str) -> List[Dict]:
"""获取指定时间的任务"""
sql = """
SELECT task_id, task_description, reminder_time, task_type
FROM tasks
WHERE TIME_FORMAT(reminder_time, '%H:%i') = %s
AND (status = 'pending' OR (status = 'completed' AND task_type = 'recurring'))
"""
return self.execute_query(sql, (current_time,)) or []