From d52f37e5f0be23bde2556999e41e34d6c96f61ef Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Fri, 18 Oct 2024 15:35:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9A=82=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 二位数组转一维,文档ID标记的问题还没解决 --- cucyuqing/cmd/risk-analyze.py | 76 +++++++++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/cucyuqing/cmd/risk-analyze.py b/cucyuqing/cmd/risk-analyze.py index 4d2ebe9..24c314c 100644 --- a/cucyuqing/cmd/risk-analyze.py +++ b/cucyuqing/cmd/risk-analyze.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from os import system from typing import Iterable, Required import openai @@ -14,18 +15,63 @@ from cucyuqing.dbscan import run_dbscan async def main(): - await pool.open() + await asyncio.gather( + pool.open(), + mysql.connect(), + ) + # 获取一个风险类型和对应的提示词 + risk_types = await get_risk_type_prompt() + print("共有风险类型:", len(risk_types)) + dbscan_result = await run_dbscan() docs = [cluster[0] for cluster in dbscan_result.clusters] - analyze_rusult = await batch_risk_analyze([doc.title for doc in docs]) - for result, doc in zip(analyze_rusult, docs): + + analyze_result = await batch_risk_analyze([doc.title for doc in docs], risk_types) + for result, doc in zip(analyze_result, docs): + if "是" not in result: + continue print(f"风险: {result} 标题: {doc.title}") +@dataclass +class RiskType: + name: str + prompt: str + + +@dataclass +class Task: + text: str + risk_type: RiskType + response: str = "" + + +async def get_risk_type_prompt() -> list[RiskType]: + """从数据库中获取风险类型和对应的提示词""" + rows = await mysql.fetch_all( + """ + SELECT rp.content, rt.name + FROM risk_prompt rp + JOIN risk_type rt ON rp.risk_type_id = rt.id + ORDER BY rp.id DESC + """ + ) + + return [RiskType(prompt=row[0], name=row[1]) for row in rows] + + async def batch_risk_analyze( - texts: list, model: str = "gpt-4o-mini", threads: int = 10 -) -> list: - tasks = [{"input": text} for text in texts] + texts: list, + risk_types: list[RiskType], + model: str = "gpt-4o-mini", + threads: int = 10, +) -> list[str]: + """文本风险分析(并行批批处理)""" + + # 从 text, risk_types 两个列表交叉生成任务列表 + tasks: list[Task] = [ + Task(text=text, risk_type=rt) for text in texts for rt in risk_types + ] bar = tqdm.tqdm(total=len(tasks)) queue = asyncio.Queue() @@ -34,7 +80,7 @@ async def batch_risk_analyze( task = await queue.get() if task is None: break - task["response"] = await risk_analyze(task["input"], model) + task.response = await risk_analyze(task, model) queue.task_done() bar.update(1) @@ -50,19 +96,16 @@ async def batch_risk_analyze( await queue.put(None) await asyncio.gather(*workers) - return [task["response"] for task in tasks] + return [task.response for task in tasks] -async def risk_analyze(text: str, model: str) -> str: +async def risk_analyze(task: Task, model: str) -> str: + """对一条文本进行风险分析""" llm = openai.AsyncOpenAI( api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL ) - system_message = ( - "你是一个新闻风险分析器,你要判断以下文本是否有风险,你只要回答是或者否。" - ) - hash = hashlib.md5( - model.encode() + b"|" + text.encode() + b"|" + system_message.encode() + f"{model}|{task.text}|{task.risk_type.prompt}".encode() ).hexdigest() # 查询缓存 @@ -75,14 +118,15 @@ async def risk_analyze(text: str, model: str) -> str: return row[0] messages: Iterable[ChatCompletionMessageParam] = [ - {"role": "system", "content": system_message}, - {"role": "user", "content": text}, + {"role": "system", "content": task.risk_type.prompt}, + {"role": "user", "content": task.text}, ] resp = await llm.chat.completions.create( messages=messages, model=model, temperature=0, stop="\n", + max_tokens=10, ) completions = resp.choices[0].message.content or ""