From 968e6602f7f6d9e750193259068f79f83518a192 Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Thu, 10 Aug 2023 18:36:29 +0800 Subject: [PATCH] async generator stream --- src/chatbox.tsx | 127 ++++++++++++++---------------------------------- src/chatgpt.ts | 51 +++++++++++++++++-- 2 files changed, 85 insertions(+), 93 deletions(-) diff --git a/src/chatbox.tsx b/src/chatbox.tsx index 9f9a9df..2c9a602 100644 --- a/src/chatbox.tsx +++ b/src/chatbox.tsx @@ -54,101 +54,48 @@ export default function ChatBOX(props: { const _completeWithStreamMode = async (response: Response) => { let responseTokenCount = 0; chatStore.streamMode = true; - // call api, return reponse text - console.log("response", response); - const reader = response.body?.getReader(); const allChunkMessage: string[] = []; - new ReadableStream({ - async start() { - let lastText = ""; - while (true) { - let responseDone = false; - let state = await reader?.read(); - let done = state?.done; - let value = state?.value; - if (done) break; - let text = lastText + new TextDecoder().decode(value); - // console.log("text:", text); - const lines = text - .trim() - .split("\n") - .map((line) => line.trim()) - .filter((i) => { - if (!i) return false; - if (i === "data: [DONE]" || i === "data:[DONE]") { - responseDone = true; - responseTokenCount += 1; - return false; - } - return true; - }); - responseTokenCount += lines.length; - console.log("lines", lines); - const jsons: ChunkMessage[] = lines - .map((line) => { - try { - const ret = JSON.parse(line.trim().slice("data:".length)); - lastText = ""; - return ret; - } catch (e) { - console.log(`Chunk parse error at: ${line}`); - lastText = line; - return null; - } - }) - .filter((i) => i); - console.log("jsons", jsons); - for (const { model } of jsons) { - if (model) chatStore.responseModelName = model; - } - const chunkText = jsons - .map((j) => j.choices[0].delta.content ?? "") - .join(""); - // console.log("chunk text", chunkText); - allChunkMessage.push(chunkText); - setShowGenerating(true); - setGeneratingMessage(allChunkMessage.join("")); - if (responseDone) break; - } - setShowGenerating(false); + setShowGenerating(true); + for await (const i of client.processStreamResponse(response)) { + responseTokenCount += 1; + allChunkMessage.push(i.choices[0].delta.content ?? ""); + setGeneratingMessage(allChunkMessage.join("")); + } + setShowGenerating(false); + const content = allChunkMessage.join(""); - // console.log("push to history", allChunkMessage); - const content = allChunkMessage.join(""); + // estimate cost + let cost = 0; + if (chatStore.responseModelName) { + cost += + responseTokenCount * + (models[chatStore.responseModelName]?.price?.completion ?? 0); + let sum = 0; + for (const msg of chatStore.history + .filter(({ hide }) => !hide) + .slice(chatStore.postBeginIndex)) { + sum += msg.token; + } + cost += sum * (models[chatStore.responseModelName]?.price?.prompt ?? 0); + } - // estimate cost - let cost = 0; - if (chatStore.responseModelName) { - cost += - responseTokenCount * - (models[chatStore.responseModelName]?.price?.completion ?? 0); - let sum = 0; - for (const msg of chatStore.history - .filter(({ hide }) => !hide) - .slice(chatStore.postBeginIndex)) { - sum += msg.token; - } - cost += - sum * (models[chatStore.responseModelName]?.price?.prompt ?? 0); - } - chatStore.cost += cost; - addTotalCost(cost); + chatStore.cost += cost; + addTotalCost(cost); - chatStore.history.push({ - role: "assistant", - content, - hide: false, - token: responseTokenCount, - example: false, - }); - // manually copy status from client to chatStore - chatStore.maxTokens = client.max_tokens; - chatStore.tokenMargin = client.tokens_margin; - update_total_tokens(); - setChatStore({ ...chatStore }); - setGeneratingMessage(""); - setShowGenerating(false); - }, + chatStore.history.push({ + role: "assistant", + content, + hide: false, + token: responseTokenCount, + example: false, }); + // manually copy status from client to chatStore + chatStore.maxTokens = client.max_tokens; + chatStore.tokenMargin = client.tokens_margin; + update_total_tokens(); + setChatStore({ ...chatStore }); + setGeneratingMessage(""); + setShowGenerating(false); }; const _completeWithFetchMode = async (response: Response) => { diff --git a/src/chatgpt.ts b/src/chatgpt.ts index 8d87d41..afbcd22 100644 --- a/src/chatgpt.ts +++ b/src/chatgpt.ts @@ -63,7 +63,7 @@ class Chat { top_p = 1, presence_penalty = 0, frequency_penalty = 0, - } = {} + } = {}, ) { if (OPENAI_API_KEY === undefined) { throw "OPENAI_API_KEY is undefined"; @@ -95,14 +95,14 @@ class Chat { } if (msg.role === "system") { console.log( - "Warning: detected system message in the middle of history" + "Warning: detected system message in the middle of history", ); } } for (const msg of this.messages) { if (msg.name && msg.role !== "system") { console.log( - "Warning: detected message where name field set but role is system" + "Warning: detected message where name field set but role is system", ); } } @@ -127,6 +127,7 @@ class Chat { }); } + async fetch(): Promise { const resp = await this._fetch(); const j = await resp.json(); @@ -142,6 +143,50 @@ class Chat { return this.messages.slice(-1)[0].content; } + async *processStreamResponse(resp: Response) { + const reader = resp?.body?.pipeThrough(new TextDecoderStream()).getReader(); + if (reader === undefined) { + console.log("reader is undefined"); + return; + } + let receiving = true; + while (receiving) { + let lastText = ""; + const { value, done } = await reader.read(); + if (done) break; + const lines = (lastText + value) + .trim() + .split("\n") + .filter((line) => line.trim()) + .map((line) => line.slice("data:".length)) + .map((line) => line.trim()) + .filter((i) => { + if (i === "[DONE]") { + receiving = false; + return false; + } + return true; + }); + const jsons: ChunkMessage[] = lines + .map((line) => { + try { + const ret = JSON.parse(line.trim()); + lastText = ""; + return ret; + } catch (e) { + console.log(`Chunk parse error at: ${line}`); + lastText = line; + return null; + } + }) + .filter((i) => i.choices[0].delta.content); + + for (const j of jsons) { + yield j; + } + } + } + processFetchResponse(resp: FetchResponse): string { if (resp.error !== undefined) { throw JSON.stringify(resp.error);