Merge pull request #61 from jnchen/master

修改chatglm接入为使用新版openai语法,解决函数调用报错的问题
This commit is contained in:
Changhua
2024-05-11 16:54:57 +08:00
committed by GitHub
2 changed files with 19 additions and 14 deletions

View File

@@ -51,7 +51,7 @@ def register_tool(func: callable):
tool_def = { tool_def = {
"name": tool_name, "name": tool_name,
"description": tool_description, "description": tool_description,
"params": tool_params "parameters": tool_params
} }
# print("[registered tool] " + pformat(tool_def)) # print("[registered tool] " + pformat(tool_def))

View File

@@ -6,8 +6,8 @@ import os
import random import random
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
import httpx
import openai from openai import OpenAI
from base.chatglm.code_kernel import CodeKernel, execute from base.chatglm.code_kernel import CodeKernel, execute
from base.chatglm.tool_registry import dispatch_tool, extract_code, get_tools from base.chatglm.tool_registry import dispatch_tool, extract_code, get_tools
from wcferry import Wcf from wcferry import Wcf
@@ -18,12 +18,13 @@ functions = get_tools()
class ChatGLM: class ChatGLM:
def __init__(self, config={}, wcf: Optional[Wcf] = None, max_retry=5) -> None: def __init__(self, config={}, wcf: Optional[Wcf] = None, max_retry=5) -> None:
openai.api_key = config.get("key", "empty") key = config.get("key", 'empty')
# 自己搭建或第三方代理的接口 api = config.get("api")
openai.api_base = config["api"]
proxy = config.get("proxy") proxy = config.get("proxy")
if proxy: if proxy:
openai.proxy = {"http": proxy, "https": proxy} self.client = OpenAI(api_key=key, base_url=api, http_client=httpx.Client(proxy=proxy))
else:
self.client = OpenAI(api_key=key, base_url=api)
self.conversation_list = {} self.conversation_list = {}
self.chat_type = {} self.chat_type = {}
self.max_retry = max_retry self.max_retry = max_retry
@@ -31,8 +32,11 @@ class ChatGLM:
self.filePath = config["file_path"] self.filePath = config["file_path"]
self.kernel = CodeKernel() self.kernel = CodeKernel()
self.system_content_msg = {"chat": [{"role": "system", "content": config["prompt"]}], self.system_content_msg = {"chat": [{"role": "system", "content": config["prompt"]}],
"tool": [{"role": "system", "content": "Answer the following questions as best as you can. You have access to the following tools:"}], "tool": [{"role": "system",
"code": [{"role": "system", "content": "你是一位智能AI助手你叫ChatGLM你连接着一台电脑但请注意不能联网。在使用Python解决任务时你可以运行代码并得到结果如果运行结果有错误你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件文件默认存储路径是{}".format(self.filePath)}]} "content": "Answer the following questions as best as you can. You have access to the following tools:"}],
"code": [{"role": "system",
"content": "你是一位智能AI助手你叫ChatGLM你连接着一台电脑但请注意不能联网。在使用Python解决任务时你可以运行代码并得到结果如果运行结果有错误你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件文件默认存储路径是{}".format(
self.filePath)}]}
def __repr__(self): def __repr__(self):
return 'ChatGLM' return 'ChatGLM'
@@ -71,8 +75,8 @@ class ChatGLM:
params = dict(model="chatglm3", temperature=1.0, params = dict(model="chatglm3", temperature=1.0,
messages=self.conversation_list[wxid][self.chat_type[wxid]], stream=False) messages=self.conversation_list[wxid][self.chat_type[wxid]], stream=False)
if 'tool' == self.chat_type[wxid]: if 'tool' == self.chat_type[wxid]:
params["functions"] = functions params["tools"] = [dict(type='function', function=d) for d in functions.values()]
response = openai.ChatCompletion.create(**params) response = self.client.chat.completions.create(**params)
for _ in range(self.max_retry): for _ in range(self.max_retry):
if response.choices[0].message.get("function_call"): if response.choices[0].message.get("function_call"):
function_call = response.choices[0].message.function_call function_call = response.choices[0].message.function_call
@@ -106,7 +110,7 @@ class ChatGLM:
} }
) )
self.updateMessage(wxid, tool_response, "function") self.updateMessage(wxid, tool_response, "function")
response = openai.ChatCompletion.create(**params) response = self.client.chat.completions.create(**params)
elif response.choices[0].message.content.find('interpreter') != -1: elif response.choices[0].message.content.find('interpreter') != -1:
output_text = response.choices[0].message.content output_text = response.choices[0].message.content
code = extract_code(output_text) code = extract_code(output_text)
@@ -136,7 +140,7 @@ class ChatGLM:
} }
) )
self.updateMessage(wxid, tool_response, "function") self.updateMessage(wxid, tool_response, "function")
response = openai.ChatCompletion.create(**params) response = self.client.chat.completions.create(**params)
else: else:
rsp = response.choices[0].message.content rsp = response.choices[0].message.content
break break
@@ -171,6 +175,7 @@ class ChatGLM:
if __name__ == "__main__": if __name__ == "__main__":
from configuration import Config from configuration import Config
config = Config().CHATGLM config = Config().CHATGLM
if not config: if not config:
exit(0) exit(0)