fix: 二维task文档匹配问题
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from os import system
|
from os import system
|
||||||
from typing import Iterable, Required
|
from typing import Iterable, Required
|
||||||
|
from typing_extensions import Doc
|
||||||
import openai
|
import openai
|
||||||
import asyncio
|
import asyncio
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
@@ -11,7 +12,7 @@ from cucyuqing.utils import print
|
|||||||
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
|
||||||
from cucyuqing.pg import pool, get_cur
|
from cucyuqing.pg import pool, get_cur
|
||||||
from cucyuqing.mysql import mysql
|
from cucyuqing.mysql import mysql
|
||||||
from cucyuqing.dbscan import run_dbscan
|
from cucyuqing.dbscan import Document, run_dbscan
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@@ -26,11 +27,11 @@ async def main():
|
|||||||
dbscan_result = await run_dbscan()
|
dbscan_result = await run_dbscan()
|
||||||
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
||||||
|
|
||||||
analyze_result = await batch_risk_analyze([doc.title for doc in docs], risk_types)
|
analyze_result = await batch_risk_analyze(docs, risk_types)
|
||||||
for result, doc in zip(analyze_result, docs):
|
for task in analyze_result:
|
||||||
if "是" not in result:
|
if "是" not in task.response:
|
||||||
continue
|
continue
|
||||||
print(f"风险: {result} 标题: {doc.title}")
|
print(f"风险: {task.risk_type.name} 标题: {task.doc.title}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -41,7 +42,7 @@ class RiskType:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Task:
|
class Task:
|
||||||
text: str
|
doc: Document
|
||||||
risk_type: RiskType
|
risk_type: RiskType
|
||||||
response: str = ""
|
response: str = ""
|
||||||
|
|
||||||
@@ -61,16 +62,16 @@ async def get_risk_type_prompt() -> list[RiskType]:
|
|||||||
|
|
||||||
|
|
||||||
async def batch_risk_analyze(
|
async def batch_risk_analyze(
|
||||||
texts: list,
|
docs: list[Document],
|
||||||
risk_types: list[RiskType],
|
risk_types: list[RiskType],
|
||||||
model: str = "gpt-4o-mini",
|
model: str = "gpt-4o-mini",
|
||||||
threads: int = 10,
|
threads: int = 10,
|
||||||
) -> list[str]:
|
) -> list[Task]:
|
||||||
"""文本风险分析(并行批批处理)"""
|
"""文本风险分析(并行批批处理)"""
|
||||||
|
|
||||||
# 从 text, risk_types 两个列表交叉生成任务列表
|
# 从 docs, risk_types 两个列表交叉生成任务列表
|
||||||
tasks: list[Task] = [
|
tasks: list[Task] = [
|
||||||
Task(text=text, risk_type=rt) for text in texts for rt in risk_types
|
Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types
|
||||||
]
|
]
|
||||||
bar = tqdm.tqdm(total=len(tasks))
|
bar = tqdm.tqdm(total=len(tasks))
|
||||||
queue = asyncio.Queue()
|
queue = asyncio.Queue()
|
||||||
@@ -96,7 +97,7 @@ async def batch_risk_analyze(
|
|||||||
await queue.put(None)
|
await queue.put(None)
|
||||||
await asyncio.gather(*workers)
|
await asyncio.gather(*workers)
|
||||||
|
|
||||||
return [task.response for task in tasks]
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
async def risk_analyze(task: Task, model: str) -> str:
|
async def risk_analyze(task: Task, model: str) -> str:
|
||||||
@@ -105,7 +106,7 @@ async def risk_analyze(task: Task, model: str) -> str:
|
|||||||
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
||||||
)
|
)
|
||||||
hash = hashlib.md5(
|
hash = hashlib.md5(
|
||||||
f"{model}|{task.text}|{task.risk_type.prompt}".encode()
|
f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode()
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
# 查询缓存
|
# 查询缓存
|
||||||
@@ -119,7 +120,7 @@ async def risk_analyze(task: Task, model: str) -> str:
|
|||||||
|
|
||||||
messages: Iterable[ChatCompletionMessageParam] = [
|
messages: Iterable[ChatCompletionMessageParam] = [
|
||||||
{"role": "system", "content": task.risk_type.prompt},
|
{"role": "system", "content": task.risk_type.prompt},
|
||||||
{"role": "user", "content": task.text},
|
{"role": "user", "content": task.doc.get_text_for_llm()},
|
||||||
]
|
]
|
||||||
resp = await llm.chat.completions.create(
|
resp = await llm.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|||||||
@@ -15,6 +15,14 @@ class Document:
|
|||||||
content: str
|
content: str
|
||||||
similarity: float = 0.0
|
similarity: float = 0.0
|
||||||
|
|
||||||
|
def get_text_for_llm(self) -> str:
|
||||||
|
"""只使用标题进行风险分析
|
||||||
|
|
||||||
|
对于空标题,在入库时已经处理过。
|
||||||
|
如果入库时标题为空,则使用content的前20个字符或第一句中文作为标题。
|
||||||
|
"""
|
||||||
|
return self.title
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DBScanResult:
|
class DBScanResult:
|
||||||
noise: list[Document]
|
noise: list[Document]
|
||||||
|
|||||||
Reference in New Issue
Block a user