diff --git a/src/app.tsx b/src/app.tsx index 3ad6f4c..5533d31 100644 --- a/src/app.tsx +++ b/src/app.tsx @@ -2,7 +2,7 @@ import { IDBPDatabase, openDB } from "idb"; import { useEffect, useState } from "preact/hooks"; import "./global.css"; -import { calculate_token_length, Message } from "./chatgpt"; +import { calculate_token_length, Logprobs, Message } from "./chatgpt"; import getDefaultParams from "./getDefaultParam"; import ChatBOX from "./chatbox"; import models, { defaultModel } from "./models"; @@ -15,6 +15,7 @@ export interface ChatStoreMessage extends Message { token: number; example: boolean; audio: Blob | null; + logprobs: Logprobs | null; } export interface TemplateAPI { @@ -63,6 +64,7 @@ export interface ChatStore { image_gen_api: string; image_gen_key: string; json_mode: boolean; + logprobs: boolean; } const _defaultAPIEndpoint = "https://api.openai.com/v1/chat/completions"; @@ -84,7 +86,8 @@ export const newChatStore = ( toolsString = "", image_gen_api = "https://api.openai.com/v1/images/generations", image_gen_key = "", - json_mode = false + json_mode = false, + logprobs = true ): ChatStore => { return { chatgpt_api_web_version: CHATGPT_API_WEB_VERSION, @@ -124,6 +127,7 @@ export const newChatStore = ( image_gen_key: image_gen_key, json_mode: json_mode, tts_format: tts_format, + logprobs, }; }; @@ -285,7 +289,8 @@ export function App() { chatStore.toolsString, chatStore.image_gen_api, chatStore.image_gen_key, - chatStore.json_mode + chatStore.json_mode, + chatStore.logprobs ) ); setSelectedChatIndex(newKey as number); diff --git a/src/chatbox.tsx b/src/chatbox.tsx index 44018c0..d0cbae1 100644 --- a/src/chatbox.tsx +++ b/src/chatbox.tsx @@ -22,6 +22,7 @@ import ChatGPT, { Message as MessageType, MessageDetail, ToolCall, + Logprobs, } from "./chatgpt"; import Message from "./message"; import models from "./models"; @@ -82,15 +83,29 @@ export default function ChatBOX(props: { const allChunkMessage: string[] = []; const allChunkTool: ToolCall[] = []; setShowGenerating(true); + const logprobs: Logprobs = { + content: [], + }; for await (const i of client.processStreamResponse(response)) { chatStore.responseModelName = i.model; responseTokenCount += 1; - // skip if choice is empty (e.g. azure) - if (!i.choices[0]) continue; + const c = i.choices[0]; - allChunkMessage.push(i.choices[0].delta.content ?? ""); - const tool_calls = i.choices[0].delta.tool_calls; + // skip if choice is empty (e.g. azure) + if (!c) continue; + + const logprob = c?.logprobs?.content[0]?.logprob; + if (logprob !== undefined) { + logprobs.content.push({ + token: c.delta.content ?? "", + logprob, + }); + console.log(c.delta.content, logprob); + } + + allChunkMessage.push(c.delta.content ?? ""); + const tool_calls = c.delta.tool_calls; if (tool_calls) { for (const tool_call of tool_calls) { // init @@ -149,6 +164,7 @@ export default function ChatBOX(props: { chatStore.cost += cost; addTotalCost(cost); + console.log("save logprobs", logprobs); const newMsg: ChatStoreMessage = { role: "assistant", content, @@ -156,6 +172,7 @@ export default function ChatBOX(props: { token: responseTokenCount, example: false, audio: null, + logprobs, }; if (allChunkTool.length > 0) newMsg.tool_calls = allChunkTool; @@ -210,6 +227,7 @@ export default function ChatBOX(props: { data.usage.completion_tokens ?? calculate_token_length(msg.content), example: false, audio: null, + logprobs: data.choices[0]?.logprobs, }); setShowGenerating(false); }; @@ -257,7 +275,10 @@ export default function ChatBOX(props: { try { setShowGenerating(true); - const response = await client._fetch(chatStore.streamMode); + const response = await client._fetch( + chatStore.streamMode, + chatStore.logprobs + ); const contentType = response.headers.get("content-type"); if (contentType?.startsWith("text/event-stream")) { await _completeWithStreamMode(response); @@ -306,6 +327,7 @@ export default function ChatBOX(props: { token: calculate_token_length(inputMsg.trim()), example: false, audio: null, + logprobs: null, }); // manually calculate token length @@ -972,6 +994,7 @@ export default function ChatBOX(props: { hide: false, example: false, audio: null, + logprobs: null, }); update_total_tokens(); setInputMsg(""); @@ -1066,6 +1089,7 @@ export default function ChatBOX(props: { hide: false, example: false, audio: null, + logprobs: null, }); update_total_tokens(); setChatStore({ ...chatStore }); diff --git a/src/chatgpt.ts b/src/chatgpt.ts index b27b673..3c84a86 100644 --- a/src/chatgpt.ts +++ b/src/chatgpt.ts @@ -35,6 +35,16 @@ interface Choices { index: number; delta: Delta; finish_reason: string | null; + logprobs: Logprobs | null; +} + +export interface Logprobs { + content: LogprobsContent[]; +} + +interface LogprobsContent { + token: string; + logprob: number; } export interface StreamingResponseChunk { @@ -85,6 +95,7 @@ export interface FetchResponse { message: Message | undefined; finish_reason: "stop" | "length"; index: number | undefined; + logprobs: Logprobs | null; }[]; } @@ -174,7 +185,7 @@ class Chat { this.json_mode = json_mode; } - _fetch(stream = false) { + _fetch(stream = false, logprobs = false) { // perform role type check let hasNonSystemMessage = false; for (const msg of this.messages) { @@ -208,6 +219,7 @@ class Chat { model: this.model, messages, stream, + logprobs, presence_penalty: this.presence_penalty, frequency_penalty: this.frequency_penalty, }; @@ -253,15 +265,6 @@ class Chat { }); } - async fetch(): Promise { - const resp = await this._fetch(); - const j = await resp.json(); - if (j.error !== undefined) { - throw JSON.stringify(j.error); - } - return j; - } - async *processStreamResponse(resp: Response) { const reader = resp?.body?.pipeThrough(new TextDecoderStream()).getReader(); if (reader === undefined) { diff --git a/src/logprob.tsx b/src/logprob.tsx new file mode 100644 index 0000000..610ccd7 --- /dev/null +++ b/src/logprob.tsx @@ -0,0 +1,16 @@ +import React from "react"; + +const logprobToColor = (logprob: number) => { + // 将logprob转换为百分比 + const percent = Math.exp(logprob) * 100; + + // 计算颜色值 + // 绿色的RGB值为(0, 255, 0),红色的RGB值为(255, 0, 0) + const red = Math.round(255 * (1 - percent / 100)); + const green = Math.round(255 * (percent / 100)); + const color = `rgb(${red}, ${green}, 0)`; + + return color; +}; + +export default logprobToColor; diff --git a/src/message.tsx b/src/message.tsx index 7ccff1d..7cb378d 100644 --- a/src/message.tsx +++ b/src/message.tsx @@ -9,6 +9,7 @@ import { MessageDetail } from "./messageDetail"; import { MessageToolCall } from "./messageToolCall"; import { MessageToolResp } from "./messageToolResp"; import { EditMessage } from "./editMessage"; +import logprobToColor from "./logprob"; export const isVailedJSON = (str: string): boolean => { try { @@ -32,6 +33,7 @@ export default function Message(props: Props) { const [showEdit, setShowEdit] = useState(false); const [showCopiedHint, setShowCopiedHint] = useState(false); const [renderMarkdown, setRenderWorkdown] = useState(false); + const [renderColor, setRenderColor] = useState(false); const DeleteIcon = () => (