# -*- coding: utf-8 -*- """趣味指令规则服务层。 服务层职责: 1. 聚合 MySQL + Redis + 应用内存三层缓存读写。 2. 提供规则匹配能力给插件层。 3. 提供后台 CRUD 的统一入口,确保变更后缓存立刻刷新。 """ import json import re import threading import time from datetime import date, datetime from typing import Any, Dict, List, Optional, Tuple from loguru import logger from db.fun_command_rule_db import FunCommandRuleDBOperator class FunCommandRuleService: """趣味指令规则服务。""" # Redis 键定义:统一集中,便于后续迁移命名。 REDIS_RULES_KEY = "fun:command:rules:all" def __init__(self, db_operator: FunCommandRuleDBOperator, redis_client, local_ttl_seconds: int = 30): self.db = db_operator self.redis = redis_client self.local_ttl_seconds = max(int(local_ttl_seconds or 30), 5) # 进程内缓存:热路径优先命中,避免每条消息都打 Redis。 self._local_lock = threading.RLock() self._local_rules: List[Dict[str, Any]] = [] self._local_expire_at: float = 0.0 # 命中冷却缓存:用于每条规则的简单限频。 self._cooldown_lock = threading.RLock() self._cooldown_map: Dict[str, float] = {} def init_tables(self) -> bool: """初始化底层数据表。""" return self.db.init_tables() # --------------------------- 缓存层 --------------------------- def _load_rules_from_db(self) -> List[Dict[str, Any]]: """从 MySQL 回源全量启用规则。""" return self.db.list_rules(enabled=True) def _make_json_safe(self, value: Any) -> Any: """将任意对象转换为可 JSON 序列化结构。 重点处理: 1. datetime/date:统一转成字符串,避免 json.dumps 抛异常。 2. dict/list:递归处理,确保嵌套结构中的时间字段也可序列化。 """ if isinstance(value, datetime): return value.strftime("%Y-%m-%d %H:%M:%S") if isinstance(value, date): return value.strftime("%Y-%m-%d") if isinstance(value, dict): return {str(k): self._make_json_safe(v) for k, v in value.items()} if isinstance(value, list): return [self._make_json_safe(item) for item in value] return value def _write_redis_rules(self, rules: List[Dict[str, Any]]) -> None: """写入 Redis 持久缓存。 注意:这里不设置过期时间,Redis 作为跨进程共享缓存常驻。 """ try: # 规则行里可能带 created_at/updated_at(datetime),这里先做 JSON 安全转换。 safe_rules = self._make_json_safe(rules or []) self.redis.set(self.REDIS_RULES_KEY, json.dumps(safe_rules, ensure_ascii=False)) except Exception as e: logger.warning(f"写入 Redis 规则缓存失败: {e}") def _read_redis_rules(self) -> Optional[List[Dict[str, Any]]]: """从 Redis 读取规则缓存。""" try: text = self.redis.get(self.REDIS_RULES_KEY) if not text: return None data = json.loads(text) if not isinstance(data, list): return None return data except Exception as e: logger.warning(f"读取 Redis 规则缓存失败: {e}") return None def _set_local_cache(self, rules: List[Dict[str, Any]]) -> None: """更新应用内缓存。""" with self._local_lock: self._local_rules = list(rules or []) self._local_expire_at = time.time() + self.local_ttl_seconds def _get_local_cache(self) -> Optional[List[Dict[str, Any]]]: """读取应用内缓存。 仅当未过期才返回,确保后台更新后最长 local_ttl_seconds 可见。 """ with self._local_lock: if time.time() < self._local_expire_at: return list(self._local_rules) return None def refresh_cache(self) -> List[Dict[str, Any]]: """强制刷新缓存(DB -> Redis -> Local)。""" rules = self._load_rules_from_db() self._write_redis_rules(rules) self._set_local_cache(rules) return rules def get_enabled_rules(self) -> List[Dict[str, Any]]: """获取启用规则。 读取顺序: 1. 本地缓存命中直接返回(最高性能)。 2. Redis 命中则回填本地缓存。 3. MySQL 回源并回填 Redis + 本地缓存。 """ local_rules = self._get_local_cache() if local_rules is not None: return local_rules redis_rules = self._read_redis_rules() if redis_rules is not None: self._set_local_cache(redis_rules) return redis_rules return self.refresh_cache() # --------------------------- 管理端 CRUD --------------------------- def list_rules(self, scope_type: str = "", scope_id: str = "", enabled: Optional[bool] = None) -> List[Dict[str, Any]]: """后台使用:按条件列规则。""" return self.db.list_rules(scope_type=scope_type, scope_id=scope_id, enabled=enabled) def create_rule(self, payload: Dict[str, Any]) -> bool: """创建规则并刷新缓存。""" ok = self.db.create_rule(payload) if ok: self.refresh_cache() return ok def update_rule(self, rule_id: int, payload: Dict[str, Any]) -> bool: """更新规则并刷新缓存。""" ok = self.db.update_rule(rule_id, payload) if ok: self.refresh_cache() return ok def delete_rule(self, rule_id: int) -> bool: """删除规则并刷新缓存。""" ok = self.db.delete_rule(rule_id) if ok: self.refresh_cache() return ok def toggle_rule(self, rule_id: int, enabled: bool, updated_by: str = "system") -> bool: """启停规则并刷新缓存。""" ok = self.db.toggle_rule(rule_id=rule_id, enabled=enabled, updated_by=updated_by) if ok: self.refresh_cache() return ok # --------------------------- 规则匹配 --------------------------- @staticmethod def _scope_match(rule: Dict[str, Any], scope_type: str, scope_id: str) -> bool: """判断规则作用域是否命中当前会话。""" rule_scope_type = str(rule.get("scope_type", "global") or "global").strip().lower() rule_scope_id = str(rule.get("scope_id", "") or "").strip() # global:全局可用。 if rule_scope_type == "global": return True # group/private:必须同时匹配 scope_id。 if rule_scope_type == scope_type: return not rule_scope_id or rule_scope_id == scope_id return False @staticmethod def _trigger_match(rule: Dict[str, Any], event_key: str, content: str) -> bool: """判断规则触发条件是否命中。 支持: - event(事件触发,如 PAT) - exact/prefix/contains/regex(文本触发) """ trigger_type = str(rule.get("trigger_type", "exact") or "exact").strip().lower() trigger_text = str(rule.get("trigger_text", "") or "") target_event_key = str(rule.get("event_key", "") or "").strip().upper() normalized_content = str(content or "").strip() if trigger_type == "event": return bool(target_event_key) and target_event_key == str(event_key or "").strip().upper() if not normalized_content: return False if trigger_type == "exact": return normalized_content == trigger_text if trigger_type == "prefix": return normalized_content.startswith(trigger_text) if trigger_type == "contains": return trigger_text in normalized_content if trigger_type == "regex": try: return re.search(trigger_text, normalized_content) is not None except re.error: # 正则配置非法时直接视为不匹配,避免打断主流程。 return False return False def _cooldown_key(self, rule_id: int, session_key: str) -> str: """构造冷却键。 session_key 采用群ID/私聊ID,确保不同会话互不干扰。 """ return f"{int(rule_id)}::{session_key}" def _check_and_mark_cooldown(self, rule: Dict[str, Any], session_key: str) -> bool: """检查并写入冷却窗口。 返回 True 表示允许触发;False 表示仍在冷却中。 """ cooldown_seconds = int(rule.get("cooldown_seconds", 0) or 0) if cooldown_seconds <= 0: return True now = time.time() key = self._cooldown_key(int(rule.get("id", 0) or 0), session_key) with self._cooldown_lock: expired_at = self._cooldown_map.get(key, 0) if now < expired_at: return False self._cooldown_map[key] = now + cooldown_seconds # 轻量清理,防止 map 长期膨胀。 if len(self._cooldown_map) > 5000: stale_keys = [k for k, v in self._cooldown_map.items() if v < now] for stale_key in stale_keys[:1000]: self._cooldown_map.pop(stale_key, None) return True def match_rule( self, scope_type: str, scope_id: str, content: str, event_key: str, session_key: str, ) -> Optional[Dict[str, Any]]: """匹配首条可执行规则。 设计为“首条命中即返回”,通过 priority 实现可控顺序。 """ rules = self.get_enabled_rules() if not rules: return None normalized_scope_type = str(scope_type or "global").strip().lower() normalized_scope_id = str(scope_id or "").strip() for rule in rules: if not self._scope_match(rule, normalized_scope_type, normalized_scope_id): continue if not self._trigger_match(rule, event_key=event_key, content=content): continue if not self._check_and_mark_cooldown(rule, session_key=session_key): continue return rule return None def validate_responses(self, responses_json: Any) -> Tuple[bool, str, List[Dict[str, Any]]]: """校验响应动作数组。 返回: - ok: 是否通过 - message: 错误说明 - normalized: 归一化后的响应列表 """ if isinstance(responses_json, str): try: responses_json = json.loads(responses_json) except Exception: return False, "responses_json 不是合法 JSON", [] if not isinstance(responses_json, list) or not responses_json: return False, "responses_json 必须是非空数组", [] normalized: List[Dict[str, Any]] = [] allowed_types = {"text", "image", "voice", "video", "link", "app"} for idx, item in enumerate(responses_json): if not isinstance(item, dict): return False, f"第 {idx + 1} 条响应必须是对象", [] action_type = str(item.get("type", "") or "").strip().lower() if action_type not in allowed_types: return False, f"第 {idx + 1} 条响应 type 非法,仅支持 {sorted(allowed_types)}", [] delay_ms = int(item.get("delay_ms", 0) or 0) if delay_ms < 0: delay_ms = 0 normalized_item = dict(item) normalized_item["type"] = action_type normalized_item["delay_ms"] = delay_ms normalized.append(normalized_item) return True, "ok", normalized