WIP: LLM 风险判断
This commit is contained in:
105
cucyuqing/cmd/risk-analyze.py
Normal file
105
cucyuqing/cmd/risk-analyze.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
from os import system
|
||||||
|
from typing import Iterable, Required
|
||||||
|
import openai
|
||||||
|
import asyncio
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
import tqdm
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
from cucyuqing.utils import print
|
||||||
|
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
||||||
|
from cucyuqing.pg import pool, get_cur
|
||||||
|
from cucyuqing.mysql import mysql
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await pool.open()
|
||||||
|
print(await batch_risk_analyze(["你是老师", "我是初音未来"]))
|
||||||
|
|
||||||
|
async def get_docs() -> list[dict]:
|
||||||
|
# [TODO]
|
||||||
|
raise NotImplemented
|
||||||
|
await mysql.execute("""
|
||||||
|
""")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def batch_risk_analyze(
|
||||||
|
texts: list, model: str = "gpt-4o-mini", threads: int = 10
|
||||||
|
) -> list:
|
||||||
|
tasks = [{"input": text} for text in texts]
|
||||||
|
bar = tqdm.tqdm(total=len(tasks))
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def lmm_worker():
|
||||||
|
while True:
|
||||||
|
task = await queue.get()
|
||||||
|
if task is None:
|
||||||
|
break
|
||||||
|
task["response"] = await risk_analyze(task["input"], model)
|
||||||
|
queue.task_done()
|
||||||
|
bar.update(1)
|
||||||
|
|
||||||
|
async def producer():
|
||||||
|
for task in tasks:
|
||||||
|
await queue.put(task)
|
||||||
|
|
||||||
|
workers = [asyncio.create_task(lmm_worker()) for _ in range(threads)]
|
||||||
|
|
||||||
|
await producer()
|
||||||
|
await queue.join()
|
||||||
|
for _ in workers:
|
||||||
|
await queue.put(None)
|
||||||
|
await asyncio.gather(*workers)
|
||||||
|
|
||||||
|
return [task["response"] for task in tasks]
|
||||||
|
|
||||||
|
|
||||||
|
async def risk_analyze(text: str, model: str) -> str:
|
||||||
|
llm = openai.AsyncOpenAI(
|
||||||
|
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
||||||
|
)
|
||||||
|
system_message = (
|
||||||
|
"你是一个新闻风险分析器,你要判断以下文本是否有风险,你只要回答是或者否。"
|
||||||
|
)
|
||||||
|
|
||||||
|
hash = hashlib.md5(
|
||||||
|
model.encode() + b"|" + text.encode() + b"|" + system_message.encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
# 查询缓存
|
||||||
|
async with get_cur() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"SELECT response FROM llm_cache WHERE id = %s LIMIT 1", (hash,)
|
||||||
|
)
|
||||||
|
row = await cur.fetchone()
|
||||||
|
if row:
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
messages: Iterable[ChatCompletionMessageParam] = [
|
||||||
|
{"role": "system", "content": system_message},
|
||||||
|
{"role": "user", "content": text},
|
||||||
|
]
|
||||||
|
resp = await llm.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=0,
|
||||||
|
stop="\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
completions = resp.choices[0].message.content or ""
|
||||||
|
usage = resp.usage.model_dump_json() if resp.usage else None
|
||||||
|
|
||||||
|
# 缓存结果
|
||||||
|
async with get_cur() as cur:
|
||||||
|
await cur.execute(
|
||||||
|
"INSERT INTO llm_cache (id, messages, model, response, usage) VALUES (%s, %s, %s, %s, %s)",
|
||||||
|
(hash, json.dumps(messages, ensure_ascii=False), model, completions, usage),
|
||||||
|
)
|
||||||
|
|
||||||
|
return completions
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -24,3 +24,5 @@ PG_DSN = must_get_env("PG_DSN")
|
|||||||
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
||||||
OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY")
|
OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY")
|
||||||
OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
|
OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
OPENAI_RISK_LLM_API_KEY = must_get_env("OPENAI_RISK_LLM_API_KEY")
|
||||||
|
OPENAI_RISK_LLM_BASE_URL = get_env_with_default("OPENAI_RISK_LLM_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
|||||||
Reference in New Issue
Block a user