169 lines
4.7 KiB
Python
169 lines
4.7 KiB
Python
from dataclasses import dataclass
|
|
from os import system
|
|
from typing import Iterable, Required
|
|
from typing_extensions import Doc
|
|
import openai
|
|
import asyncio
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
import tqdm
|
|
import json
|
|
import hashlib
|
|
from cucyuqing.utils import print
|
|
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
|
from cucyuqing.pg import pool, get_cur
|
|
from cucyuqing.mysql import mysql
|
|
from cucyuqing.dbscan import Document, run_dbscan
|
|
|
|
|
|
async def main():
|
|
await asyncio.gather(
|
|
pool.open(),
|
|
mysql.connect(),
|
|
)
|
|
# 获取一个风险类型和对应的提示词
|
|
risk_types = await get_risk_type_prompt()
|
|
print("共有风险类型:", len(risk_types))
|
|
|
|
dbscan_result = await run_dbscan()
|
|
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
|
print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise))
|
|
|
|
risks_to_update: dict[str, set[str]] = {}
|
|
analyze_result = await batch_risk_analyze(docs, risk_types)
|
|
for task in analyze_result:
|
|
if "是" not in task.response:
|
|
continue
|
|
print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}")
|
|
|
|
# 合并每个文档的风险到一个set
|
|
if task.doc.id not in risks_to_update:
|
|
risks_to_update[task.doc.id] = set()
|
|
risks_to_update[task.doc.id].add(task.risk_type.name)
|
|
|
|
# 更新数据库
|
|
for doc_id, risks in risks_to_update.items():
|
|
await mysql.execute(
|
|
"""
|
|
UPDATE risk_news
|
|
SET risk_types = :risk_types, updated_at = now()
|
|
WHERE es_id = :es_id
|
|
""",
|
|
{
|
|
"es_id": doc_id,
|
|
"risk_types": json.dumps(list(risks), ensure_ascii=False),
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class RiskType:
|
|
name: str
|
|
prompt: str
|
|
|
|
|
|
@dataclass
|
|
class Task:
|
|
doc: Document
|
|
risk_type: RiskType
|
|
response: str = ""
|
|
|
|
|
|
async def get_risk_type_prompt() -> list[RiskType]:
|
|
"""从数据库中获取风险类型和对应的提示词"""
|
|
rows = await mysql.fetch_all(
|
|
"""
|
|
SELECT rp.content, rt.name
|
|
FROM risk_prompt rp
|
|
JOIN risk_type rt ON rp.risk_type_id = rt.id
|
|
ORDER BY rp.id DESC
|
|
"""
|
|
)
|
|
|
|
return [RiskType(prompt=row[0], name=row[1]) for row in rows]
|
|
|
|
|
|
async def batch_risk_analyze(
|
|
docs: list[Document],
|
|
risk_types: list[RiskType],
|
|
model: str = "gpt-4o-mini",
|
|
threads: int = 10,
|
|
) -> list[Task]:
|
|
"""文本风险分析(并行批批处理)"""
|
|
|
|
# 从 docs, risk_types 两个列表交叉生成任务列表
|
|
tasks: list[Task] = [
|
|
Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types
|
|
]
|
|
bar = tqdm.tqdm(total=len(tasks))
|
|
queue = asyncio.Queue()
|
|
|
|
async def lmm_worker():
|
|
while True:
|
|
task = await queue.get()
|
|
if task is None:
|
|
break
|
|
task.response = await risk_analyze(task, model)
|
|
queue.task_done()
|
|
bar.update(1)
|
|
|
|
async def producer():
|
|
for task in tasks:
|
|
await queue.put(task)
|
|
|
|
workers = [asyncio.create_task(lmm_worker()) for _ in range(threads)]
|
|
|
|
await producer()
|
|
await queue.join()
|
|
for _ in workers:
|
|
await queue.put(None)
|
|
await asyncio.gather(*workers)
|
|
|
|
return tasks
|
|
|
|
|
|
async def risk_analyze(task: Task, model: str) -> str:
|
|
"""对一条文本进行风险分析"""
|
|
llm = openai.AsyncOpenAI(
|
|
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
|
)
|
|
hash = hashlib.md5(
|
|
f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode()
|
|
).hexdigest()
|
|
|
|
# 查询缓存
|
|
async with get_cur() as cur:
|
|
await cur.execute(
|
|
"SELECT response FROM llm_cache WHERE id = %s LIMIT 1", (hash,)
|
|
)
|
|
row = await cur.fetchone()
|
|
if row:
|
|
return row[0]
|
|
|
|
messages: Iterable[ChatCompletionMessageParam] = [
|
|
{"role": "system", "content": task.risk_type.prompt},
|
|
{"role": "user", "content": task.doc.get_text_for_llm()},
|
|
]
|
|
resp = await llm.chat.completions.create(
|
|
messages=messages,
|
|
model=model,
|
|
temperature=0,
|
|
stop="\n",
|
|
max_tokens=10,
|
|
)
|
|
|
|
completions = resp.choices[0].message.content or ""
|
|
usage = resp.usage.model_dump_json() if resp.usage else None
|
|
|
|
# 缓存结果
|
|
async with get_cur() as cur:
|
|
await cur.execute(
|
|
"INSERT INTO llm_cache (id, messages, model, response, usage) VALUES (%s, %s, %s, %s, %s)",
|
|
(hash, json.dumps(messages, ensure_ascii=False), model, completions, usage),
|
|
)
|
|
|
|
return completions
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|