Files
abot/db/task_db.py
2025-06-12 12:25:10 +08:00

591 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import List, Dict, Optional, Tuple
from db.base import BaseDBOperator
from db.connection import DBConnectionManager
from loguru import logger
import json
class TaskDBOperator(BaseDBOperator):
"""消息推送任务数据库操作类"""
def __init__(self, db_manager: DBConnectionManager):
super().__init__(db_manager)
def init_tables(self) -> bool:
"""初始化数据库表"""
try:
# 创建任务表
self.execute_update("""
CREATE TABLE IF NOT EXISTS t_push_tasks (
task_id VARCHAR(36) PRIMARY KEY,
name VARCHAR(50) NOT NULL,
schedule_type ENUM('once', 'recurring') NOT NULL,
schedule_time DATETIME NOT NULL,
recurring_interval ENUM('daily', 'weekly', 'monthly') DEFAULT NULL,
recurring_end DATETIME DEFAULT NULL,
recurring_time TIME DEFAULT NULL,
weekly_days JSON DEFAULT NULL,
monthly_day INT DEFAULT NULL,
content_text TEXT(500),
content_image VARCHAR(255),
content_link JSON,
content_miniprogram JSON,
content_voice VARCHAR(255), -- 语音消息文件路径
content_video VARCHAR(255), -- 视频消息文件路径
groups JSON,
priority ENUM('high', 'medium', 'low') DEFAULT 'medium',
status ENUM('draft', 'scheduled', 'running', 'completed', 'failed', 'paused') DEFAULT 'draft',
creator_id VARCHAR(50) NOT NULL,
preview_recipients JSON,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
)
""")
# 创建任务日志表
self.execute_update("""
CREATE TABLE IF NOT EXISTS t_push_task_logs (
log_id VARCHAR(36) PRIMARY KEY,
task_id VARCHAR(36) NOT NULL,
action ENUM('create', 'update', 'delete', 'pause', 'resume') NOT NULL,
user_id VARCHAR(50) NOT NULL,
changes JSON,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# 创建预览表
self.execute_update("""
CREATE TABLE IF NOT EXISTS t_push_previews (
preview_id VARCHAR(36) PRIMARY KEY,
task_id VARCHAR(36) NOT NULL,
content JSON NOT NULL,
recipients JSON NOT NULL,
validation JSON,
status ENUM('sent', 'confirmed', 'modified') DEFAULT 'sent',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# 创建反馈表
self.execute_update("""
CREATE TABLE IF NOT EXISTS t_push_feedback (
feedback_id VARCHAR(36) PRIMARY KEY,
task_id VARCHAR(36) NOT NULL,
user_id VARCHAR(50) NOT NULL,
content TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
return True
except Exception as e:
logger.error(f"初始化数据库表失败: {e}")
return False
def create_task(self, task_data: Dict) -> Optional[Dict]:
"""创建新任务
Args:
task_data: 任务数据
Returns:
创建的任务数据失败返回None
"""
try:
sql = """
INSERT INTO t_push_tasks (
task_id, name, schedule_type, schedule_time, recurring_interval,
recurring_end, recurring_time, weekly_days, monthly_day, content_text,
content_image, content_link, content_miniprogram, content_voice, content_video,
groups, priority, status, creator_id, preview_recipients
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
)
"""
# 将字典和列表类型转换为JSON字符串
content_miniprogram = json.dumps(task_data.get('content_miniprogram', {}))
groups = json.dumps(task_data.get('groups', []))
preview_recipients = json.dumps(task_data.get('preview_recipients', []))
# 处理 content_link
content_link = task_data.get('content_link', {})
if isinstance(content_link, str):
try:
# 如果已经是JSON字符串直接使用
json.loads(content_link) # 验证是否是有效的JSON
except json.JSONDecodeError:
# 如果解析失败,使用空字典
content_link = json.dumps({})
else:
# 如果是字典,序列化一次
content_link = json.dumps(content_link)
params = (
task_data['task_id'],
task_data['name'],
task_data['schedule_type'],
task_data['schedule_time'],
task_data.get('recurring_interval'),
task_data.get('recurring_end'),
task_data.get('recurring_time'),
task_data.get('weekly_days'),
task_data.get('monthly_day'),
task_data.get('content_text'),
task_data.get('content_image'),
content_link,
content_miniprogram,
task_data.get('content_voice'),
task_data.get('content_video'),
groups,
task_data.get('priority', 'medium'),
task_data.get('status', 'draft'),
task_data['creator_id'],
preview_recipients
)
if self.execute_update(sql, params):
return self.get_task(task_data['task_id'])
return None
except Exception as e:
logger.error(f"创建任务失败: {e}")
return None
def get_task(self, task_id: str) -> Optional[Dict]:
"""获取任务
Args:
task_id: 任务ID
Returns:
任务数据不存在返回None
"""
sql = "SELECT * FROM t_push_tasks WHERE task_id = %s"
return self.execute_query(sql, (task_id,), fetch_one=True)
def update_task(self, task_id: str, updates: Dict) -> bool:
"""更新任务
Args:
task_id: 任务ID
updates: 更新的字段和值
Returns:
是否更新成功
"""
try:
fields = []
values = []
# 需要序列化的字段
json_fields = ['groups', 'content_miniprogram', 'preview_recipients', 'content_link', 'weekly_days']
# datetime字段
datetime_fields = ['recurring_end', 'recurring_time', 'schedule_time']
for key, value in updates.items():
# 跳过空值字段
if value is None or (isinstance(value, list) and len(value) == 0):
continue
# 跳过datetime字段的空字符串
if key in datetime_fields and value == '':
continue
fields.append(f"{key} = %s")
# 如果是需要序列化的字段,且值不是字符串类型,则进行序列化
if key in json_fields and not isinstance(value, str):
values.append(json.dumps(value))
else:
values.append(value)
values.append(task_id)
sql = f"""
UPDATE t_push_tasks
SET {', '.join(fields)}
WHERE task_id = %s
"""
return self.execute_update(sql, values)
except Exception as e:
logger.error(f"更新任务失败: {e}")
return False
def delete_task(self, task_id: str) -> bool:
"""删除任务
Args:
task_id: 任务ID
Returns:
是否删除成功
"""
try:
# 先删除任务相关的日志
delete_logs_sql = "DELETE FROM t_push_task_logs WHERE task_id = %s"
self.execute_update(delete_logs_sql, (task_id,))
# 然后删除任务
delete_task_sql = "DELETE FROM t_push_tasks WHERE task_id = %s"
return self.execute_update(delete_task_sql, (task_id,))
except Exception as e:
logger.error(f"删除任务失败: {e}")
return False
def get_scheduled_tasks(self) -> List[Dict]:
"""获取待执行的任务
Returns:
待执行的任务列表
"""
sql = """
SELECT * FROM t_push_tasks
WHERE status = 'scheduled'
AND schedule_time <= NOW()
"""
return self.execute_query(sql)
def log_task_action(self, log_data: Dict) -> bool:
"""记录任务操作日志
Args:
log_data: 日志数据
Returns:
是否记录成功
"""
try:
sql = """
INSERT INTO t_push_task_logs (
log_id, task_id, action, user_id, changes
) VALUES (
%s, %s, %s, %s, %s
)
"""
# 将changes字典转换为JSON字符串
changes_json = json.dumps(log_data['changes'], ensure_ascii=False)
params = (
log_data['log_id'],
log_data['task_id'],
log_data['action'],
log_data['user_id'],
changes_json
)
return self.execute_update(sql, params)
except Exception as e:
logger.error(f"记录任务操作日志失败: {e}")
return False
def get_task_logs(self, task_id: str, page: int = 1, limit: int = 20) -> Dict:
"""获取任务日志
Args:
task_id: 任务ID
page: 页码
limit: 每页数量
Returns:
包含日志列表和总数的字典
"""
try:
# 获取总数
count_sql = """
SELECT COUNT(*) as total
FROM t_push_task_logs
WHERE task_id = %s
"""
count_result = self.execute_query(count_sql, (task_id,), fetch_one=True)
total = count_result['total'] if count_result else 0
# 获取日志列表
sql = """
SELECT *
FROM t_push_task_logs
WHERE task_id = %s
ORDER BY timestamp DESC
LIMIT %s OFFSET %s
"""
offset = (page - 1) * limit
logs = self.execute_query(sql, (task_id, limit, offset))
return {
'total': total,
'logs': logs
}
except Exception as e:
logger.error(f"获取任务日志失败: {e}")
return {'total': 0, 'logs': []}
def create_preview(self, preview_data: Dict) -> Optional[Dict]:
"""创建预览记录
Args:
preview_data: 预览数据
Returns:
创建的预览记录失败返回None
"""
try:
sql = """
INSERT INTO t_push_previews (
preview_id, task_id, content, recipients, validation
) VALUES (
%s, %s, %s, %s, %s
)
"""
params = (
preview_data['preview_id'],
preview_data['task_id'],
preview_data['content'],
preview_data['recipients'],
preview_data.get('validation')
)
if self.execute_update(sql, params):
return self.get_preview(preview_data['preview_id'])
return None
except Exception as e:
logger.error(f"创建预览记录失败: {e}")
return None
def get_preview(self, preview_id: str) -> Optional[Dict]:
"""获取预览记录
Args:
preview_id: 预览ID
Returns:
预览记录不存在返回None
"""
sql = "SELECT * FROM t_push_previews WHERE preview_id = %s"
return self.execute_query(sql, (preview_id,), fetch_one=True)
def update_preview_status(self, preview_id: str, status: str) -> bool:
"""更新预览状态
Args:
preview_id: 预览ID
status: 新状态
Returns:
是否更新成功
"""
sql = """
UPDATE t_push_previews
SET status = %s
WHERE preview_id = %s
"""
return self.execute_update(sql, (status, preview_id))
def create_feedback(self, feedback_data: Dict) -> bool:
"""创建反馈记录
Args:
feedback_data: 反馈数据
Returns:
是否创建成功
"""
try:
sql = """
INSERT INTO t_push_feedback (
feedback_id, task_id, user_id, content
) VALUES (
%s, %s, %s, %s
)
"""
params = (
feedback_data['feedback_id'],
feedback_data['task_id'],
feedback_data['user_id'],
feedback_data['content']
)
return self.execute_update(sql, params)
except Exception as e:
logger.error(f"创建反馈记录失败: {e}")
return False
def get_task_feedback(self, task_id: str, start_time: str = None, end_time: str = None) -> List[Dict]:
"""获取任务反馈
Args:
task_id: 任务ID
start_time: 开始时间
end_time: 结束时间
Returns:
反馈列表
"""
try:
sql = """
SELECT *
FROM t_push_feedback
WHERE task_id = %s
"""
params = [task_id]
if start_time:
sql += " AND timestamp >= %s"
params.append(start_time)
if end_time:
sql += " AND timestamp <= %s"
params.append(end_time)
sql += " ORDER BY timestamp DESC"
return self.execute_query(sql, tuple(params))
except Exception as e:
logger.error(f"获取任务反馈失败: {e}")
return []
def get_tasks_list(self, status: str = None, start_time: str = None, end_time: str = None, page: int = 1,
limit: int = 20) -> Tuple[List[Dict], int]:
"""获取任务列表
Args:
status: 任务状态
start_time: 开始时间
end_time: 结束时间
page: 页码
limit: 每页数量
Returns:
(任务列表, 总数)
"""
try:
# 构建查询条件
conditions = []
params = []
if status:
conditions.append("status = %s")
params.append(status)
if start_time:
conditions.append("schedule_time >= %s")
params.append(start_time)
if end_time:
conditions.append("schedule_time <= %s")
params.append(end_time)
# 构建SQL
sql = "SELECT * FROM t_push_tasks"
if conditions:
sql += " WHERE " + " AND ".join(conditions)
sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([limit, (page - 1) * limit])
# 查询总数
count_sql = f"SELECT COUNT(*) FROM t_push_tasks"
if conditions:
count_sql += " WHERE " + " AND ".join(conditions)
count_result = self.execute_query(count_sql, params[:-2], fetch_one=True)
total = count_result['COUNT(*)'] if count_result else 0
# 查询数据
tasks = self.execute_query(sql, params)
# 处理JSON字段
for task in tasks:
for field in ['groups', 'content_miniprogram', 'preview_recipients']:
if task.get(field):
task[field] = json.loads(task[field])
return tasks, total
except Exception as e:
logger.error(f"获取任务列表失败: {e}")
return [], 0
def get_task_logs_with_pagination(self, task_id: str, page: int = 1, limit: int = 20) -> Dict:
"""获取任务日志(带分页)
Args:
task_id: 任务ID
page: 页码
limit: 每页数量
Returns:
包含日志列表和分页信息的字典
"""
try:
# 查询总数
count_sql = """
SELECT COUNT(*) as total
FROM t_push_task_logs
WHERE task_id = %s
"""
count_result = self.execute_query(count_sql, (task_id,), fetch_one=True)
total = count_result['total'] if count_result else 0
# 查询日志列表
sql = """
SELECT *
FROM t_push_task_logs
WHERE task_id = %s
ORDER BY timestamp DESC
LIMIT %s OFFSET %s
"""
offset = (page - 1) * limit
logs = self.execute_query(sql, (task_id, limit, offset))
# 处理JSON字段
for log in logs:
if log.get('changes'):
log['changes'] = json.loads(log['changes'])
return {
'logs': logs,
'total': total,
'page': page,
'limit': limit
}
except Exception as e:
logger.error(f"获取任务日志失败: {e}")
return {'logs': [], 'total': 0, 'page': page, 'limit': limit}
def get_tasks_count(self) -> int:
"""获取任务总数
Returns:
任务总数
"""
try:
sql = "SELECT COUNT(*) as total FROM t_push_tasks"
result = self.execute_query(sql, fetch_one=True)
return result['total'] if result else 0
except Exception as e:
logger.error(f"获取任务总数失败: {e}")
return 0
def get_tasks_count_by_status(self, status: str) -> int:
"""获取指定状态的任务数量
Args:
status: 任务状态
Returns:
任务数量
"""
try:
sql = "SELECT COUNT(*) as total FROM t_push_tasks WHERE status = %s"
result = self.execute_query(sql, (status,), fetch_one=True)
return result['total'] if result else 0
except Exception as e:
logger.error(f"获取{status}状态任务数量失败: {e}")
return 0
def get_tasks_count_by_date(self, date: str) -> int:
"""获取指定日期的任务数量
Args:
date: 日期格式YYYY-MM-DD
Returns:
任务数量
"""
try:
sql = """
SELECT COUNT(*) as total
FROM t_push_tasks
WHERE DATE(created_at) = %s
"""
result = self.execute_query(sql, (date,), fetch_one=True)
return result['total'] if result else 0
except Exception as e:
logger.error(f"获取{date}任务数量失败: {e}")
return 0