WIP: LLM 风险判断

This commit is contained in:
2024-09-20 18:33:33 +08:00
parent fff2a32d7e
commit 341435e603
2 changed files with 107 additions and 0 deletions

View File

@@ -0,0 +1,105 @@
from os import system
from typing import Iterable, Required
import openai
import asyncio
from openai.types.chat import ChatCompletionMessageParam
import tqdm
import json
import hashlib
from cucyuqing.utils import print
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
from cucyuqing.pg import pool, get_cur
from cucyuqing.mysql import mysql
async def main():
await pool.open()
print(await batch_risk_analyze(["你是老师", "我是初音未来"]))
async def get_docs() -> list[dict]:
# [TODO]
raise NotImplemented
await mysql.execute("""
""")
return []
async def batch_risk_analyze(
texts: list, model: str = "gpt-4o-mini", threads: int = 10
) -> list:
tasks = [{"input": text} for text in texts]
bar = tqdm.tqdm(total=len(tasks))
queue = asyncio.Queue()
async def lmm_worker():
while True:
task = await queue.get()
if task is None:
break
task["response"] = await risk_analyze(task["input"], model)
queue.task_done()
bar.update(1)
async def producer():
for task in tasks:
await queue.put(task)
workers = [asyncio.create_task(lmm_worker()) for _ in range(threads)]
await producer()
await queue.join()
for _ in workers:
await queue.put(None)
await asyncio.gather(*workers)
return [task["response"] for task in tasks]
async def risk_analyze(text: str, model: str) -> str:
llm = openai.AsyncOpenAI(
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
)
system_message = (
"你是一个新闻风险分析器,你要判断以下文本是否有风险,你只要回答是或者否。"
)
hash = hashlib.md5(
model.encode() + b"|" + text.encode() + b"|" + system_message.encode()
).hexdigest()
# 查询缓存
async with get_cur() as cur:
await cur.execute(
"SELECT response FROM llm_cache WHERE id = %s LIMIT 1", (hash,)
)
row = await cur.fetchone()
if row:
return row[0]
messages: Iterable[ChatCompletionMessageParam] = [
{"role": "system", "content": system_message},
{"role": "user", "content": text},
]
resp = await llm.chat.completions.create(
messages=messages,
model=model,
temperature=0,
stop="\n",
)
completions = resp.choices[0].message.content or ""
usage = resp.usage.model_dump_json() if resp.usage else None
# 缓存结果
async with get_cur() as cur:
await cur.execute(
"INSERT INTO llm_cache (id, messages, model, response, usage) VALUES (%s, %s, %s, %s, %s)",
(hash, json.dumps(messages, ensure_ascii=False), model, completions, usage),
)
return completions
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -24,3 +24,5 @@ PG_DSN = must_get_env("PG_DSN")
MYSQL_DSN = must_get_env("MYSQL_DSN") MYSQL_DSN = must_get_env("MYSQL_DSN")
OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY") OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY")
OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1") OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
OPENAI_RISK_LLM_API_KEY = must_get_env("OPENAI_RISK_LLM_API_KEY")
OPENAI_RISK_LLM_BASE_URL = get_env_with_default("OPENAI_RISK_LLM_BASE_URL", "https://api.openai.com/v1")