Files
cucyuqing/cucyuqing/cmd/risk-analyze.py

169 lines
4.7 KiB
Python

from dataclasses import dataclass
from os import system
from typing import Iterable, Required
from typing_extensions import Doc
import openai
import asyncio
from openai.types.chat import ChatCompletionMessageParam
import tqdm
import json
import hashlib
from cucyuqing.utils import print
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
from cucyuqing.pg import pool, get_cur
from cucyuqing.mysql import mysql
from cucyuqing.dbscan import Document, run_dbscan
async def main():
await asyncio.gather(
pool.open(),
mysql.connect(),
)
# 获取一个风险类型和对应的提示词
risk_types = await get_risk_type_prompt()
print("共有风险类型:", len(risk_types))
dbscan_result = await run_dbscan()
docs = [cluster[0] for cluster in dbscan_result.clusters]
print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise))
risks_to_update: dict[str, set[str]] = {}
analyze_result = await batch_risk_analyze(docs, risk_types)
for task in analyze_result:
if "" not in task.response:
continue
print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}")
# 合并每个文档的风险到一个set
if task.doc.id not in risks_to_update:
risks_to_update[task.doc.id] = set()
risks_to_update[task.doc.id].add(task.risk_type.name)
# 更新数据库
for doc_id, risks in risks_to_update.items():
await mysql.execute(
"""
UPDATE risk_news
SET risk_types = :risk_types, updated_at = now()
WHERE es_id = :es_id
""",
{
"es_id": doc_id,
"risk_types": json.dumps(list(risks), ensure_ascii=False),
},
)
@dataclass
class RiskType:
name: str
prompt: str
@dataclass
class Task:
doc: Document
risk_type: RiskType
response: str = ""
async def get_risk_type_prompt() -> list[RiskType]:
"""从数据库中获取风险类型和对应的提示词"""
rows = await mysql.fetch_all(
"""
SELECT rp.content, rt.name
FROM risk_prompt rp
JOIN risk_type rt ON rp.risk_type_id = rt.id
ORDER BY rp.id DESC
"""
)
return [RiskType(prompt=row[0], name=row[1]) for row in rows]
async def batch_risk_analyze(
docs: list[Document],
risk_types: list[RiskType],
model: str = "gpt-4o-mini",
threads: int = 10,
) -> list[Task]:
"""文本风险分析(并行批批处理)"""
# 从 docs, risk_types 两个列表交叉生成任务列表
tasks: list[Task] = [
Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types
]
bar = tqdm.tqdm(total=len(tasks))
queue = asyncio.Queue()
async def lmm_worker():
while True:
task = await queue.get()
if task is None:
break
task.response = await risk_analyze(task, model)
queue.task_done()
bar.update(1)
async def producer():
for task in tasks:
await queue.put(task)
workers = [asyncio.create_task(lmm_worker()) for _ in range(threads)]
await producer()
await queue.join()
for _ in workers:
await queue.put(None)
await asyncio.gather(*workers)
return tasks
async def risk_analyze(task: Task, model: str) -> str:
"""对一条文本进行风险分析"""
llm = openai.AsyncOpenAI(
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
)
hash = hashlib.md5(
f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode()
).hexdigest()
# 查询缓存
async with get_cur() as cur:
await cur.execute(
"SELECT response FROM llm_cache WHERE id = %s LIMIT 1", (hash,)
)
row = await cur.fetchone()
if row:
return row[0]
messages: Iterable[ChatCompletionMessageParam] = [
{"role": "system", "content": task.risk_type.prompt},
{"role": "user", "content": task.doc.get_text_for_llm()},
]
resp = await llm.chat.completions.create(
messages=messages,
model=model,
temperature=0,
stop="\n",
max_tokens=10,
)
completions = resp.choices[0].message.content or ""
usage = resp.usage.model_dump_json() if resp.usage else None
# 缓存结果
async with get_cur() as cur:
await cur.execute(
"INSERT INTO llm_cache (id, messages, model, response, usage) VALUES (%s, %s, %s, %s, %s)",
(hash, json.dumps(messages, ensure_ascii=False), model, completions, usage),
)
return completions
if __name__ == "__main__":
asyncio.run(main())