fix: 二维task文档匹配问题

This commit is contained in:
2024-10-18 15:47:18 +08:00
parent d52f37e5f0
commit 47ae4dc8d5
2 changed files with 22 additions and 13 deletions

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass
from os import system
from typing import Iterable, Required
from typing_extensions import Doc
import openai
import asyncio
from openai.types.chat import ChatCompletionMessageParam
@@ -11,7 +12,7 @@ from cucyuqing.utils import print
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
from cucyuqing.pg import pool, get_cur
from cucyuqing.mysql import mysql
from cucyuqing.dbscan import run_dbscan
from cucyuqing.dbscan import Document, run_dbscan
async def main():
@@ -26,11 +27,11 @@ async def main():
dbscan_result = await run_dbscan()
docs = [cluster[0] for cluster in dbscan_result.clusters]
analyze_result = await batch_risk_analyze([doc.title for doc in docs], risk_types)
for result, doc in zip(analyze_result, docs):
if "" not in result:
analyze_result = await batch_risk_analyze(docs, risk_types)
for task in analyze_result:
if "" not in task.response:
continue
print(f"风险: {result} 标题: {doc.title}")
print(f"风险: {task.risk_type.name} 标题: {task.doc.title}")
@dataclass
@@ -41,7 +42,7 @@ class RiskType:
@dataclass
class Task:
text: str
doc: Document
risk_type: RiskType
response: str = ""
@@ -61,16 +62,16 @@ async def get_risk_type_prompt() -> list[RiskType]:
async def batch_risk_analyze(
texts: list,
docs: list[Document],
risk_types: list[RiskType],
model: str = "gpt-4o-mini",
threads: int = 10,
) -> list[str]:
) -> list[Task]:
"""文本风险分析(并行批批处理)"""
# 从 text, risk_types 两个列表交叉生成任务列表
# 从 docs, risk_types 两个列表交叉生成任务列表
tasks: list[Task] = [
Task(text=text, risk_type=rt) for text in texts for rt in risk_types
Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types
]
bar = tqdm.tqdm(total=len(tasks))
queue = asyncio.Queue()
@@ -96,7 +97,7 @@ async def batch_risk_analyze(
await queue.put(None)
await asyncio.gather(*workers)
return [task.response for task in tasks]
return tasks
async def risk_analyze(task: Task, model: str) -> str:
@@ -105,7 +106,7 @@ async def risk_analyze(task: Task, model: str) -> str:
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
)
hash = hashlib.md5(
f"{model}|{task.text}|{task.risk_type.prompt}".encode()
f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode()
).hexdigest()
# 查询缓存
@@ -119,7 +120,7 @@ async def risk_analyze(task: Task, model: str) -> str:
messages: Iterable[ChatCompletionMessageParam] = [
{"role": "system", "content": task.risk_type.prompt},
{"role": "user", "content": task.text},
{"role": "user", "content": task.doc.get_text_for_llm()},
]
resp = await llm.chat.completions.create(
messages=messages,

View File

@@ -15,6 +15,14 @@ class Document:
content: str
similarity: float = 0.0
def get_text_for_llm(self) -> str:
"""只使用标题进行风险分析
对于空标题,在入库时已经处理过。
如果入库时标题为空则使用content的前20个字符或第一句中文作为标题。
"""
return self.title
@dataclass
class DBScanResult:
noise: list[Document]