Files
abot/message_sign/main.py
2025-03-05 10:29:12 +08:00

249 lines
9.8 KiB
Python

from datetime import datetime, timedelta
import logging
import mysql.connector.pooling
import tomllib
import pytz
import redis
from typing import Optional, Tuple
from wcferry import Wcf, WxMsg
from robot_cmd.robot_command import GroupBotManager, Feature, PermissionStatus
# 创建表的SQL语句
CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS t_sign_record (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
wx_id VARCHAR(100) NOT NULL,
group_id VARCHAR(100) NOT NULL,
wx_nick_name VARCHAR(100) NOT NULL,
points INT DEFAULT 0,
sign_stat DATETIME,
signin_streak INT DEFAULT 0,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
UNIQUE KEY unique_sign (wx_id, group_id)
)
"""
class SignInSystem:
def __init__(self, wcf: Wcf, gbm: GroupBotManager, all_contacts: dict,
db_pool: mysql.connector.pooling.MySQLConnectionPool, redis_pool: redis.ConnectionPool):
# 读取配置文件
with open('message_sign/config.toml', 'rb') as f:
self.config = tomllib.load(f)['SignIn']
self.LOG = logging.getLogger(__name__)
if not self.config['enable']:
raise Exception("签到功能未启用")
self.wcf = wcf
self.gbm = gbm
self.all_contacts = all_contacts
self.db_pool = db_pool
self.redis_pool = redis_pool
self.command = self.config['command']
self.min_point = self.config['min-point']
self.max_point = self.config['max-point']
self.streak_cycle = self.config['streak-cycle']
self.max_streak_point = self.config['max-streak-point']
# 时区设置
self.timezone = 'Asia/Shanghai'
# 从 Redis 初始化签到数据
self.today_signin_count = self._load_signin_count_from_redis()
with self._get_redis_connection() as redis_client:
last_reset_date_str = redis_client.get('group:sign_in:last_reset_date')
if last_reset_date_str:
self.last_reset_date = datetime.strptime(last_reset_date_str, '%Y-%m-%d').date()
else:
self.last_reset_date = datetime.now(tz=pytz.timezone(self.timezone)).date()
self._save_last_reset_date_to_redis()
self.LOG.info(f"[签到] 组件初始化完成 {self.command_format}")
def _get_db_connection(self):
"""从连接池获取数据库连接"""
return self.db_pool.get_connection()
def _get_redis_connection(self):
"""从连接池获取 Redis 连接"""
return redis.Redis(connection_pool=self.redis_pool)
def _load_signin_count_from_redis(self) -> dict:
"""从 Redis 加载签到人数数据"""
signin_count = {}
with self._get_redis_connection() as redis_client:
keys = redis_client.keys('group:sign_in:*')
for key in keys:
if key == 'group:sign_in:last_reset_date':
continue
group_id = key.replace('group:sign_in:', '')
count = redis_client.get(key)
if count is not None:
signin_count[group_id] = int(count)
return signin_count
def _save_signin_count_to_redis(self, group_id: str, count: int):
"""保存签到人数到 Redis"""
with self._get_redis_connection() as redis_client:
redis_client.set(f'group:sign_in:{group_id}', count)
def _save_last_reset_date_to_redis(self):
"""保存最后重置日期到 Redis"""
with self._get_redis_connection() as redis_client:
redis_client.set('group:sign_in:last_reset_date', self.last_reset_date.strftime('%Y-%m-%d'))
@property
def command_format(self):
return ','.join(self.command)
@property
def enable(self):
return self.config['enable']
def initialize_table(self):
"""初始化数据库表"""
with self._get_db_connection() as conn:
with conn.cursor(dictionary=True) as cursor: # 使用 dictionary=True 返回字典格式
cursor.execute(CREATE_TABLE_SQL)
conn.commit()
def reset_today_count_if_needed(self):
"""检查并重置每日签到计数"""
current_date = datetime.now(tz=pytz.timezone(self.timezone)).date()
if current_date != self.last_reset_date:
self.today_signin_count.clear()
with self._get_redis_connection() as redis_client:
keys = redis_client.keys('group:sign_in:*')
for key in keys:
if key != 'group:sign_in:last_reset_date':
redis_client.delete(key)
self.last_reset_date = current_date
self._save_last_reset_date_to_redis()
self.LOG.info(f"[签到] 已重置每日签到计数,日期更新为 {current_date}")
def get_today_signin_count(self, group_id: str) -> int:
"""获取群内今日签到人数(使用缓存)"""
self.reset_today_count_if_needed()
return self.today_signin_count.get(group_id, 0)
def get_user_record(self, wx_id: str, group_id: str) -> Optional[dict]:
"""获取用户签到记录"""
with self._get_db_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
query = """
SELECT wx_id, group_id, wx_nick_name, points, sign_stat, signin_streak
FROM t_sign_record
WHERE wx_id = %s AND group_id = %s
"""
cursor.execute(query, (wx_id, group_id))
return cursor.fetchone()
def calculate_points(self, streak: int) -> int:
"""根据连续签到天数计算积分"""
base_points = self.min_point
extra_points = min(streak // self.streak_cycle, self.max_streak_point)
total_points = base_points + extra_points
return min(total_points, self.max_point)
def member_sign_in(self, message: WxMsg):
"""会员签到功能"""
if not self.enable:
return
content = str(message.content).strip()
command = content.split(" ")
if not len(command) or command[0] not in self.command:
return
if self.gbm.get_group_permission(message.roomid, Feature.SIGNIN) == PermissionStatus.DISABLED:
return
# 获取当前时间,带有时区信息
current_time = datetime.now(tz=pytz.timezone(self.timezone))
# 获取当天零点的时间
today_start = current_time.replace(hour=0, minute=0, second=0, microsecond=0)
# 获取昨天的时间
yesterday = today_start - timedelta(days=1)
# 获取用户的签到记录
user_record = self.get_user_record(message.sender, message.roomid)
wx_nick_name = self.all_contacts.get(message.sender, message.sender)
# 判断用户是否已经签到过
if user_record and user_record.get('sign_stat'):
sign_stat = user_record['sign_stat']
# 确保 sign_stat 和 today_start 是同一时区对象
if isinstance(sign_stat, datetime) and sign_stat.tzinfo is None:
sign_stat = pytz.timezone(self.timezone).localize(sign_stat)
# 如果 sign_stat 已经大于或等于今天的零点,则认为用户已经签到过了
if sign_stat >= today_start:
self.wcf.send_text(
f"@{wx_nick_name} 您今天已经签到过了!当前积分:{user_record['points']}",
(message.roomid if message.from_group() else message.sender),
message.sender
)
return
streak = 0
if user_record and user_record['sign_stat']:
last_sign_date = user_record['sign_stat'].replace(hour=0, minute=0, second=0, microsecond=0)
if last_sign_date == yesterday:
streak = user_record['signin_streak'] + 1
else:
streak = 1
else:
streak = 1
today_signin_rank = self.get_today_signin_count(message.roomid) + 1
self.today_signin_count[message.roomid] = today_signin_rank
self._save_signin_count_to_redis(message.roomid, today_signin_rank)
points_to_add = self.calculate_points(streak)
with self._get_db_connection() as conn:
with conn.cursor(dictionary=True) as cursor:
if user_record:
update_sql = """
UPDATE t_sign_record
SET wx_nick_name = %s, points = points + %s,
sign_stat = %s, signin_streak = %s,
update_time = %s
WHERE wx_id = %s AND group_id = %s
"""
cursor.execute(update_sql, (
wx_nick_name, points_to_add, current_time, streak,
current_time, message.sender, message.roomid
))
else:
insert_sql = """
INSERT INTO t_sign_record
(wx_id, group_id, wx_nick_name, points, sign_stat, signin_streak)
VALUES (%s, %s, %s, %s, %s, %s)
"""
cursor.execute(insert_sql, (
message.sender, message.roomid, wx_nick_name, points_to_add, current_time, streak
))
conn.commit()
msg = (
f"@{wx_nick_name} 签到成功!\n"
f"您是今日群内第{today_signin_rank}个签到的\n"
f"连续签到{streak}天,本次获得{points_to_add}积分"
)
self.wcf.send_text(
msg,
(message.roomid if message.from_group() else message.sender),
message.sender
)
def __del__(self):
"""连接池由外部管理,不需要手动关闭"""
pass