import requests import json import logging from datetime import datetime class Doubao(): def __init__(self, conf: dict) -> None: self.key = conf.get("key") self.api = conf.get("api") prompt = conf.get("prompt") self.model = conf.get("model") self.LOG = logging.getLogger("doubao") self.conversation_list = {} self.system_content_msg = {"role": "system", "content": prompt} def __repr__(self): return 'Doubao' def get_answer(self, question: str, wxid: str) -> str: # 设置请求头 self.updateMessage(wxid, question, "user") rsp = "" try: headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {self.key}" } # 设置请求的payload data = { "model": self.model, "messages": [ self.system_content_msg, { "role": "user", "content": f"{question}" } ] } # 发送POST请求 response = requests.post(self.api, headers=headers, data=json.dumps(data), ) response.encoding = 'utf-8' # 输出响应内容 print(response.status_code) # print(response.text) rsp = extract_content(response.text) self.updateMessage(wxid, rsp, "assistant") except Exception as e0: self.LOG.error(f"发生未知错误:{str(e0)}") return rsp def updateMessage(self, wxid: str, question: str, role: str) -> None: now_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) time_mk = "当需要回答时间时请直接参考回复:" # 初始化聊天记录,组装系统信息 if wxid not in self.conversation_list.keys(): question_ = [ self.system_content_msg, {"role": "system", "content": "" + time_mk + now_time} ] self.conversation_list[wxid] = question_ # 当前问题 content_question_ = {"role": role, "content": question} self.conversation_list[wxid].append(content_question_) for cont in self.conversation_list[wxid]: if cont["role"] != "system": continue if cont["content"].startswith(time_mk): cont["content"] = time_mk + now_time # 只存储10条记录,超过滚动清除 i = len(self.conversation_list[wxid]) if i > 10: print("滚动清除微信记录:" + wxid) # 删除多余的记录,倒着删,且跳过第一个的系统消息 del self.conversation_list[wxid][1] @staticmethod def value_check(conf: dict) -> bool: if conf: if conf.get("key") and conf.get("api") and conf.get("prompt"): return True return False # 解析JSON def extract_content(data_string): try: data = json.loads(data_string) # 提取content字段 content = data["choices"][0]["message"].get("content", "") return content except json.JSONDecodeError: print("Invalid JSON") return None if __name__ == '__main__': from configuration import Config config = Config().DOUBAO if not config: exit(0) chat = Doubao(config) while True: q = input(">>> ") try: time_start = datetime.now() # 记录开始时间 print(chat.get_answer(q, "Jyunere")) time_end = datetime.now() # 记录结束时间 print(f"{round((time_end - time_start).total_seconds(), 2)}s") # 计算的时间差为程序的执行时间,单位为秒/s except Exception as e: print(e) # curl # https: // ark.cn - beijing.volces.com / api / v3 / chat / completions \ # - H # "Content-Type: application/json" \ # - H # "Authorization: Bearer b8586595-eb81-483d-8e91-a35cc789729e" \ # - d # '{ # "model": "doubao-1-5-lite-32k-250115", # "messages": [ # {"role": "system", "content": "你是人工智能助手."}, # {"role": "user", "content": "常见的十字花科植物有哪些?"} # ] # # }'