WIP: LLM 风险判断
This commit is contained in:
105
cucyuqing/cmd/risk-analyze.py
Normal file
105
cucyuqing/cmd/risk-analyze.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from os import system
|
||||
from typing import Iterable, Required
|
||||
import openai
|
||||
import asyncio
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
import tqdm
|
||||
import json
|
||||
import hashlib
|
||||
from cucyuqing.utils import print
|
||||
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
||||
from cucyuqing.pg import pool, get_cur
|
||||
from cucyuqing.mysql import mysql
|
||||
|
||||
|
||||
async def main():
|
||||
await pool.open()
|
||||
print(await batch_risk_analyze(["你是老师", "我是初音未来"]))
|
||||
|
||||
async def get_docs() -> list[dict]:
|
||||
# [TODO]
|
||||
raise NotImplemented
|
||||
await mysql.execute("""
|
||||
""")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
async def batch_risk_analyze(
|
||||
texts: list, model: str = "gpt-4o-mini", threads: int = 10
|
||||
) -> list:
|
||||
tasks = [{"input": text} for text in texts]
|
||||
bar = tqdm.tqdm(total=len(tasks))
|
||||
queue = asyncio.Queue()
|
||||
|
||||
async def lmm_worker():
|
||||
while True:
|
||||
task = await queue.get()
|
||||
if task is None:
|
||||
break
|
||||
task["response"] = await risk_analyze(task["input"], model)
|
||||
queue.task_done()
|
||||
bar.update(1)
|
||||
|
||||
async def producer():
|
||||
for task in tasks:
|
||||
await queue.put(task)
|
||||
|
||||
workers = [asyncio.create_task(lmm_worker()) for _ in range(threads)]
|
||||
|
||||
await producer()
|
||||
await queue.join()
|
||||
for _ in workers:
|
||||
await queue.put(None)
|
||||
await asyncio.gather(*workers)
|
||||
|
||||
return [task["response"] for task in tasks]
|
||||
|
||||
|
||||
async def risk_analyze(text: str, model: str) -> str:
|
||||
llm = openai.AsyncOpenAI(
|
||||
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
||||
)
|
||||
system_message = (
|
||||
"你是一个新闻风险分析器,你要判断以下文本是否有风险,你只要回答是或者否。"
|
||||
)
|
||||
|
||||
hash = hashlib.md5(
|
||||
model.encode() + b"|" + text.encode() + b"|" + system_message.encode()
|
||||
).hexdigest()
|
||||
|
||||
# 查询缓存
|
||||
async with get_cur() as cur:
|
||||
await cur.execute(
|
||||
"SELECT response FROM llm_cache WHERE id = %s LIMIT 1", (hash,)
|
||||
)
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
resp = await llm.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=0,
|
||||
stop="\n",
|
||||
)
|
||||
|
||||
completions = resp.choices[0].message.content or ""
|
||||
usage = resp.usage.model_dump_json() if resp.usage else None
|
||||
|
||||
# 缓存结果
|
||||
async with get_cur() as cur:
|
||||
await cur.execute(
|
||||
"INSERT INTO llm_cache (id, messages, model, response, usage) VALUES (%s, %s, %s, %s, %s)",
|
||||
(hash, json.dumps(messages, ensure_ascii=False), model, completions, usage),
|
||||
)
|
||||
|
||||
return completions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -24,3 +24,5 @@ PG_DSN = must_get_env("PG_DSN")
|
||||
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
||||
OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY")
|
||||
OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
|
||||
OPENAI_RISK_LLM_API_KEY = must_get_env("OPENAI_RISK_LLM_API_KEY")
|
||||
OPENAI_RISK_LLM_BASE_URL = get_env_with_default("OPENAI_RISK_LLM_BASE_URL", "https://api.openai.com/v1")
|
||||
|
||||
Reference in New Issue
Block a user