diff --git a/cucyuqing/cmd/risk-analyze.py b/cucyuqing/cmd/risk-analyze.py index 24c314c..14d16a6 100644 --- a/cucyuqing/cmd/risk-analyze.py +++ b/cucyuqing/cmd/risk-analyze.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from os import system from typing import Iterable, Required +from typing_extensions import Doc import openai import asyncio from openai.types.chat import ChatCompletionMessageParam @@ -11,7 +12,7 @@ from cucyuqing.utils import print from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL from cucyuqing.pg import pool, get_cur from cucyuqing.mysql import mysql -from cucyuqing.dbscan import run_dbscan +from cucyuqing.dbscan import Document, run_dbscan async def main(): @@ -26,11 +27,11 @@ async def main(): dbscan_result = await run_dbscan() docs = [cluster[0] for cluster in dbscan_result.clusters] - analyze_result = await batch_risk_analyze([doc.title for doc in docs], risk_types) - for result, doc in zip(analyze_result, docs): - if "是" not in result: + analyze_result = await batch_risk_analyze(docs, risk_types) + for task in analyze_result: + if "是" not in task.response: continue - print(f"风险: {result} 标题: {doc.title}") + print(f"风险: {task.risk_type.name} 标题: {task.doc.title}") @dataclass @@ -41,7 +42,7 @@ class RiskType: @dataclass class Task: - text: str + doc: Document risk_type: RiskType response: str = "" @@ -61,16 +62,16 @@ async def get_risk_type_prompt() -> list[RiskType]: async def batch_risk_analyze( - texts: list, + docs: list[Document], risk_types: list[RiskType], model: str = "gpt-4o-mini", threads: int = 10, -) -> list[str]: +) -> list[Task]: """文本风险分析(并行批批处理)""" - # 从 text, risk_types 两个列表交叉生成任务列表 + # 从 docs, risk_types 两个列表交叉生成任务列表 tasks: list[Task] = [ - Task(text=text, risk_type=rt) for text in texts for rt in risk_types + Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types ] bar = tqdm.tqdm(total=len(tasks)) queue = asyncio.Queue() @@ -96,7 +97,7 @@ async def batch_risk_analyze( await queue.put(None) await asyncio.gather(*workers) - return [task.response for task in tasks] + return tasks async def risk_analyze(task: Task, model: str) -> str: @@ -105,7 +106,7 @@ async def risk_analyze(task: Task, model: str) -> str: api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL ) hash = hashlib.md5( - f"{model}|{task.text}|{task.risk_type.prompt}".encode() + f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode() ).hexdigest() # 查询缓存 @@ -119,7 +120,7 @@ async def risk_analyze(task: Task, model: str) -> str: messages: Iterable[ChatCompletionMessageParam] = [ {"role": "system", "content": task.risk_type.prompt}, - {"role": "user", "content": task.text}, + {"role": "user", "content": task.doc.get_text_for_llm()}, ] resp = await llm.chat.completions.create( messages=messages, diff --git a/cucyuqing/dbscan.py b/cucyuqing/dbscan.py index a02cf34..0de484f 100644 --- a/cucyuqing/dbscan.py +++ b/cucyuqing/dbscan.py @@ -15,6 +15,14 @@ class Document: content: str similarity: float = 0.0 + def get_text_for_llm(self) -> str: + """只使用标题进行风险分析 + + 对于空标题,在入库时已经处理过。 + 如果入库时标题为空,则使用content的前20个字符或第一句中文作为标题。 + """ + return self.title + @dataclass class DBScanResult: noise: list[Document]