diff --git a/src/app.tsx b/src/app.tsx index d12c633..f54320b 100644 --- a/src/app.tsx +++ b/src/app.tsx @@ -21,9 +21,11 @@ export interface TemplateAPI { key: string; endpoint: string; } + export interface ChatStore { chatgpt_api_web_version: string; systemMessageContent: string; + toolsString: string; history: ChatStoreMessage[]; postBeginIndex: number; tokenMargin: number; @@ -67,11 +69,13 @@ export const newChatStore = ( tts_api = "", tts_key = "", tts_speed = 1.0, - tts_speed_enabled = false + tts_speed_enabled = false, + toolsString = "" ): ChatStore => { return { chatgpt_api_web_version: CHATGPT_API_WEB_VERSION, systemMessageContent: getDefaultParams("sys", systemMessageContent), + toolsString, history: [], postBeginIndex: 0, tokenMargin: 1024, @@ -173,6 +177,7 @@ export function App() { if (ret.maxGenTokens_enabled === undefined) ret.maxGenTokens_enabled = true; if (ret.model === undefined) ret.model = "gpt-3.5-turbo"; if (ret.responseModelName === undefined) ret.responseModelName = ""; + if (ret.toolsString === undefined) ret.toolsString = ""; if (ret.chatgpt_api_web_version === undefined) // this is from old version becasue it is undefined, // so no higher than v1.3.0 @@ -250,7 +255,8 @@ export function App() { chatStore.tts_api, chatStore.tts_key, chatStore.tts_speed, - chatStore.tts_speed_enabled + chatStore.tts_speed_enabled, + chatStore.toolsString ) ); setSelectedChatIndex(newKey as number); diff --git a/src/chatbox.tsx b/src/chatbox.tsx index 057f207..5aa2381 100644 --- a/src/chatbox.tsx +++ b/src/chatbox.tsx @@ -13,7 +13,9 @@ import ChatGPT, { calculate_token_length, ChunkMessage, FetchResponse, + Message as MessageType, MessageDetail, + ToolCall, } from "./chatgpt"; import Message from "./message"; import models from "./models"; @@ -41,6 +43,9 @@ export default function ChatBOX(props: { const [generatingMessage, setGeneratingMessage] = useState(""); const [showRetry, setShowRetry] = useState(false); const [isRecording, setIsRecording] = useState("Mic"); + const [showAddToolMsg, setShowAddToolMsg] = useState(false); + const [newToolCallID, setNewToolCallID] = useState(""); + const [newToolContent, setNewToolContent] = useState(""); const mediaRef = createRef(); const messagesEndRef = createRef(); @@ -67,12 +72,48 @@ export default function ChatBOX(props: { let responseTokenCount = 0; chatStore.streamMode = true; const allChunkMessage: string[] = []; + const allChunkTool: ToolCall[] = []; setShowGenerating(true); for await (const i of client.processStreamResponse(response)) { chatStore.responseModelName = i.model; responseTokenCount += 1; allChunkMessage.push(i.choices[0].delta.content ?? ""); - setGeneratingMessage(allChunkMessage.join("")); + const tool_calls = i.choices[0].delta.tool_calls; + if (tool_calls) { + for (const tool_call of tool_calls) { + // init + if (tool_call.id) { + allChunkTool.push({ + id: tool_call.id, + type: tool_call.type, + index: tool_call.index, + function: { + name: tool_call.function.name, + arguments: "", + }, + }); + continue; + } + + // update tool call arguments + const tool = allChunkTool.find( + (tool) => tool.index === tool_call.index + ); + + if (!tool) { + console.log("tool (by index) not found", tool_call.index); + continue; + } + + tool.function.arguments += tool_call.function.arguments; + } + } + setGeneratingMessage( + allChunkMessage.join("") + + allChunkTool.map((tool) => { + return `Tool Call ID: ${tool.id}\nType: ${tool.type}\nFunction: ${tool.function.name}\nArguments: ${tool.function.arguments}`; + }) + ); } setShowGenerating(false); const content = allChunkMessage.join(""); @@ -99,6 +140,7 @@ export default function ChatBOX(props: { chatStore.history.push({ role: "assistant", content, + tool_calls: allChunkTool, hide: false, token: responseTokenCount, example: false, @@ -127,7 +169,7 @@ export default function ChatBOX(props: { chatStore.cost += cost; addTotalCost(cost); } - const content = client.processFetchResponse(data); + const msg = client.processFetchResponse(data); // estimate user's input message token let aboveToken = 0; @@ -147,9 +189,11 @@ export default function ChatBOX(props: { chatStore.history.push({ role: "assistant", - content, + content: msg.content, + tool_calls: msg.tool_calls, hide: false, - token: data.usage.completion_tokens ?? calculate_token_length(content), + token: + data.usage.completion_tokens ?? calculate_token_length(msg.content), example: false, }); setShowGenerating(false); @@ -160,6 +204,7 @@ export default function ChatBOX(props: { // manually copy status from chatStore to client client.apiEndpoint = chatStore.apiEndpoint; client.sysMessageContent = chatStore.systemMessageContent; + client.toolsString = chatStore.toolsString; client.tokens_margin = chatStore.tokenMargin; client.temperature = chatStore.temperature; client.enable_temperature = chatStore.temperature_enabled; @@ -172,18 +217,22 @@ export default function ChatBOX(props: { .filter(({ hide }) => !hide) .slice(chatStore.postBeginIndex) // only copy content and role attribute to client for posting - .map(({ content, role, example }) => { - if (example) { - return { - content, - role: "system", - name: role === "assistant" ? "example_assistant" : "example_user", - }; - } - return { + .map(({ content, role, example, tool_call_id, tool_calls }) => { + const ret: MessageType = { content, role, + tool_calls, }; + + if (example) { + ret.name = + ret.role === "assistant" ? "example_assistant" : "example_user"; + ret.role = "system"; + } + + if (tool_call_id) ret.tool_call_id = tool_call_id; + + return ret; }); client.model = chatStore.model; client.max_tokens = chatStore.maxTokens; @@ -406,6 +455,7 @@ export default function ChatBOX(props: { className="mx-2 underline cursor-pointer" onClick={() => { chatStore.systemMessageContent = ""; + chatStore.toolsString = ""; chatStore.history = []; setChatStore({ ...chatStore }); }} @@ -943,6 +993,93 @@ export default function ChatBOX(props: { {Tr("User")} )} + {chatStore.develop_mode && ( + + )} + {showAddToolMsg && ( +
{ + setShowAddToolMsg(false); + }} + > +
{ + event.stopPropagation(); + }} + > +

Add Tool Message

+
+ + + + setNewToolCallID(event.target.value) + } + /> + + + + + + + + + +
+
+ )} ); diff --git a/src/chatgpt.ts b/src/chatgpt.ts index 1f604d7..748476a 100644 --- a/src/chatgpt.ts +++ b/src/chatgpt.ts @@ -8,13 +8,53 @@ export interface MessageDetail { text?: string; image_url?: ImageURL; } +export interface ToolCall { + index: number; + id?: string; + type: string; + function: { + name: string; + arguments: string; + }; +} export interface Message { - role: "system" | "user" | "assistant" | "function"; + role: "system" | "user" | "assistant" | "tool"; content: string | MessageDetail[]; name?: "example_user" | "example_assistant"; + tool_calls?: ToolCall[]; + tool_call_id?: string; +} + +interface Delta { + role?: string; + content?: string; + tool_calls?: ToolCall[]; +} + +interface Choices { + index: number; + delta: Delta; + finish_reason: string | null; +} + +export interface StreamingResponseChunk { + id: string; + object: string; + created: number; + model: string; + system_fingerprint: string; + choices: Choices[]; } export const getMessageText = (message: Message): string => { if (typeof message.content === "string") { + // function call message + if (message.tool_calls) { + return message.tool_calls + .map((tc) => { + return `Tool Call ID: ${tc.id}\nType: ${tc.type}\nFunction: ${tc.function.name}\nArguments: ${tc.function.arguments}}`; + }) + .join("\n"); + } return message.content; } return message.content @@ -78,6 +118,7 @@ class Chat { OPENAI_API_KEY: string; messages: Message[]; sysMessageContent: string; + toolsString: string; total_tokens: number; max_tokens: number; max_gen_tokens: number; @@ -96,6 +137,7 @@ class Chat { OPENAI_API_KEY: string | undefined, { systemMessage = "", + toolsString = "", max_tokens = 4096, max_gen_tokens = 2048, enable_max_gen_tokens = true, @@ -121,6 +163,7 @@ class Chat { this.enable_max_gen_tokens = enable_max_gen_tokens; this.tokens_margin = tokens_margin; this.sysMessageContent = systemMessage; + this.toolsString = toolsString; this.apiEndpoint = apiEndPoint; this.model = model; this.temperature = temperature; @@ -178,6 +221,25 @@ class Chat { body["max_tokens"] = this.max_gen_tokens; } + // parse toolsString to function call format + const ts = this.toolsString.trim(); + if (ts) { + try { + const fcList: any[] = JSON.parse(ts); + body["tools"] = fcList.map((fc) => { + return { + type: "function", + function: fc, + }; + }); + } catch (e) { + console.log("toolsString parse error"); + throw ( + "Function call toolsString parse error, not a valied json list: " + e + ); + } + } + return fetch(this.apiEndpoint, { method: "POST", headers: { @@ -224,7 +286,7 @@ class Chat { console.log("line", line); try { const jsonStr = line.slice("data:".length).trim(); - const json = JSON.parse(jsonStr); + const json = JSON.parse(jsonStr) as StreamingResponseChunk; yield json; } catch (e) { console.log(`Chunk parse error at: ${line}`); @@ -234,7 +296,7 @@ class Chat { } } - processFetchResponse(resp: FetchResponse): string { + processFetchResponse(resp: FetchResponse): Message { if (resp.error !== undefined) { throw JSON.stringify(resp.error); } @@ -249,15 +311,19 @@ class Chat { this.forgetSomeMessages(); } - return ( - (resp?.choices[0]?.message?.content as string) ?? - `Error: ${JSON.stringify(resp)}` - ); - } + let content = resp.choices[0].message?.content ?? ""; + if ( + !resp.choices[0]?.message?.content && + !resp.choices[0]?.message?.tool_calls + ) { + content = `Unparsed response: ${JSON.stringify(resp)}`; + } - async complete(): Promise { - const resp = await this.fetch(); - return this.processFetchResponse(resp); + return { + role: "assistant", + content, + tool_calls: resp?.choices[0]?.message?.tool_calls, + }; } completeWithSteam() { diff --git a/src/global.css b/src/global.css index f67213a..c951762 100644 --- a/src/global.css +++ b/src/global.css @@ -28,6 +28,6 @@ body::-webkit-scrollbar { display: none; } -p.message-content { +.message-content { white-space: pre-wrap; } diff --git a/src/message.tsx b/src/message.tsx index 3fdb679..f163b74 100644 --- a/src/message.tsx +++ b/src/message.tsx @@ -12,35 +12,97 @@ interface EditMessageProps { setChatStore: (cs: ChatStore) => void; } +export const isVailedJSON = (str: string): boolean => { + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; +}; + function EditMessage(props: EditMessageProps) { const { setShowEdit, chat, setChatStore, chatStore } = props; return (
setShowEdit(false)} > -
+
{ + event.stopPropagation(); + }} + > {typeof chat.content === "string" ? ( - +
+ {chat.tool_call_id && ( + + + { + chat.tool_call_id = event.target.value; + setChatStore({ ...chatStore }); + }} + /> + + )} + {chat.tool_calls && + chat.tool_calls.map((tool_call) => ( +
+ + + + + + + + + + + + Vailed JSON:{" "} + {isVailedJSON(tool_call.function.arguments) ? "🆗" : "❌"} + + + +
+
+ ))} + +
) : (
); - const CopyIcon = () => { + const copyToClipboard = (text: string) => { + navigator.clipboard.writeText(text); + setShowCopiedHint(true); + setTimeout(() => setShowCopiedHint(false), 1000); + }; + + const CopyIcon = ({ textToCopy }: { textToCopy: string }) => { return ( <>
{showEdit && ( diff --git a/src/settings.tsx b/src/settings.tsx index 1aa901a..119d595 100644 --- a/src/settings.tsx +++ b/src/settings.tsx @@ -5,6 +5,7 @@ import models from "./models"; import { TemplateChatStore } from "./chatbox"; import { tr, Tr, langCodeContext, LANG_OPTIONS } from "./translate"; import p from "preact-markdown"; +import { isVailedJSON } from "./message"; const TTS_VOICES: string[] = [ "alloy", @@ -60,7 +61,7 @@ const SelectModel = (props: { const LongInput = (props: { chatStore: ChatStore; setChatStore: (cs: ChatStore) => void; - field: "systemMessageContent"; + field: "systemMessageContent" | "toolsString"; help: string; }) => { return ( @@ -373,6 +374,15 @@ export default (props: { help="系统消息,用于指示ChatGPT的角色和一些前置条件,例如“你是一个有帮助的人工智能助理”,或者“你是一个专业英语翻译,把我的话全部翻译成英语”,详情参考 OPEAN AI API 文档" {...props} /> + + Valied JSON:{" "} + {isVailedJSON(props.chatStore.toolsString) ? "🆗" : "❌"} + +