diff --git a/cucyuqing/cmd/risk-analyze.py b/cucyuqing/cmd/risk-analyze.py new file mode 100644 index 0000000..ac7c5a4 --- /dev/null +++ b/cucyuqing/cmd/risk-analyze.py @@ -0,0 +1,105 @@ +from os import system +from typing import Iterable, Required +import openai +import asyncio +from openai.types.chat import ChatCompletionMessageParam +import tqdm +import json +import hashlib +from cucyuqing.utils import print +from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL +from cucyuqing.pg import pool, get_cur +from cucyuqing.mysql import mysql + + +async def main(): + await pool.open() + print(await batch_risk_analyze(["你是老师", "我是初音未来"])) + +async def get_docs() -> list[dict]: + # [TODO] + raise NotImplemented + await mysql.execute(""" + """) + return [] + + + +async def batch_risk_analyze( + texts: list, model: str = "gpt-4o-mini", threads: int = 10 +) -> list: + tasks = [{"input": text} for text in texts] + bar = tqdm.tqdm(total=len(tasks)) + queue = asyncio.Queue() + + async def lmm_worker(): + while True: + task = await queue.get() + if task is None: + break + task["response"] = await risk_analyze(task["input"], model) + queue.task_done() + bar.update(1) + + async def producer(): + for task in tasks: + await queue.put(task) + + workers = [asyncio.create_task(lmm_worker()) for _ in range(threads)] + + await producer() + await queue.join() + for _ in workers: + await queue.put(None) + await asyncio.gather(*workers) + + return [task["response"] for task in tasks] + + +async def risk_analyze(text: str, model: str) -> str: + llm = openai.AsyncOpenAI( + api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL + ) + system_message = ( + "你是一个新闻风险分析器,你要判断以下文本是否有风险,你只要回答是或者否。" + ) + + hash = hashlib.md5( + model.encode() + b"|" + text.encode() + b"|" + system_message.encode() + ).hexdigest() + + # 查询缓存 + async with get_cur() as cur: + await cur.execute( + "SELECT response FROM llm_cache WHERE id = %s LIMIT 1", (hash,) + ) + row = await cur.fetchone() + if row: + return row[0] + + messages: Iterable[ChatCompletionMessageParam] = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": text}, + ] + resp = await llm.chat.completions.create( + messages=messages, + model=model, + temperature=0, + stop="\n", + ) + + completions = resp.choices[0].message.content or "" + usage = resp.usage.model_dump_json() if resp.usage else None + + # 缓存结果 + async with get_cur() as cur: + await cur.execute( + "INSERT INTO llm_cache (id, messages, model, response, usage) VALUES (%s, %s, %s, %s, %s)", + (hash, json.dumps(messages, ensure_ascii=False), model, completions, usage), + ) + + return completions + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cucyuqing/config.py b/cucyuqing/config.py index b1812c8..3eee5a7 100644 --- a/cucyuqing/config.py +++ b/cucyuqing/config.py @@ -24,3 +24,5 @@ PG_DSN = must_get_env("PG_DSN") MYSQL_DSN = must_get_env("MYSQL_DSN") OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY") OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1") +OPENAI_RISK_LLM_API_KEY = must_get_env("OPENAI_RISK_LLM_API_KEY") +OPENAI_RISK_LLM_BASE_URL = get_env_with_default("OPENAI_RISK_LLM_BASE_URL", "https://api.openai.com/v1")