feature: 数据库连接与SQL集中管理,提高代码可读性
This commit is contained in:
80
db/base.py
Normal file
80
db/base.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple, Union
|
||||
|
||||
from db.connection import DBConnectionManager
|
||||
|
||||
|
||||
class BaseDBOperator:
|
||||
"""基础数据库操作类"""
|
||||
|
||||
def __init__(self, db_manager: DBConnectionManager):
|
||||
self.db_manager = db_manager
|
||||
self.LOG = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def execute_query(self, sql: str, params: Optional[tuple] = None, fetch_one: bool = False) -> Union[List[Dict], Dict, None]:
|
||||
"""执行查询SQL"""
|
||||
conn = self.db_manager.get_mysql_connection()
|
||||
try:
|
||||
with conn.cursor(dictionary=True) as cursor:
|
||||
cursor.execute(sql, params or ())
|
||||
if fetch_one:
|
||||
return cursor.fetchone()
|
||||
return cursor.fetchall()
|
||||
except Exception as e:
|
||||
self.LOG.error(f"执行查询SQL出错: {e}, SQL: {sql}, 参数: {params}")
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_update(self, sql: str, params: Optional[tuple] = None) -> bool:
|
||||
"""执行更新SQL"""
|
||||
conn = self.db_manager.get_mysql_connection()
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(sql, params or ())
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.LOG.error(f"执行更新SQL出错: {e}, SQL: {sql}, 参数: {params}")
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_batch(self, sql: str, params_list: List[tuple]) -> bool:
|
||||
"""批量执行SQL"""
|
||||
if not params_list:
|
||||
return True
|
||||
|
||||
conn = self.db_manager.get_mysql_connection()
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.executemany(sql, params_list)
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.LOG.error(f"批量执行SQL出错: {e}, SQL: {sql}, 参数数量: {len(params_list)}")
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_transaction(self, operations: List[Tuple[str, tuple]]) -> bool:
|
||||
"""执行事务"""
|
||||
if not operations:
|
||||
return True
|
||||
|
||||
conn = self.db_manager.get_mysql_connection()
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
for sql, params in operations:
|
||||
cursor.execute(sql, params)
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.LOG.error(f"执行事务出错: {e}, 操作数量: {len(operations)}")
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
conn.close()
|
||||
51
db/connection.py
Normal file
51
db/connection.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import logging
|
||||
import mysql.connector.pooling
|
||||
import redis
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DBConnectionManager:
|
||||
"""数据库连接管理类,用于管理MySQL和Redis连接池"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(DBConnectionManager, cls).__new__(cls)
|
||||
cls._instance.initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, mysql_config: Optional[dict] = None, redis_config: Optional[dict] = None):
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.LOG = logging.getLogger("DBConnectionManager")
|
||||
|
||||
# 初始化MySQL连接池
|
||||
if mysql_config:
|
||||
self.mysql_pool = mysql.connector.pooling.MySQLConnectionPool(**mysql_config)
|
||||
self.LOG.info(f"MySQL连接池初始化完成: {mysql_config}")
|
||||
else:
|
||||
self.mysql_pool = None
|
||||
|
||||
# 初始化Redis连接池
|
||||
if redis_config:
|
||||
self.redis_pool = redis.ConnectionPool(**redis_config)
|
||||
self.LOG.info(f"Redis连接池初始化完成: {redis_config}")
|
||||
else:
|
||||
self.redis_pool = None
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def get_mysql_connection(self):
|
||||
"""获取MySQL连接"""
|
||||
if not self.mysql_pool:
|
||||
raise Exception("MySQL连接池未初始化")
|
||||
return self.mysql_pool.get_connection()
|
||||
|
||||
def get_redis_connection(self):
|
||||
"""获取Redis连接"""
|
||||
if not self.redis_pool:
|
||||
raise Exception("Redis连接池未初始化")
|
||||
return redis.Redis(connection_pool=self.redis_pool)
|
||||
154
db/encyclopedia.py
Normal file
154
db/encyclopedia.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from db.base import BaseDBOperator
|
||||
from db.connection import DBConnectionManager
|
||||
|
||||
|
||||
class EncyclopediaDB(BaseDBOperator):
|
||||
"""百科答题游戏相关数据库操作"""
|
||||
|
||||
def __init__(self, db_manager: DBConnectionManager):
|
||||
super().__init__(db_manager)
|
||||
|
||||
def add_group(self, group_id: str) -> bool:
|
||||
"""添加群组"""
|
||||
sql = "INSERT INTO t_encyclopedia_groups (group_id) VALUES (%s)"
|
||||
return self.execute_update(sql, (group_id,))
|
||||
|
||||
def check_group_exists(self, group_id: str) -> bool:
|
||||
"""检查群组是否存在"""
|
||||
sql = "SELECT 1 FROM t_encyclopedia_groups WHERE group_id = %s"
|
||||
result = self.execute_query(sql, (group_id,), fetch_one=True)
|
||||
return result is not None
|
||||
|
||||
def add_player(self, player_id: str, group_id: str, player_name: str) -> bool:
|
||||
"""添加玩家"""
|
||||
sql = """
|
||||
INSERT INTO t_encyclopedia_players (player_id, group_id, player_name)
|
||||
VALUES (%s, %s, %s)
|
||||
ON DUPLICATE KEY UPDATE player_name = VALUES(player_name)
|
||||
"""
|
||||
return self.execute_update(sql, (player_id, group_id, player_name))
|
||||
|
||||
def get_active_task(self, group_id: str) -> Optional[Dict]:
|
||||
"""获取群组的活跃任务"""
|
||||
sql = """
|
||||
SELECT active_task_id, group_id, question, answer, score, description, holder_id, assigned_at, status
|
||||
FROM t_encyclopedia_active_tasks
|
||||
WHERE group_id = %s AND status = 'pending'
|
||||
ORDER BY assigned_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
return self.execute_query(sql, (group_id,), fetch_one=True)
|
||||
|
||||
def add_task_history(self, group_id: str, task_id: int, player_id: str,
|
||||
answer: str, is_correct: bool, points_earned: int) -> bool:
|
||||
"""添加任务历史记录"""
|
||||
sql = """
|
||||
INSERT INTO t_encyclopedia_task_history
|
||||
(group_id, active_task_id, player_id, answer, is_correct, points_earned)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
params = (group_id, task_id, player_id, answer, 1 if is_correct else 0, points_earned)
|
||||
return self.execute_update(sql, params)
|
||||
|
||||
def get_player_ranking(self, group_id: str, limit: int = 10) -> List[Dict]:
|
||||
"""获取玩家排名"""
|
||||
sql = """
|
||||
SELECT player_name, points
|
||||
FROM t_encyclopedia_players
|
||||
WHERE group_id = %s
|
||||
ORDER BY points DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
return self.execute_query(sql, (group_id, limit)) or []
|
||||
|
||||
def create_active_task(self, group_id: str, question: str, answer: str,
|
||||
score: int, description: str, holder_id: str) -> Optional[int]:
|
||||
"""创建活跃任务"""
|
||||
sql = """
|
||||
INSERT INTO t_encyclopedia_active_tasks
|
||||
(group_id, question, answer, score, description, holder_id)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
if self.execute_update(sql, (group_id, question, answer, score, description, holder_id)):
|
||||
# 获取最新创建的任务ID
|
||||
get_id_sql = """
|
||||
SELECT active_task_id
|
||||
FROM t_encyclopedia_active_tasks
|
||||
WHERE group_id = %s AND question = %s AND holder_id = %s
|
||||
ORDER BY assigned_at DESC LIMIT 1
|
||||
"""
|
||||
result = self.execute_query(get_id_sql, (group_id, question, holder_id), fetch_one=True)
|
||||
return result['active_task_id'] if result else None
|
||||
return None
|
||||
|
||||
def update_player_points(self, player_id: str, group_id: str, points: int) -> bool:
|
||||
"""更新玩家积分"""
|
||||
if points > 0:
|
||||
sql = """
|
||||
UPDATE t_encyclopedia_players
|
||||
SET points = points + %s
|
||||
WHERE group_id = %s AND player_id = %s
|
||||
"""
|
||||
else:
|
||||
sql = """
|
||||
UPDATE t_encyclopedia_players
|
||||
SET points = GREATEST(points + %s, 0)
|
||||
WHERE group_id = %s AND player_id = %s
|
||||
"""
|
||||
return self.execute_update(sql, (points, group_id, player_id))
|
||||
|
||||
def complete_task(self, active_task_id: int) -> bool:
|
||||
"""完成任务"""
|
||||
sql = """
|
||||
UPDATE t_encyclopedia_active_tasks
|
||||
SET status = 'completed'
|
||||
WHERE active_task_id = %s
|
||||
"""
|
||||
return self.execute_update(sql, (active_task_id,))
|
||||
|
||||
def get_player(self, player_id: str, group_id: str) -> Optional[Dict]:
|
||||
"""获取玩家信息"""
|
||||
sql = """
|
||||
SELECT player_id, player_name, points
|
||||
FROM t_encyclopedia_players
|
||||
WHERE group_id = %s AND player_id = %s
|
||||
"""
|
||||
return self.execute_query(sql, (group_id, player_id), fetch_one=True)
|
||||
|
||||
def get_all_groups(self) -> List[str]:
|
||||
"""获取所有群组ID"""
|
||||
sql = "SELECT group_id FROM t_encyclopedia_groups"
|
||||
results = self.execute_query(sql)
|
||||
return [row['group_id'] for row in results] if results else []
|
||||
|
||||
def get_all_players_in_group(self, group_id: str) -> List[Dict]:
|
||||
"""获取群组中的所有玩家"""
|
||||
sql = "SELECT player_id, player_name FROM t_encyclopedia_players WHERE group_id = %s"
|
||||
return self.execute_query(sql, (group_id,)) or []
|
||||
|
||||
def get_active_tasks_in_group(self, group_id: str) -> List[Dict]:
|
||||
"""获取群组中的所有活跃任务"""
|
||||
sql = """
|
||||
SELECT a.active_task_id, a.question, p.player_name, p.player_id
|
||||
FROM t_encyclopedia_active_tasks a
|
||||
JOIN t_encyclopedia_players p ON a.holder_id = p.player_id AND a.group_id = p.group_id
|
||||
WHERE a.group_id = %s AND a.status = 'pending'
|
||||
"""
|
||||
return self.execute_query(sql, (group_id,)) or []
|
||||
|
||||
def get_task_by_id(self, group_id: str, task_id: int) -> Optional[Dict]:
|
||||
"""根据ID获取任务"""
|
||||
sql = """
|
||||
SELECT question, answer, score, holder_id, status
|
||||
FROM t_encyclopedia_active_tasks
|
||||
WHERE group_id = %s AND active_task_id = %s
|
||||
"""
|
||||
return self.execute_query(sql, (group_id, task_id), fetch_one=True)
|
||||
|
||||
def get_task_holder(self, group_id: str, holder_id: str) -> Optional[Dict]:
|
||||
"""获取任务持有者信息"""
|
||||
sql = "SELECT player_name FROM t_encyclopedia_players WHERE group_id = %s AND player_id = %s"
|
||||
return self.execute_query(sql, (group_id, holder_id), fetch_one=True)
|
||||
69
db/message_storage.py
Normal file
69
db/message_storage.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from wcferry import WxMsg
|
||||
|
||||
from db.base import BaseDBOperator
|
||||
from db.connection import DBConnectionManager
|
||||
|
||||
|
||||
class MessageStorageDB(BaseDBOperator):
|
||||
"""消息存储相关数据库操作"""
|
||||
|
||||
def __init__(self, db_manager: DBConnectionManager):
|
||||
super().__init__(db_manager)
|
||||
|
||||
def archive_message(self, msg: WxMsg) -> bool:
|
||||
"""存档消息"""
|
||||
now_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
sql = """
|
||||
INSERT INTO messages (group_id, timestamp, sender, content, message_type, attachment_url, message_id, message_xml, message_thumb)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
params = (msg.roomid, now_time, msg.sender, msg.content, msg.type, msg.extra, msg.id, msg.xml, msg.thumb)
|
||||
result = self.execute_update(sql, params)
|
||||
if result:
|
||||
self.LOG.info(f"消息存档成功: {now_time}:{msg.roomid}:{msg.sender}")
|
||||
return result
|
||||
|
||||
def get_recent_messages(self, group_id: str, hours_ago: int = 8, min_content_length: int = 6) -> List[Dict]:
|
||||
"""获取最近的消息"""
|
||||
sql = """
|
||||
SELECT timestamp, sender, content, message_type
|
||||
FROM messages
|
||||
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL %s HOUR)
|
||||
AND message_type in (1, 49)
|
||||
AND group_id = %s
|
||||
AND length(content) > %s
|
||||
AND content NOT LIKE '/%'
|
||||
"""
|
||||
params = (hours_ago, group_id, min_content_length)
|
||||
return self.execute_query(sql, params) or []
|
||||
|
||||
def get_message_count_by_date(self, date: str) -> List[Dict]:
|
||||
"""获取指定日期的消息统计"""
|
||||
sql = """
|
||||
SELECT group_id, sender, COUNT(*) as count
|
||||
FROM messages
|
||||
WHERE DATE(timestamp) = %s
|
||||
GROUP BY group_id, sender
|
||||
"""
|
||||
return self.execute_query(sql, (date,)) or []
|
||||
|
||||
# 在 MessageStorageDB 类中添加以下方法
|
||||
|
||||
def get_speech_ranking(self, date: str, group_id: str, limit: int = 20) -> List[Dict]:
|
||||
"""获取指定日期和群组的发言排名"""
|
||||
sql = """
|
||||
SELECT wx_id, count AS speech_count
|
||||
FROM speech_counts
|
||||
WHERE date = %s
|
||||
AND group_id = %s
|
||||
GROUP BY wx_id
|
||||
ORDER BY count DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
params = (date, group_id, limit)
|
||||
results = self.execute_query(sql, params)
|
||||
return results or []
|
||||
64
db/sign_in.py
Normal file
64
db/sign_in.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from db.base import BaseDBOperator
|
||||
from db.connection import DBConnectionManager
|
||||
|
||||
|
||||
class SignInDB(BaseDBOperator):
|
||||
"""签到系统相关数据库操作"""
|
||||
|
||||
def __init__(self, db_manager: DBConnectionManager):
|
||||
super().__init__(db_manager)
|
||||
|
||||
def initialize_table(self) -> bool:
|
||||
"""初始化签到表"""
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS t_sign_record (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
wx_id VARCHAR(100) NOT NULL,
|
||||
group_id VARCHAR(100) NOT NULL,
|
||||
wx_nick_name VARCHAR(100) NOT NULL,
|
||||
points INT DEFAULT 0,
|
||||
sign_stat DATETIME,
|
||||
signin_streak INT DEFAULT 0,
|
||||
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY unique_sign (wx_id, group_id)
|
||||
)
|
||||
"""
|
||||
return self.execute_update(sql)
|
||||
|
||||
def get_user_record(self, wx_id: str, group_id: str) -> Optional[Dict]:
|
||||
"""获取用户签到记录"""
|
||||
sql = """
|
||||
SELECT wx_id, group_id, wx_nick_name, points, sign_stat, signin_streak
|
||||
FROM t_sign_record
|
||||
WHERE wx_id = %s AND group_id = %s
|
||||
"""
|
||||
return self.execute_query(sql, (wx_id, group_id), fetch_one=True)
|
||||
|
||||
def update_sign_record(self, wx_id: str, group_id: str, wx_nick_name: str,
|
||||
points_to_add: int, sign_time: datetime, streak: int) -> bool:
|
||||
"""更新签到记录"""
|
||||
sql = """
|
||||
UPDATE t_sign_record
|
||||
SET wx_nick_name = %s, points = points + %s,
|
||||
sign_stat = %s, signin_streak = %s,
|
||||
update_time = %s
|
||||
WHERE wx_id = %s AND group_id = %s
|
||||
"""
|
||||
params = (wx_nick_name, points_to_add, sign_time, streak, sign_time, wx_id, group_id)
|
||||
return self.execute_update(sql, params)
|
||||
|
||||
def create_sign_record(self, wx_id: str, group_id: str, wx_nick_name: str,
|
||||
points: int, sign_time: datetime, streak: int) -> bool:
|
||||
"""创建签到记录"""
|
||||
sql = """
|
||||
INSERT INTO t_sign_record
|
||||
(wx_id, group_id, wx_nick_name, points, sign_stat, signin_streak)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
params = (wx_id, group_id, wx_nick_name, points, sign_time, streak)
|
||||
return self.execute_update(sql, params)
|
||||
77
db/sign_in_redis.py
Normal file
77
db/sign_in_redis.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from db.connection import DBConnectionManager
|
||||
|
||||
|
||||
class SignInRedisDB:
|
||||
"""签到系统Redis相关操作"""
|
||||
|
||||
def __init__(self, db_manager: DBConnectionManager):
|
||||
self.db_manager = db_manager
|
||||
self.prefix = "group:sign_in:"
|
||||
|
||||
def get_redis_connection(self):
|
||||
"""获取Redis连接"""
|
||||
return self.db_manager.get_redis_connection()
|
||||
|
||||
def load_signin_count(self) -> Dict[str, int]:
|
||||
"""加载签到人数数据"""
|
||||
signin_count = {}
|
||||
with self.get_redis_connection() as redis_client:
|
||||
keys = redis_client.keys(f'{self.prefix}*')
|
||||
for key in keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
if key == f'{self.prefix}last_reset_date':
|
||||
continue
|
||||
group_id = key.replace(self.prefix, '')
|
||||
count = redis_client.get(key)
|
||||
if count is not None:
|
||||
if isinstance(count, bytes):
|
||||
count = count.decode('utf-8')
|
||||
signin_count[group_id] = int(count)
|
||||
return signin_count
|
||||
|
||||
def save_signin_count(self, group_id: str, count: int) -> bool:
|
||||
"""保存签到人数"""
|
||||
try:
|
||||
with self.get_redis_connection() as redis_client:
|
||||
redis_client.set(f'{self.prefix}{group_id}', count)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_last_reset_date(self) -> Optional[datetime.date]:
|
||||
"""获取最后重置日期"""
|
||||
with self.get_redis_connection() as redis_client:
|
||||
date_str = redis_client.get(f'{self.prefix}last_reset_date')
|
||||
if date_str:
|
||||
if isinstance(date_str, bytes):
|
||||
date_str = date_str.decode('utf-8')
|
||||
return datetime.strptime(date_str, '%Y-%m-%d').date()
|
||||
return None
|
||||
|
||||
def save_last_reset_date(self, date: datetime.date) -> bool:
|
||||
"""保存最后重置日期"""
|
||||
try:
|
||||
with self.get_redis_connection() as redis_client:
|
||||
redis_client.set(f'{self.prefix}last_reset_date', date.strftime('%Y-%m-%d'))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def reset_daily_counts(self) -> bool:
|
||||
"""重置每日签到计数"""
|
||||
try:
|
||||
with self.get_redis_connection() as redis_client:
|
||||
keys = redis_client.keys(f'{self.prefix}*')
|
||||
for key in keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode('utf-8')
|
||||
if key != f'{self.prefix}last_reset_date':
|
||||
redis_client.delete(key)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
76
db/tasks.py
Normal file
76
db/tasks.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from db.base import BaseDBOperator
|
||||
from db.connection import DBConnectionManager
|
||||
|
||||
|
||||
class TasksDB(BaseDBOperator):
|
||||
"""任务管理相关数据库操作"""
|
||||
|
||||
def __init__(self, db_manager: DBConnectionManager):
|
||||
super().__init__(db_manager)
|
||||
|
||||
def initialize_table(self) -> bool:
|
||||
"""初始化任务表"""
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
task_description VARCHAR(255) NOT NULL,
|
||||
reminder_time TIME NOT NULL,
|
||||
task_type ENUM('single', 'recurring') DEFAULT 'single',
|
||||
status ENUM('pending', 'completed') DEFAULT 'pending',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
return self.execute_update(sql)
|
||||
|
||||
def add_task(self, description: str, reminder_time: str, task_type: str = 'single') -> Optional[int]:
|
||||
"""添加任务"""
|
||||
sql = """
|
||||
INSERT INTO tasks (task_description, reminder_time, task_type)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
conn = self.db_manager.get_mysql_connection()
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(sql, (description, reminder_time, task_type))
|
||||
task_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
return task_id
|
||||
except Exception as e:
|
||||
self.LOG.error(f"添加任务出错: {e}")
|
||||
conn.rollback()
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_pending_tasks(self) -> List[Dict]:
|
||||
"""获取待办任务"""
|
||||
sql = """
|
||||
SELECT task_id, task_description, reminder_time, task_type, status, created_at
|
||||
FROM tasks
|
||||
WHERE status = 'pending'
|
||||
ORDER BY reminder_time
|
||||
"""
|
||||
return self.execute_query(sql) or []
|
||||
|
||||
def complete_task(self, task_id: int) -> bool:
|
||||
"""完成任务"""
|
||||
sql = """
|
||||
UPDATE tasks
|
||||
SET status = 'completed'
|
||||
WHERE task_id = %s
|
||||
"""
|
||||
return self.execute_update(sql, (task_id,))
|
||||
|
||||
def get_tasks_by_time(self, current_time: str) -> List[Dict]:
|
||||
"""获取指定时间的任务"""
|
||||
sql = """
|
||||
SELECT task_id, task_description, reminder_time, task_type
|
||||
FROM tasks
|
||||
WHERE TIME_FORMAT(reminder_time, '%H:%i') = %s
|
||||
AND (status = 'pending' OR (status = 'completed' AND task_type = 'recurring'))
|
||||
"""
|
||||
return self.execute_query(sql, (current_time,)) or []
|
||||
Reference in New Issue
Block a user