138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
import requests
|
||
import json
|
||
import logging
|
||
|
||
from datetime import datetime
|
||
|
||
|
||
class Doubao():
|
||
def __init__(self, conf: dict) -> None:
|
||
self.key = conf.get("key")
|
||
self.api = conf.get("api")
|
||
prompt = conf.get("prompt")
|
||
self.model = conf.get("model")
|
||
self.LOG = logging.getLogger("doubao")
|
||
self.conversation_list = {}
|
||
self.system_content_msg = {"role": "system", "content": prompt}
|
||
|
||
def __repr__(self):
|
||
return 'Doubao'
|
||
|
||
def get_answer(self, question: str, wxid: str) -> str:
|
||
# 设置请求头
|
||
self.updateMessage(wxid, question, "user")
|
||
rsp = ""
|
||
try:
|
||
headers = {
|
||
"Content-Type": "application/json; charset=utf-8",
|
||
"Authorization": f"Bearer {self.key}"
|
||
}
|
||
# 设置请求的payload
|
||
data = {
|
||
"model": self.model,
|
||
"messages": [
|
||
self.system_content_msg,
|
||
{
|
||
"role": "user",
|
||
"content": f"{question}"
|
||
}
|
||
|
||
]
|
||
}
|
||
# 发送POST请求
|
||
response = requests.post(self.api, headers=headers, data=json.dumps(data), )
|
||
response.encoding = 'utf-8'
|
||
|
||
# 输出响应内容
|
||
print(response.status_code)
|
||
# print(response.text)
|
||
rsp = extract_content(response.text)
|
||
self.updateMessage(wxid, rsp, "assistant")
|
||
except Exception as e0:
|
||
self.LOG.error(f"发生未知错误:{str(e0)}")
|
||
return rsp
|
||
|
||
def updateMessage(self, wxid: str, question: str, role: str) -> None:
|
||
now_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||
|
||
time_mk = "当需要回答时间时请直接参考回复:"
|
||
# 初始化聊天记录,组装系统信息
|
||
if wxid not in self.conversation_list.keys():
|
||
question_ = [
|
||
self.system_content_msg,
|
||
{"role": "system", "content": "" + time_mk + now_time}
|
||
]
|
||
self.conversation_list[wxid] = question_
|
||
|
||
# 当前问题
|
||
content_question_ = {"role": role, "content": question}
|
||
self.conversation_list[wxid].append(content_question_)
|
||
|
||
for cont in self.conversation_list[wxid]:
|
||
if cont["role"] != "system":
|
||
continue
|
||
if cont["content"].startswith(time_mk):
|
||
cont["content"] = time_mk + now_time
|
||
|
||
# 只存储10条记录,超过滚动清除
|
||
i = len(self.conversation_list[wxid])
|
||
if i > 10:
|
||
print("滚动清除微信记录:" + wxid)
|
||
# 删除多余的记录,倒着删,且跳过第一个的系统消息
|
||
del self.conversation_list[wxid][1]
|
||
|
||
@staticmethod
|
||
def value_check(conf: dict) -> bool:
|
||
if conf:
|
||
if conf.get("key") and conf.get("api") and conf.get("prompt"):
|
||
return True
|
||
return False
|
||
|
||
|
||
# 解析JSON
|
||
def extract_content(data_string):
|
||
try:
|
||
data = json.loads(data_string)
|
||
# 提取content字段
|
||
content = data["choices"][0]["message"].get("content", "")
|
||
return content
|
||
except json.JSONDecodeError:
|
||
print("Invalid JSON")
|
||
return None
|
||
|
||
|
||
if __name__ == '__main__':
|
||
from configuration import Config
|
||
|
||
config = Config().DOUBAO
|
||
if not config:
|
||
exit(0)
|
||
|
||
chat = Doubao(config)
|
||
|
||
while True:
|
||
q = input(">>> ")
|
||
try:
|
||
time_start = datetime.now() # 记录开始时间
|
||
print(chat.get_answer(q, "Jyunere"))
|
||
time_end = datetime.now() # 记录结束时间
|
||
|
||
print(f"{round((time_end - time_start).total_seconds(), 2)}s") # 计算的时间差为程序的执行时间,单位为秒/s
|
||
except Exception as e:
|
||
print(e)
|
||
|
||
# curl
|
||
# https: // ark.cn - beijing.volces.com / api / v3 / chat / completions \
|
||
# - H
|
||
# "Content-Type: application/json" \
|
||
# - H
|
||
# "Authorization: Bearer b8586595-eb81-483d-8e91-a35cc789729e" \
|
||
# - d
|
||
# '{
|
||
# "model": "doubao-1-5-lite-32k-250115",
|
||
# "messages": [
|
||
# {"role": "system", "content": "你是人工智能助手."},
|
||
# {"role": "user", "content": "常见的十字花科植物有哪些?"}
|
||
# ]
|
||
#
|
||
# }' |