Files
abot/base/func_doubao.py
2025-04-30 13:22:33 +08:00

161 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import requests
import json
from loguru import logger
from datetime import datetime
class Doubao():
def __init__(self, conf: dict) -> None:
self.key = conf.get("key")
self.api = conf.get("api")
prompt = conf.get("prompt")
self.model = conf.get("model")
self.LOG = logger
self.conversation_list = {}
self.system_content_msg = {"role": "system", "content": prompt}
def __repr__(self):
return 'Doubao'
def get_answer(self, question: str, wxid: str) -> str:
# 设置请求头
self.updateMessage(wxid, question, "user")
rsp = ""
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {self.key}"
}
# 设置请求的payload
data = {
"model": self.model,
"messages": [
self.system_content_msg,
{
"role": "user",
"content": f"{question}"
}
]
}
# 发送POST请求
response = requests.post(self.api, headers=headers, data=json.dumps(data), )
response.encoding = 'utf-8'
# 输出响应内容
print(response.status_code)
# print(response.text)
rsp = extract_content(response.text)
self.updateMessage(wxid, rsp, "assistant")
except Exception as e0:
self.LOG.error(f"发生未知错误:{str(e0)}")
return rsp
def updateMessage(self, wxid: str, question: str, role: str) -> None:
now_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
time_mk = "当需要回答时间时请直接参考回复:"
# 初始化聊天记录,组装系统信息
if wxid not in self.conversation_list.keys():
question_ = [
self.system_content_msg,
{"role": "system", "content": "" + time_mk + now_time}
]
self.conversation_list[wxid] = question_
# 当前问题
content_question_ = {"role": role, "content": question}
self.conversation_list[wxid].append(content_question_)
for cont in self.conversation_list[wxid]:
if cont["role"] != "system":
continue
if cont["content"].startswith(time_mk):
cont["content"] = time_mk + now_time
# 只存储10条记录超过滚动清除
i = len(self.conversation_list[wxid])
if i > 10:
print("滚动清除微信记录:" + wxid)
# 删除多余的记录,倒着删,且跳过第一个的系统消息
del self.conversation_list[wxid][1]
@staticmethod
def value_check(conf: dict) -> bool:
if conf:
if conf.get("key") and conf.get("api") and conf.get("prompt"):
return True
return False
def extract_content(data_string):
try:
data = json.loads(data_string)
# 提取content字段
content = data["choices"][0]["message"].get("content", "")
# 提取tokens相关内容加入容错处理
tokens_usage = data.get("usage", {})
# 确保tokens_usage是字典类型
if isinstance(tokens_usage, dict):
prompt_tokens = tokens_usage.get("prompt_tokens", 0)
completion_tokens = tokens_usage.get("completion_tokens", 0)
total_tokens = tokens_usage.get("total_tokens", 0)
else:
prompt_tokens = completion_tokens = total_tokens = 0
# 如果tokens信息为空提供默认值或提示
if prompt_tokens == 0 and completion_tokens == 0 and total_tokens == 0:
tokens_info = "\n\n【tokens】暂无数据"
else:
tokens_info = (f"\n\n【tokens】输入: {prompt_tokens} 生成: {completion_tokens} 总: {total_tokens}")
# 将tokens信息添加到content后面返回为字符串
content_with_tokens = content + tokens_info
return content_with_tokens
except json.JSONDecodeError:
print("Invalid JSON")
return None
if __name__ == '__main__':
from configuration import Config
config = Config().DOUBAO
if not config:
exit(0)
chat = Doubao(config)
while True:
q = input(">>> ")
try:
time_start = datetime.now() # 记录开始时间
print(chat.get_answer(q, "Jyunere"))
time_end = datetime.now() # 记录结束时间
print(f"{round((time_end - time_start).total_seconds(), 2)}s") # 计算的时间差为程序的执行时间,单位为秒/s
except Exception as e:
print(e)
# curl
# https: // ark.cn - beijing.volces.com / api / v3 / chat / completions \
# - H
# "Content-Type: application/json" \
# - H
# "Authorization: Bearer b8586595-eb81-483d-8e91-a35cc789729e" \
# - d
# '{
# "model": "doubao-1-5-lite-32k-250115",
# "messages": [
# {"role": "system", "content": "你是人工智能助手."},
# {"role": "user", "content": "常见的十字花科植物有哪些?"}
# ]
#
# }'