# -*- coding: utf-8 -*- """趣味指令规则服务层。 服务层职责: 1. 聚合 MySQL + Redis + 应用内存三层缓存读写。 2. 提供规则匹配能力给插件层。 3. 提供后台 CRUD 的统一入口,确保变更后缓存立刻刷新。 """ import json import re import threading import time from typing import Any, Dict, List, Optional, Tuple from loguru import logger from db.fun_command_rule_db import FunCommandRuleDBOperator class FunCommandRuleService: """趣味指令规则服务。""" # Redis 键定义:统一集中,便于后续迁移命名。 REDIS_RULES_KEY = "fun:command:rules:all" def __init__(self, db_operator: FunCommandRuleDBOperator, redis_client, local_ttl_seconds: int = 30): self.db = db_operator self.redis = redis_client self.local_ttl_seconds = max(int(local_ttl_seconds or 30), 5) # 进程内缓存:热路径优先命中,避免每条消息都打 Redis。 self._local_lock = threading.RLock() self._local_rules: List[Dict[str, Any]] = [] self._local_expire_at: float = 0.0 # 命中冷却缓存:用于每条规则的简单限频。 self._cooldown_lock = threading.RLock() self._cooldown_map: Dict[str, float] = {} def init_tables(self) -> bool: """初始化底层数据表。""" return self.db.init_tables() # --------------------------- 缓存层 --------------------------- def _load_rules_from_db(self) -> List[Dict[str, Any]]: """从 MySQL 回源全量启用规则。""" return self.db.list_rules(enabled=True) def _write_redis_rules(self, rules: List[Dict[str, Any]]) -> None: """写入 Redis 持久缓存。 注意:这里不设置过期时间,Redis 作为跨进程共享缓存常驻。 """ try: self.redis.set(self.REDIS_RULES_KEY, json.dumps(rules or [], ensure_ascii=False)) except Exception as e: logger.warning(f"写入 Redis 规则缓存失败: {e}") def _read_redis_rules(self) -> Optional[List[Dict[str, Any]]]: """从 Redis 读取规则缓存。""" try: text = self.redis.get(self.REDIS_RULES_KEY) if not text: return None data = json.loads(text) if not isinstance(data, list): return None return data except Exception as e: logger.warning(f"读取 Redis 规则缓存失败: {e}") return None def _set_local_cache(self, rules: List[Dict[str, Any]]) -> None: """更新应用内缓存。""" with self._local_lock: self._local_rules = list(rules or []) self._local_expire_at = time.time() + self.local_ttl_seconds def _get_local_cache(self) -> Optional[List[Dict[str, Any]]]: """读取应用内缓存。 仅当未过期才返回,确保后台更新后最长 local_ttl_seconds 可见。 """ with self._local_lock: if time.time() < self._local_expire_at: return list(self._local_rules) return None def refresh_cache(self) -> List[Dict[str, Any]]: """强制刷新缓存(DB -> Redis -> Local)。""" rules = self._load_rules_from_db() self._write_redis_rules(rules) self._set_local_cache(rules) return rules def get_enabled_rules(self) -> List[Dict[str, Any]]: """获取启用规则。 读取顺序: 1. 本地缓存命中直接返回(最高性能)。 2. Redis 命中则回填本地缓存。 3. MySQL 回源并回填 Redis + 本地缓存。 """ local_rules = self._get_local_cache() if local_rules is not None: return local_rules redis_rules = self._read_redis_rules() if redis_rules is not None: self._set_local_cache(redis_rules) return redis_rules return self.refresh_cache() # --------------------------- 管理端 CRUD --------------------------- def list_rules(self, scope_type: str = "", scope_id: str = "", enabled: Optional[bool] = None) -> List[Dict[str, Any]]: """后台使用:按条件列规则。""" return self.db.list_rules(scope_type=scope_type, scope_id=scope_id, enabled=enabled) def create_rule(self, payload: Dict[str, Any]) -> bool: """创建规则并刷新缓存。""" ok = self.db.create_rule(payload) if ok: self.refresh_cache() return ok def update_rule(self, rule_id: int, payload: Dict[str, Any]) -> bool: """更新规则并刷新缓存。""" ok = self.db.update_rule(rule_id, payload) if ok: self.refresh_cache() return ok def delete_rule(self, rule_id: int) -> bool: """删除规则并刷新缓存。""" ok = self.db.delete_rule(rule_id) if ok: self.refresh_cache() return ok def toggle_rule(self, rule_id: int, enabled: bool, updated_by: str = "system") -> bool: """启停规则并刷新缓存。""" ok = self.db.toggle_rule(rule_id=rule_id, enabled=enabled, updated_by=updated_by) if ok: self.refresh_cache() return ok # --------------------------- 规则匹配 --------------------------- @staticmethod def _scope_match(rule: Dict[str, Any], scope_type: str, scope_id: str) -> bool: """判断规则作用域是否命中当前会话。""" rule_scope_type = str(rule.get("scope_type", "global") or "global").strip().lower() rule_scope_id = str(rule.get("scope_id", "") or "").strip() # global:全局可用。 if rule_scope_type == "global": return True # group/private:必须同时匹配 scope_id。 if rule_scope_type == scope_type: return not rule_scope_id or rule_scope_id == scope_id return False @staticmethod def _trigger_match(rule: Dict[str, Any], event_key: str, content: str) -> bool: """判断规则触发条件是否命中。 支持: - event(事件触发,如 PAT) - exact/prefix/contains/regex(文本触发) """ trigger_type = str(rule.get("trigger_type", "exact") or "exact").strip().lower() trigger_text = str(rule.get("trigger_text", "") or "") target_event_key = str(rule.get("event_key", "") or "").strip().upper() normalized_content = str(content or "").strip() if trigger_type == "event": return bool(target_event_key) and target_event_key == str(event_key or "").strip().upper() if not normalized_content: return False if trigger_type == "exact": return normalized_content == trigger_text if trigger_type == "prefix": return normalized_content.startswith(trigger_text) if trigger_type == "contains": return trigger_text in normalized_content if trigger_type == "regex": try: return re.search(trigger_text, normalized_content) is not None except re.error: # 正则配置非法时直接视为不匹配,避免打断主流程。 return False return False def _cooldown_key(self, rule_id: int, session_key: str) -> str: """构造冷却键。 session_key 采用群ID/私聊ID,确保不同会话互不干扰。 """ return f"{int(rule_id)}::{session_key}" def _check_and_mark_cooldown(self, rule: Dict[str, Any], session_key: str) -> bool: """检查并写入冷却窗口。 返回 True 表示允许触发;False 表示仍在冷却中。 """ cooldown_seconds = int(rule.get("cooldown_seconds", 0) or 0) if cooldown_seconds <= 0: return True now = time.time() key = self._cooldown_key(int(rule.get("id", 0) or 0), session_key) with self._cooldown_lock: expired_at = self._cooldown_map.get(key, 0) if now < expired_at: return False self._cooldown_map[key] = now + cooldown_seconds # 轻量清理,防止 map 长期膨胀。 if len(self._cooldown_map) > 5000: stale_keys = [k for k, v in self._cooldown_map.items() if v < now] for stale_key in stale_keys[:1000]: self._cooldown_map.pop(stale_key, None) return True def match_rule( self, scope_type: str, scope_id: str, content: str, event_key: str, session_key: str, ) -> Optional[Dict[str, Any]]: """匹配首条可执行规则。 设计为“首条命中即返回”,通过 priority 实现可控顺序。 """ rules = self.get_enabled_rules() if not rules: return None normalized_scope_type = str(scope_type or "global").strip().lower() normalized_scope_id = str(scope_id or "").strip() for rule in rules: if not self._scope_match(rule, normalized_scope_type, normalized_scope_id): continue if not self._trigger_match(rule, event_key=event_key, content=content): continue if not self._check_and_mark_cooldown(rule, session_key=session_key): continue return rule return None def validate_responses(self, responses_json: Any) -> Tuple[bool, str, List[Dict[str, Any]]]: """校验响应动作数组。 返回: - ok: 是否通过 - message: 错误说明 - normalized: 归一化后的响应列表 """ if isinstance(responses_json, str): try: responses_json = json.loads(responses_json) except Exception: return False, "responses_json 不是合法 JSON", [] if not isinstance(responses_json, list) or not responses_json: return False, "responses_json 必须是非空数组", [] normalized: List[Dict[str, Any]] = [] allowed_types = {"text", "image", "voice", "video", "link", "app"} for idx, item in enumerate(responses_json): if not isinstance(item, dict): return False, f"第 {idx + 1} 条响应必须是对象", [] action_type = str(item.get("type", "") or "").strip().lower() if action_type not in allowed_types: return False, f"第 {idx + 1} 条响应 type 非法,仅支持 {sorted(allowed_types)}", [] delay_ms = int(item.get("delay_ms", 0) or 0) if delay_ms < 0: delay_ms = 0 normalized_item = dict(item) normalized_item["type"] = action_type normalized_item["delay_ms"] = delay_ms normalized.append(normalized_item) return True, "ok", normalized