代码优化

This commit is contained in:
严浪
2023-11-13 14:29:57 +08:00
parent db3ff14c71
commit 9ca9d31fee
10 changed files with 210 additions and 143 deletions

View File

@@ -7,85 +7,84 @@ import openai
import json
import os
import random
from tool.tool_registry import get_tools, dispatch_tool,extract_code
from tool.code_kernel import execute,CodeKernel
from wcferry import Wcf, WxMsg
from chatglm.tool_registry import get_tools, dispatch_tool, extract_code
from chatglm.code_kernel import execute, CodeKernel
from typing import Dict, Union, Optional, Tuple
from wcferry import Wcf
functions = get_tools()
functions=get_tools()
class ChatGLM():
def __init__(self, wcf: Wcf, config={},max_retry=5) -> None:
openai.api_key = config.get('key')
def __init__(self, config={}, wcf: Optional[Wcf] = None, max_retry=5) -> None:
openai.api_key = config.get('key', 'XXX')
# 自己搭建或第三方代理的接口
openai.api_base = config.get('api')
if config.get('proxy',None):
openai.proxy = {"http": config.get('proxy',None), "https": config.get('proxy',None)}
openai.api_base = config.get('api', 'http://localhost:8000/v1')
if config.get('proxy', None):
openai.proxy = {"http": config.get(
'proxy', None), "https": config.get('proxy', None)}
self.conversation_list = {}
self.chat_type={}
self.max_retry=max_retry
self.wcf=wcf
self.filePath=config.get('file_path')
self.chat_type = {}
self.max_retry = max_retry
self.wcf = wcf
self.filePath = config.get('file_path', 'temp')
self.kernel = CodeKernel()
self.system_content_msg = {"chat":[{"role": "system", "content": config.get('prompt')}],
"tool":[{"role": "system", "content": "Answer the following questions as best as you can. You have access to the following tools:"}],
"code":[{"role": "system", "content": "你是一位智能AI助手你叫ChatGLM你连接着一台电脑但请注意不能联网。在使用Python解决任务时你可以运行代码并得到结果如果运行结果有错误你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件文件默认存储路径是{}".format(self.filePath)}]}
code0='''
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-1, 1, 50)
y = x * x + 1
plt.plot(x, y)
plt.show()
'''
res_type, res = execute(code0, self.kernel) #第一次画图不返回图片问题
print(res_type, res)
self.system_content_msg = {"chat": [{"role": "system", "content": config.get('prompt', '你是智能聊天机器人,你叫小薇')}],
"tool": [{"role": "system", "content": "Answer the following questions as best as you can. You have access to the following tools:"}],
"code": [{"role": "system", "content": "你是一位智能AI助手你叫ChatGLM你连接着一台电脑但请注意不能联网。在使用Python解决任务时你可以运行代码并得到结果如果运行结果有错误你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件文件默认存储路径是{}".format(self.filePath)}]}
def get_answer(self, question: str, wxid: str) -> str:
# wxid或者roomid,个人时为微信id群消息时为群id
if '#帮助'==question:
if '#帮助' == question:
return '本助手有三种模式,#聊天模式 = #1 #工具模式 = #2 #代码模式 = #3 , #清除模式会话 = #4 , #清除全部会话 = #5 可用发送#对应模式 或者 #编号 进行切换'
elif '#聊天模式'==question or '#1'==question:
self.chat_type[wxid]='chat'
elif '#聊天模式' == question or '#1' == question:
self.chat_type[wxid] = 'chat'
return '已切换#聊天模式'
elif '#工具模式'==question or '#2'==question:
self.chat_type[wxid]='tool'
elif '#工具模式' == question or '#2' == question:
self.chat_type[wxid] = 'tool'
return '已切换#工具模式 \n工具有:查看天气,日期,新闻,comfyUI文生图。例如\n帮我生成一张小鸟的图片,提示词必须是英文'
elif '#代码模式'==question or '#3'==question:
self.chat_type[wxid]='code'
elif '#代码模式' == question or '#3' == question:
self.chat_type[wxid] = 'code'
return '已切换#代码模式 \n代码模式可以用于写python代码例如\n用python画一个爱心'
elif '#清除模式会话'==question or '#4'==question:
self.conversation_list[wxid][self.chat_type[wxid]]=self.system_content_msg[self.chat_type[wxid]]
elif '#清除模式会话' == question or '#4' == question:
self.conversation_list[wxid][self.chat_type[wxid]
] = self.system_content_msg[self.chat_type[wxid]]
return '已清除'
elif '#清除全部会话'==question or '#5'==question:
self.conversation_list[wxid]=self.system_content_msg
elif '#清除全部会话' == question or '#5' == question:
self.conversation_list[wxid] = self.system_content_msg
return '已清除'
self.updateMessage(wxid, question, "user")
try:
params = dict(model="chatglm3",temperature=1.0, messages=self.conversation_list[wxid][self.chat_type[wxid]], stream=False)
if 'tool'==self.chat_type[wxid]:
params = dict(model="chatglm3", temperature=1.0,
messages=self.conversation_list[wxid][self.chat_type[wxid]], stream=False)
if 'tool' == self.chat_type[wxid]:
params["functions"] = functions
response = openai.ChatCompletion.create(**params)
for _ in range(self.max_retry):
if response.choices[0].message.get("function_call"):
function_call = response.choices[0].message.function_call
print(f"Function Call Response: {function_call.to_dict_recursive()}")
print(
f"Function Call Response: {function_call.to_dict_recursive()}")
function_args = json.loads(function_call.arguments)
observation = dispatch_tool(function_call.name, function_args)
if isinstance(observation,dict):
res_type=observation['res_type'] if 'res_type' in observation else 'text'
res=observation['res'] if 'res_type' in observation else str(observation )
observation = dispatch_tool(
function_call.name, function_args)
if isinstance(observation, dict):
res_type = observation['res_type'] if 'res_type' in observation else 'text'
res = observation['res'] if 'res_type' in observation else str(
observation)
if res_type == 'image':
filename= observation['filename']
filePath=os.path.join(self.filePath,filename)
filename = observation['filename']
filePath = os.path.join(self.filePath, filename)
res.save(filePath)
self.wcf.send_image(filePath,wxid)
tool_response='[Image]' if res_type == 'image' else res
self.wcf and self.wcf.send_image(filePath, wxid)
tool_response = '[Image]' if res_type == 'image' else res
else:
tool_response=observation if isinstance(observation,str) else str(observation)
tool_response = observation if isinstance(
observation, str) else str(observation)
print(f"Tool Call Response: {tool_response}")
params["messages"].append(response.choices[0].message)
@@ -98,24 +97,25 @@ plt.show()
)
self.updateMessage(wxid, tool_response, "function")
response = openai.ChatCompletion.create(**params)
elif response.choices[0].message.content.find('interpreter')!=-1:
output_text=response.choices[0].message.content
elif response.choices[0].message.content.find('interpreter') != -1:
output_text = response.choices[0].message.content
code = extract_code(output_text)
self.wcf.send_text('代码如下:\n'+code,wxid)
self.wcf.send_text('执行代码...',wxid)
self.wcf and self.wcf.send_text('代码如下:\n' + code, wxid)
self.wcf and self.wcf.send_text('执行代码...', wxid)
try:
res_type, res = execute(code, self.kernel)
except Exception as e:
rsp=f'代码执行错误: {e}'
rsp = f'代码执行错误: {e}'
break
if res_type == 'image':
filename= '{}.png'.format(''.join(random.sample('abcdefghijklmnopqrstuvwxyz1234567890',8)))
filePath=os.path.join(self.filePath,filename)
filename = '{}.png'.format(''.join(random.sample(
'abcdefghijklmnopqrstuvwxyz1234567890', 8)))
filePath = os.path.join(self.filePath, filename)
res.save(filePath)
self.wcf.send_image(filePath,wxid)
self.wcf and self.wcf.send_image(filePath, wxid)
else:
self.wcf.send_text("执行结果:\n"+res,wxid)
tool_response='[Image]' if res_type == 'image' else res
self.wcf and self.wcf.send_text("执行结果:\n" + res, wxid)
tool_response = '[Image]' if res_type == 'image' else res
print("Received:", res_type, res)
params["messages"].append(response.choices[0].message)
params["messages"].append(
@@ -144,11 +144,12 @@ plt.show()
if wxid not in self.conversation_list.keys():
self.conversation_list[wxid] = self.system_content_msg
if wxid not in self.chat_type.keys():
self.chat_type[wxid]='chat'
self.chat_type[wxid] = 'chat'
# 当前问题
content_question_ = {"role": role, "content": question}
self.conversation_list[wxid][self.chat_type[wxid]].append(content_question_)
self.conversation_list[wxid][self.chat_type[wxid]].append(
content_question_)
# 只存储10条记录超过滚动清除
i = len(self.conversation_list[wxid][self.chat_type[wxid]])
@@ -160,16 +161,11 @@ plt.show()
if __name__ == "__main__":
from configuration import Config
config = Config().CHATGPT
config = Config().CHATGLM
if not config:
exit(0)
key = config.get("key")
api = config.get("api")
proxy = config.get("proxy")
prompt = config.get("prompt")
chat = ChatGLM(key, api, proxy, prompt)
chat = ChatGLM(config)
while True:
q = input(">>> ")
@@ -178,6 +174,7 @@ if __name__ == "__main__":
print(chat.get_answer(q, "wxid"))
time_end = datetime.now() # 记录结束时间
print(f"{round((time_end - time_start).total_seconds(), 2)}s") # 计算的时间差为程序的执行时间,单位为秒/s
# 计算的时间差为程序的执行时间,单位为秒/s
print(f"{round((time_end - time_start).total_seconds(), 2)}s")
except Exception as e:
print(e)