修改chatglm接入为使用新版openai语法,解决函数调用报错的问题

This commit is contained in:
caojingchen
2024-05-11 16:01:26 +08:00
parent d8913613d8
commit 7e8e760774
2 changed files with 19 additions and 14 deletions

View File

@@ -51,7 +51,7 @@ def register_tool(func: callable):
tool_def = { tool_def = {
"name": tool_name, "name": tool_name,
"description": tool_description, "description": tool_description,
"params": tool_params "parameters": tool_params
} }
# print("[registered tool] " + pformat(tool_def)) # print("[registered tool] " + pformat(tool_def))

View File

@@ -6,8 +6,8 @@ import os
import random import random
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
import httpx
import openai from openai import OpenAI
from base.chatglm.code_kernel import CodeKernel, execute from base.chatglm.code_kernel import CodeKernel, execute
from base.chatglm.tool_registry import dispatch_tool, extract_code, get_tools from base.chatglm.tool_registry import dispatch_tool, extract_code, get_tools
from wcferry import Wcf from wcferry import Wcf
@@ -18,12 +18,13 @@ functions = get_tools()
class ChatGLM: class ChatGLM:
def __init__(self, config={}, wcf: Optional[Wcf] = None, max_retry=5) -> None: def __init__(self, config={}, wcf: Optional[Wcf] = None, max_retry=5) -> None:
openai.api_key = config.get("key", "empty") key = config.get("key", 'empty')
# 自己搭建或第三方代理的接口 api = config.get("api")
openai.api_base = config["api"]
proxy = config.get("proxy") proxy = config.get("proxy")
if proxy: if proxy:
openai.proxy = {"http": proxy, "https": proxy} self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
else:
self.client = OpenAI(api_key=key, base_url=api)
self.conversation_list = {} self.conversation_list = {}
self.chat_type = {} self.chat_type = {}
self.max_retry = max_retry self.max_retry = max_retry
@@ -31,8 +32,11 @@ class ChatGLM:
self.filePath = config["file_path"] self.filePath = config["file_path"]
self.kernel = CodeKernel() self.kernel = CodeKernel()
self.system_content_msg = {"chat": [{"role": "system", "content": config["prompt"]}], self.system_content_msg = {"chat": [{"role": "system", "content": config["prompt"]}],
"tool": [{"role": "system", "content": "Answer the following questions as best as you can. You have access to the following tools:"}], "tool": [{"role": "system",
"code": [{"role": "system", "content": "你是一位智能AI助手你叫ChatGLM你连接着一台电脑但请注意不能联网。在使用Python解决任务时你可以运行代码并得到结果如果运行结果有错误你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件文件默认存储路径是{}".format(self.filePath)}]} "content": "Answer the following questions as best as you can. You have access to the following tools:"}],
"code": [{"role": "system",
"content": "你是一位智能AI助手你叫ChatGLM你连接着一台电脑但请注意不能联网。在使用Python解决任务时你可以运行代码并得到结果如果运行结果有错误你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件文件默认存储路径是{}".format(
self.filePath)}]}
def __repr__(self): def __repr__(self):
return 'ChatGLM' return 'ChatGLM'
@@ -59,7 +63,7 @@ class ChatGLM:
return '已切换#代码模式 \n代码模式可以用于写python代码例如\n用python画一个爱心' return '已切换#代码模式 \n代码模式可以用于写python代码例如\n用python画一个爱心'
elif '#清除模式会话' == question or '#4' == question: elif '#清除模式会话' == question or '#4' == question:
self.conversation_list[wxid][self.chat_type[wxid] self.conversation_list[wxid][self.chat_type[wxid]
] = self.system_content_msg[self.chat_type[wxid]] ] = self.system_content_msg[self.chat_type[wxid]]
return '已清除' return '已清除'
elif '#清除全部会话' == question or '#5' == question: elif '#清除全部会话' == question or '#5' == question:
self.conversation_list[wxid] = self.system_content_msg self.conversation_list[wxid] = self.system_content_msg
@@ -71,8 +75,8 @@ class ChatGLM:
params = dict(model="chatglm3", temperature=1.0, params = dict(model="chatglm3", temperature=1.0,
messages=self.conversation_list[wxid][self.chat_type[wxid]], stream=False) messages=self.conversation_list[wxid][self.chat_type[wxid]], stream=False)
if 'tool' == self.chat_type[wxid]: if 'tool' == self.chat_type[wxid]:
params["functions"] = functions params["tools"] = [dict(type='function', function=d) for d in functions.values()]
response = openai.ChatCompletion.create(**params) response = self.client.chat.completions.create(**params)
for _ in range(self.max_retry): for _ in range(self.max_retry):
if response.choices[0].message.get("function_call"): if response.choices[0].message.get("function_call"):
function_call = response.choices[0].message.function_call function_call = response.choices[0].message.function_call
@@ -106,7 +110,7 @@ class ChatGLM:
} }
) )
self.updateMessage(wxid, tool_response, "function") self.updateMessage(wxid, tool_response, "function")
response = openai.ChatCompletion.create(**params) response = self.client.chat.completions.create(**params)
elif response.choices[0].message.content.find('interpreter') != -1: elif response.choices[0].message.content.find('interpreter') != -1:
output_text = response.choices[0].message.content output_text = response.choices[0].message.content
code = extract_code(output_text) code = extract_code(output_text)
@@ -136,7 +140,7 @@ class ChatGLM:
} }
) )
self.updateMessage(wxid, tool_response, "function") self.updateMessage(wxid, tool_response, "function")
response = openai.ChatCompletion.create(**params) response = self.client.chat.completions.create(**params)
else: else:
rsp = response.choices[0].message.content rsp = response.choices[0].message.content
break break
@@ -171,6 +175,7 @@ class ChatGLM:
if __name__ == "__main__": if __name__ == "__main__":
from configuration import Config from configuration import Config
config = Config().CHATGLM config = Config().CHATGLM
if not config: if not config:
exit(0) exit(0)