export interface Message { role: "system" | "user" | "assistant" | "function"; content: string; name?: "example_user" | "example_assistant"; } export interface ChunkMessage { model: string; choices: { delta: { role: "assitant" | undefined; content: string | undefined }; }[]; } export interface FetchResponse { error?: any; id: string; object: string; created: number; model: string; usage: { prompt_tokens: number | undefined; completion_tokens: number | undefined; total_tokens: number | undefined; }; choices: { message: Message | undefined; finish_reason: "stop" | "length"; index: number | undefined; }[]; } // https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them export function calculate_token_length(content: string): number { const totalCount = content.length; const chineseCount = content.match(/[\u00ff-\uffff]|\S+/g)?.length ?? 0; const englishCount = totalCount - chineseCount; const tokenLength = englishCount / 4 + (chineseCount * 4) / 3; return ~~tokenLength; } class Chat { OPENAI_API_KEY: string; messages: Message[]; sysMessageContent: string; total_tokens: number; max_tokens: number; tokens_margin: number; apiEndpoint: string; model: string; temperature: number; top_p: number; presence_penalty: number; frequency_penalty: number; constructor( OPENAI_API_KEY: string | undefined, { systemMessage = "你是一个有用的人工智能助理", max_tokens = 4096, tokens_margin = 1024, apiEndPoint = "https://api.openai.com/v1/chat/completions", model = "gpt-3.5-turbo", temperature = 0.7, top_p = 1, presence_penalty = 0, frequency_penalty = 0, } = {}, ) { if (OPENAI_API_KEY === undefined) { throw "OPENAI_API_KEY is undefined"; } this.OPENAI_API_KEY = OPENAI_API_KEY; this.messages = []; this.total_tokens = calculate_token_length(systemMessage); this.max_tokens = max_tokens; this.tokens_margin = tokens_margin; this.sysMessageContent = systemMessage; this.apiEndpoint = apiEndPoint; this.model = model; this.temperature = temperature; this.top_p = top_p; this.presence_penalty = presence_penalty; this.frequency_penalty = frequency_penalty; } _fetch(stream = false) { // perform role type check let hasNonSystemMessage = false; for (const msg of this.messages) { if (msg.role === "system" && !hasNonSystemMessage) { continue; } if (!hasNonSystemMessage) { hasNonSystemMessage = true; continue; } if (msg.role === "system") { console.log( "Warning: detected system message in the middle of history", ); } } for (const msg of this.messages) { if (msg.name && msg.role !== "system") { console.log( "Warning: detected message where name field set but role is system", ); } } return fetch(this.apiEndpoint, { method: "POST", headers: { Authorization: `Bearer ${this.OPENAI_API_KEY}`, "Content-Type": "application/json", }, body: JSON.stringify({ model: this.model, messages: [ { role: "system", content: this.sysMessageContent }, ...this.messages, ], stream, temperature: this.temperature, top_p: this.top_p, presence_penalty: this.presence_penalty, frequency_penalty: this.frequency_penalty, }), }); } async fetch(): Promise { const resp = await this._fetch(); const j = await resp.json(); if (j.error !== undefined) { throw JSON.stringify(j.error); } return j; } async say(content: string): Promise { this.messages.push({ role: "user", content }); await this.complete(); return this.messages.slice(-1)[0].content; } async *processStreamResponse(resp: Response) { const reader = resp?.body?.pipeThrough(new TextDecoderStream()).getReader(); if (reader === undefined) { console.log("reader is undefined"); return; } let receiving = true; while (receiving) { let lastText = ""; const { value, done } = await reader.read(); if (done) break; const lines = (lastText + value) .trim() .split("\n") .filter((line) => line.trim()) .map((line) => line.slice("data:".length)) .map((line) => line.trim()) .filter((i) => { if (i === "[DONE]") { receiving = false; return false; } return true; }); const jsons: ChunkMessage[] = lines .map((line) => { try { const ret = JSON.parse(line.trim()); lastText = ""; return ret; } catch (e) { console.log(`Chunk parse error at: ${line}`); lastText = line; return null; } }) .filter((i) => i.choices[0].delta.content); for (const j of jsons) { yield j; } } } processFetchResponse(resp: FetchResponse): string { if (resp.error !== undefined) { throw JSON.stringify(resp.error); } this.total_tokens = resp?.usage?.total_tokens ?? 0; if (resp?.choices[0]?.message) { this.messages.push(resp?.choices[0]?.message); } if (resp.choices[0]?.finish_reason === "length") { this.forceForgetSomeMessages(); } else { this.forgetSomeMessages(); } return ( resp?.choices[0]?.message?.content ?? `Error: ${JSON.stringify(resp)}` ); } async complete(): Promise { const resp = await this.fetch(); return this.processFetchResponse(resp); } completeWithSteam() { this.total_tokens = this.messages .map((msg) => this.calculate_token_length(msg.content) + 20) .reduce((a, v) => a + v); return this._fetch(true); } calculate_token_length(content: string): number { return calculate_token_length(content); } user(...messages: string[]) { for (const msg of messages) { this.messages.push({ role: "user", content: msg }); this.total_tokens += this.calculate_token_length(msg); this.forgetSomeMessages(); } } assistant(...messages: string[]) { for (const msg of messages) { this.messages.push({ role: "assistant", content: msg }); this.total_tokens += this.calculate_token_length(msg); this.forgetSomeMessages(); } } forgetSomeMessages() { // forget occur condition if (this.total_tokens + this.tokens_margin >= this.max_tokens) { this.forceForgetSomeMessages(); } } forceForgetSomeMessages() { this.messages = [ ...this.messages.slice(Math.max(~~(this.messages.length / 4), 2)), ]; } forgetAllMessage() { this.messages = []; } stats(): string { return ( `total_tokens: ${this.total_tokens}` + "\n" + `max_tokens: ${this.max_tokens}` + "\n" + `tokens_margin: ${this.tokens_margin}` + "\n" + `messages.length: ${this.messages.length}` ); } } export default Chat;