support logprobs

This commit is contained in:
2024-02-23 19:00:20 +08:00
parent d01d7c747b
commit c9c51a85cf
6 changed files with 143 additions and 64 deletions

View File

@@ -2,7 +2,7 @@ import { IDBPDatabase, openDB } from "idb";
import { useEffect, useState } from "preact/hooks"; import { useEffect, useState } from "preact/hooks";
import "./global.css"; import "./global.css";
import { calculate_token_length, Message } from "./chatgpt"; import { calculate_token_length, Logprobs, Message } from "./chatgpt";
import getDefaultParams from "./getDefaultParam"; import getDefaultParams from "./getDefaultParam";
import ChatBOX from "./chatbox"; import ChatBOX from "./chatbox";
import models, { defaultModel } from "./models"; import models, { defaultModel } from "./models";
@@ -15,6 +15,7 @@ export interface ChatStoreMessage extends Message {
token: number; token: number;
example: boolean; example: boolean;
audio: Blob | null; audio: Blob | null;
logprobs: Logprobs | null;
} }
export interface TemplateAPI { export interface TemplateAPI {
@@ -63,6 +64,7 @@ export interface ChatStore {
image_gen_api: string; image_gen_api: string;
image_gen_key: string; image_gen_key: string;
json_mode: boolean; json_mode: boolean;
logprobs: boolean;
} }
const _defaultAPIEndpoint = "https://api.openai.com/v1/chat/completions"; const _defaultAPIEndpoint = "https://api.openai.com/v1/chat/completions";
@@ -84,7 +86,8 @@ export const newChatStore = (
toolsString = "", toolsString = "",
image_gen_api = "https://api.openai.com/v1/images/generations", image_gen_api = "https://api.openai.com/v1/images/generations",
image_gen_key = "", image_gen_key = "",
json_mode = false json_mode = false,
logprobs = true
): ChatStore => { ): ChatStore => {
return { return {
chatgpt_api_web_version: CHATGPT_API_WEB_VERSION, chatgpt_api_web_version: CHATGPT_API_WEB_VERSION,
@@ -124,6 +127,7 @@ export const newChatStore = (
image_gen_key: image_gen_key, image_gen_key: image_gen_key,
json_mode: json_mode, json_mode: json_mode,
tts_format: tts_format, tts_format: tts_format,
logprobs,
}; };
}; };
@@ -285,7 +289,8 @@ export function App() {
chatStore.toolsString, chatStore.toolsString,
chatStore.image_gen_api, chatStore.image_gen_api,
chatStore.image_gen_key, chatStore.image_gen_key,
chatStore.json_mode chatStore.json_mode,
chatStore.logprobs
) )
); );
setSelectedChatIndex(newKey as number); setSelectedChatIndex(newKey as number);

View File

@@ -22,6 +22,7 @@ import ChatGPT, {
Message as MessageType, Message as MessageType,
MessageDetail, MessageDetail,
ToolCall, ToolCall,
Logprobs,
} from "./chatgpt"; } from "./chatgpt";
import Message from "./message"; import Message from "./message";
import models from "./models"; import models from "./models";
@@ -82,15 +83,29 @@ export default function ChatBOX(props: {
const allChunkMessage: string[] = []; const allChunkMessage: string[] = [];
const allChunkTool: ToolCall[] = []; const allChunkTool: ToolCall[] = [];
setShowGenerating(true); setShowGenerating(true);
const logprobs: Logprobs = {
content: [],
};
for await (const i of client.processStreamResponse(response)) { for await (const i of client.processStreamResponse(response)) {
chatStore.responseModelName = i.model; chatStore.responseModelName = i.model;
responseTokenCount += 1; responseTokenCount += 1;
// skip if choice is empty (e.g. azure) const c = i.choices[0];
if (!i.choices[0]) continue;
allChunkMessage.push(i.choices[0].delta.content ?? ""); // skip if choice is empty (e.g. azure)
const tool_calls = i.choices[0].delta.tool_calls; if (!c) continue;
const logprob = c?.logprobs?.content[0]?.logprob;
if (logprob !== undefined) {
logprobs.content.push({
token: c.delta.content ?? "",
logprob,
});
console.log(c.delta.content, logprob);
}
allChunkMessage.push(c.delta.content ?? "");
const tool_calls = c.delta.tool_calls;
if (tool_calls) { if (tool_calls) {
for (const tool_call of tool_calls) { for (const tool_call of tool_calls) {
// init // init
@@ -149,6 +164,7 @@ export default function ChatBOX(props: {
chatStore.cost += cost; chatStore.cost += cost;
addTotalCost(cost); addTotalCost(cost);
console.log("save logprobs", logprobs);
const newMsg: ChatStoreMessage = { const newMsg: ChatStoreMessage = {
role: "assistant", role: "assistant",
content, content,
@@ -156,6 +172,7 @@ export default function ChatBOX(props: {
token: responseTokenCount, token: responseTokenCount,
example: false, example: false,
audio: null, audio: null,
logprobs,
}; };
if (allChunkTool.length > 0) newMsg.tool_calls = allChunkTool; if (allChunkTool.length > 0) newMsg.tool_calls = allChunkTool;
@@ -210,6 +227,7 @@ export default function ChatBOX(props: {
data.usage.completion_tokens ?? calculate_token_length(msg.content), data.usage.completion_tokens ?? calculate_token_length(msg.content),
example: false, example: false,
audio: null, audio: null,
logprobs: data.choices[0]?.logprobs,
}); });
setShowGenerating(false); setShowGenerating(false);
}; };
@@ -257,7 +275,10 @@ export default function ChatBOX(props: {
try { try {
setShowGenerating(true); setShowGenerating(true);
const response = await client._fetch(chatStore.streamMode); const response = await client._fetch(
chatStore.streamMode,
chatStore.logprobs
);
const contentType = response.headers.get("content-type"); const contentType = response.headers.get("content-type");
if (contentType?.startsWith("text/event-stream")) { if (contentType?.startsWith("text/event-stream")) {
await _completeWithStreamMode(response); await _completeWithStreamMode(response);
@@ -306,6 +327,7 @@ export default function ChatBOX(props: {
token: calculate_token_length(inputMsg.trim()), token: calculate_token_length(inputMsg.trim()),
example: false, example: false,
audio: null, audio: null,
logprobs: null,
}); });
// manually calculate token length // manually calculate token length
@@ -972,6 +994,7 @@ export default function ChatBOX(props: {
hide: false, hide: false,
example: false, example: false,
audio: null, audio: null,
logprobs: null,
}); });
update_total_tokens(); update_total_tokens();
setInputMsg(""); setInputMsg("");
@@ -1066,6 +1089,7 @@ export default function ChatBOX(props: {
hide: false, hide: false,
example: false, example: false,
audio: null, audio: null,
logprobs: null,
}); });
update_total_tokens(); update_total_tokens();
setChatStore({ ...chatStore }); setChatStore({ ...chatStore });

View File

@@ -35,6 +35,16 @@ interface Choices {
index: number; index: number;
delta: Delta; delta: Delta;
finish_reason: string | null; finish_reason: string | null;
logprobs: Logprobs | null;
}
export interface Logprobs {
content: LogprobsContent[];
}
interface LogprobsContent {
token: string;
logprob: number;
} }
export interface StreamingResponseChunk { export interface StreamingResponseChunk {
@@ -85,6 +95,7 @@ export interface FetchResponse {
message: Message | undefined; message: Message | undefined;
finish_reason: "stop" | "length"; finish_reason: "stop" | "length";
index: number | undefined; index: number | undefined;
logprobs: Logprobs | null;
}[]; }[];
} }
@@ -174,7 +185,7 @@ class Chat {
this.json_mode = json_mode; this.json_mode = json_mode;
} }
_fetch(stream = false) { _fetch(stream = false, logprobs = false) {
// perform role type check // perform role type check
let hasNonSystemMessage = false; let hasNonSystemMessage = false;
for (const msg of this.messages) { for (const msg of this.messages) {
@@ -208,6 +219,7 @@ class Chat {
model: this.model, model: this.model,
messages, messages,
stream, stream,
logprobs,
presence_penalty: this.presence_penalty, presence_penalty: this.presence_penalty,
frequency_penalty: this.frequency_penalty, frequency_penalty: this.frequency_penalty,
}; };
@@ -253,15 +265,6 @@ class Chat {
}); });
} }
async fetch(): Promise<FetchResponse> {
const resp = await this._fetch();
const j = await resp.json();
if (j.error !== undefined) {
throw JSON.stringify(j.error);
}
return j;
}
async *processStreamResponse(resp: Response) { async *processStreamResponse(resp: Response) {
const reader = resp?.body?.pipeThrough(new TextDecoderStream()).getReader(); const reader = resp?.body?.pipeThrough(new TextDecoderStream()).getReader();
if (reader === undefined) { if (reader === undefined) {

16
src/logprob.tsx Normal file
View File

@@ -0,0 +1,16 @@
import React from "react";
const logprobToColor = (logprob: number) => {
// 将logprob转换为百分比
const percent = Math.exp(logprob) * 100;
// 计算颜色值
// 绿色的RGB值为(0, 255, 0)红色的RGB值为(255, 0, 0)
const red = Math.round(255 * (1 - percent / 100));
const green = Math.round(255 * (percent / 100));
const color = `rgb(${red}, ${green}, 0)`;
return color;
};
export default logprobToColor;

View File

@@ -9,6 +9,7 @@ import { MessageDetail } from "./messageDetail";
import { MessageToolCall } from "./messageToolCall"; import { MessageToolCall } from "./messageToolCall";
import { MessageToolResp } from "./messageToolResp"; import { MessageToolResp } from "./messageToolResp";
import { EditMessage } from "./editMessage"; import { EditMessage } from "./editMessage";
import logprobToColor from "./logprob";
export const isVailedJSON = (str: string): boolean => { export const isVailedJSON = (str: string): boolean => {
try { try {
@@ -32,6 +33,7 @@ export default function Message(props: Props) {
const [showEdit, setShowEdit] = useState(false); const [showEdit, setShowEdit] = useState(false);
const [showCopiedHint, setShowCopiedHint] = useState(false); const [showCopiedHint, setShowCopiedHint] = useState(false);
const [renderMarkdown, setRenderWorkdown] = useState(false); const [renderMarkdown, setRenderWorkdown] = useState(false);
const [renderColor, setRenderColor] = useState(false);
const DeleteIcon = () => ( const DeleteIcon = () => (
<button <button
onClick={() => { onClick={() => {
@@ -125,7 +127,21 @@ export default function Message(props: Props) {
{ {
// only show when content is string or list of message // only show when content is string or list of message
// this check is used to avoid rendering tool call // this check is used to avoid rendering tool call
chat.content && getMessageText(chat) chat.content &&
(chat.logprobs && renderColor
? chat.logprobs.content
.filter((c) => c.token)
.map((c) => (
<div
style={{
color: logprobToColor(c.logprob),
display: "inline",
}}
>
{c.token}
</div>
))
: getMessageText(chat))
} }
</div> </div>
)} )}
@@ -200,6 +216,10 @@ export default function Message(props: Props) {
<label className="dark:text-white">{Tr("render")}</label> <label className="dark:text-white">{Tr("render")}</label>
<input type="checkbox" checked={renderMarkdown} /> <input type="checkbox" checked={renderMarkdown} />
</span> </span>
<span onClick={(event: any) => setRenderColor(!renderColor)}>
<label className="dark:text-white">{Tr("color")}</label>
<input type="checkbox" checked={renderColor} />
</span>
</div> </div>
)} )}
</div> </div>

View File

@@ -47,45 +47,54 @@ const SelectModel = (props: {
setChatStore: (cs: ChatStore) => void; setChatStore: (cs: ChatStore) => void;
help: string; help: string;
}) => { }) => {
let shouldIUseCustomModel: boolean = true let shouldIUseCustomModel: boolean = true;
for (const model in models) { for (const model in models) {
if (props.chatStore.model === model) { if (props.chatStore.model === model) {
shouldIUseCustomModel = false shouldIUseCustomModel = false;
} }
} }
const [useCustomModel, setUseCustomModel] = useState(shouldIUseCustomModel); const [useCustomModel, setUseCustomModel] = useState(shouldIUseCustomModel);
return ( return (
<Help help={props.help}> <Help help={props.help}>
<label className="m-2 p-2">Model</label> <label className="m-2 p-2">Model</label>
<span onClick={() => { <span
setUseCustomModel(!useCustomModel); onClick={() => {
}} className="m-2 p-2"> setUseCustomModel(!useCustomModel);
}}
className="m-2 p-2"
>
<label>{Tr("Custom")}</label> <label>{Tr("Custom")}</label>
<input className="" type="checkbox" checked={useCustomModel} /> <input className="" type="checkbox" checked={useCustomModel} />
</span> </span>
{ {useCustomModel ? (
useCustomModel ? <input
<input className="m-2 p-2 border rounded focus w-32 md:w-fit"
className="m-2 p-2 border rounded focus w-32 md:w-fit" value={props.chatStore.model}
value={props.chatStore.model} onChange={(event: any) => { onChange={(event: any) => {
const model = event.target.value as string; const model = event.target.value as string;
props.chatStore.model = model; props.chatStore.model = model;
props.setChatStore({ ...props.chatStore }); props.setChatStore({ ...props.chatStore });
}} /> : <select }}
className="m-2 p-2" />
value={props.chatStore.model} ) : (
onChange={(event: any) => { <select
const model = event.target.value as string; className="m-2 p-2"
props.chatStore.model = model; value={props.chatStore.model}
props.chatStore.maxTokens = getDefaultParams('max', models[model].maxToken); onChange={(event: any) => {
props.setChatStore({ ...props.chatStore }); const model = event.target.value as string;
}} props.chatStore.model = model;
> props.chatStore.maxTokens = getDefaultParams(
{Object.keys(models).map((opt) => ( "max",
<option value={opt}>{opt}</option> models[model].maxToken
))} );
</select> props.setChatStore({ ...props.chatStore });
} }}
>
{Object.keys(models).map((opt) => (
<option value={opt}>{opt}</option>
))}
</select>
)}
</Help> </Help>
); );
}; };
@@ -118,14 +127,14 @@ const Input = (props: {
chatStore: ChatStore; chatStore: ChatStore;
setChatStore: (cs: ChatStore) => void; setChatStore: (cs: ChatStore) => void;
field: field:
| "apiKey" | "apiKey"
| "apiEndpoint" | "apiEndpoint"
| "whisper_api" | "whisper_api"
| "whisper_key" | "whisper_key"
| "tts_api" | "tts_api"
| "tts_key" | "tts_key"
| "image_gen_api" | "image_gen_api"
| "image_gen_key"; | "image_gen_key";
help: string; help: string;
}) => { }) => {
const [hideInput, setHideInput] = useState(true); const [hideInput, setHideInput] = useState(true);
@@ -225,13 +234,13 @@ const Number = (props: {
chatStore: ChatStore; chatStore: ChatStore;
setChatStore: (cs: ChatStore) => void; setChatStore: (cs: ChatStore) => void;
field: field:
| "totalTokens" | "totalTokens"
| "maxTokens" | "maxTokens"
| "maxGenTokens" | "maxGenTokens"
| "tokenMargin" | "tokenMargin"
| "postBeginIndex" | "postBeginIndex"
| "presence_penalty" | "presence_penalty"
| "frequency_penalty"; | "frequency_penalty";
readOnly: boolean; readOnly: boolean;
help: string; help: string;
}) => { }) => {
@@ -275,7 +284,7 @@ const Number = (props: {
const Choice = (props: { const Choice = (props: {
chatStore: ChatStore; chatStore: ChatStore;
setChatStore: (cs: ChatStore) => void; setChatStore: (cs: ChatStore) => void;
field: "streamMode" | "develop_mode" | "json_mode"; field: "streamMode" | "develop_mode" | "json_mode" | "logprobs";
help: string; help: string;
}) => { }) => {
return ( return (
@@ -319,7 +328,8 @@ export default (props: {
location.pathname + location.pathname +
`?key=${encodeURIComponent( `?key=${encodeURIComponent(
props.chatStore.apiKey props.chatStore.apiKey
)}&api=${encodeURIComponent(props.chatStore.apiEndpoint)}&mode=${props.chatStore.streamMode ? "stream" : "fetch" )}&api=${encodeURIComponent(props.chatStore.apiEndpoint)}&mode=${
props.chatStore.streamMode ? "stream" : "fetch"
}&model=${props.chatStore.model}&sys=${encodeURIComponent( }&model=${props.chatStore.model}&sys=${encodeURIComponent(
props.chatStore.systemMessageContent props.chatStore.systemMessageContent
)}`; )}`;
@@ -467,6 +477,7 @@ export default (props: {
help="流模式,使用 stream mode 将可以动态看到生成内容,但无法准确计算 token 数量,在 token 数量过多时可能会裁切过多或过少历史消息" help="流模式,使用 stream mode 将可以动态看到生成内容,但无法准确计算 token 数量,在 token 数量过多时可能会裁切过多或过少历史消息"
{...props} {...props}
/> />
<Choice field="logprobs" help="返回每个Token的概率" {...props} />
<Choice <Choice
field="develop_mode" field="develop_mode"
help="开发者模式,开启后会显示更多选项及功能" help="开发者模式,开启后会显示更多选项及功能"