fix: 二维task文档匹配问题
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from os import system
|
||||
from typing import Iterable, Required
|
||||
from typing_extensions import Doc
|
||||
import openai
|
||||
import asyncio
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
@@ -11,7 +12,7 @@ from cucyuqing.utils import print
|
||||
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
||||
from cucyuqing.pg import pool, get_cur
|
||||
from cucyuqing.mysql import mysql
|
||||
from cucyuqing.dbscan import run_dbscan
|
||||
from cucyuqing.dbscan import Document, run_dbscan
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -26,11 +27,11 @@ async def main():
|
||||
dbscan_result = await run_dbscan()
|
||||
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
||||
|
||||
analyze_result = await batch_risk_analyze([doc.title for doc in docs], risk_types)
|
||||
for result, doc in zip(analyze_result, docs):
|
||||
if "是" not in result:
|
||||
analyze_result = await batch_risk_analyze(docs, risk_types)
|
||||
for task in analyze_result:
|
||||
if "是" not in task.response:
|
||||
continue
|
||||
print(f"风险: {result} 标题: {doc.title}")
|
||||
print(f"风险: {task.risk_type.name} 标题: {task.doc.title}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -41,7 +42,7 @@ class RiskType:
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
text: str
|
||||
doc: Document
|
||||
risk_type: RiskType
|
||||
response: str = ""
|
||||
|
||||
@@ -61,16 +62,16 @@ async def get_risk_type_prompt() -> list[RiskType]:
|
||||
|
||||
|
||||
async def batch_risk_analyze(
|
||||
texts: list,
|
||||
docs: list[Document],
|
||||
risk_types: list[RiskType],
|
||||
model: str = "gpt-4o-mini",
|
||||
threads: int = 10,
|
||||
) -> list[str]:
|
||||
) -> list[Task]:
|
||||
"""文本风险分析(并行批批处理)"""
|
||||
|
||||
# 从 text, risk_types 两个列表交叉生成任务列表
|
||||
# 从 docs, risk_types 两个列表交叉生成任务列表
|
||||
tasks: list[Task] = [
|
||||
Task(text=text, risk_type=rt) for text in texts for rt in risk_types
|
||||
Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types
|
||||
]
|
||||
bar = tqdm.tqdm(total=len(tasks))
|
||||
queue = asyncio.Queue()
|
||||
@@ -96,7 +97,7 @@ async def batch_risk_analyze(
|
||||
await queue.put(None)
|
||||
await asyncio.gather(*workers)
|
||||
|
||||
return [task.response for task in tasks]
|
||||
return tasks
|
||||
|
||||
|
||||
async def risk_analyze(task: Task, model: str) -> str:
|
||||
@@ -105,7 +106,7 @@ async def risk_analyze(task: Task, model: str) -> str:
|
||||
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
||||
)
|
||||
hash = hashlib.md5(
|
||||
f"{model}|{task.text}|{task.risk_type.prompt}".encode()
|
||||
f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode()
|
||||
).hexdigest()
|
||||
|
||||
# 查询缓存
|
||||
@@ -119,7 +120,7 @@ async def risk_analyze(task: Task, model: str) -> str:
|
||||
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
{"role": "system", "content": task.risk_type.prompt},
|
||||
{"role": "user", "content": task.text},
|
||||
{"role": "user", "content": task.doc.get_text_for_llm()},
|
||||
]
|
||||
resp = await llm.chat.completions.create(
|
||||
messages=messages,
|
||||
|
||||
@@ -15,6 +15,14 @@ class Document:
|
||||
content: str
|
||||
similarity: float = 0.0
|
||||
|
||||
def get_text_for_llm(self) -> str:
|
||||
"""只使用标题进行风险分析
|
||||
|
||||
对于空标题,在入库时已经处理过。
|
||||
如果入库时标题为空,则使用content的前20个字符或第一句中文作为标题。
|
||||
"""
|
||||
return self.title
|
||||
|
||||
@dataclass
|
||||
class DBScanResult:
|
||||
noise: list[Document]
|
||||
|
||||
Reference in New Issue
Block a user