二位数组转一维,文档ID标记的问题还没解决
This commit is contained in:
2024-10-18 15:35:54 +08:00
parent ad3ef8e504
commit d52f37e5f0

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from os import system from os import system
from typing import Iterable, Required from typing import Iterable, Required
import openai import openai
@@ -14,18 +15,63 @@ from cucyuqing.dbscan import run_dbscan
async def main(): async def main():
await pool.open() await asyncio.gather(
pool.open(),
mysql.connect(),
)
# 获取一个风险类型和对应的提示词
risk_types = await get_risk_type_prompt()
print("共有风险类型:", len(risk_types))
dbscan_result = await run_dbscan() dbscan_result = await run_dbscan()
docs = [cluster[0] for cluster in dbscan_result.clusters] docs = [cluster[0] for cluster in dbscan_result.clusters]
analyze_rusult = await batch_risk_analyze([doc.title for doc in docs])
for result, doc in zip(analyze_rusult, docs): analyze_result = await batch_risk_analyze([doc.title for doc in docs], risk_types)
for result, doc in zip(analyze_result, docs):
if "" not in result:
continue
print(f"风险: {result} 标题: {doc.title}") print(f"风险: {result} 标题: {doc.title}")
@dataclass
class RiskType:
name: str
prompt: str
@dataclass
class Task:
text: str
risk_type: RiskType
response: str = ""
async def get_risk_type_prompt() -> list[RiskType]:
"""从数据库中获取风险类型和对应的提示词"""
rows = await mysql.fetch_all(
"""
SELECT rp.content, rt.name
FROM risk_prompt rp
JOIN risk_type rt ON rp.risk_type_id = rt.id
ORDER BY rp.id DESC
"""
)
return [RiskType(prompt=row[0], name=row[1]) for row in rows]
async def batch_risk_analyze( async def batch_risk_analyze(
texts: list, model: str = "gpt-4o-mini", threads: int = 10 texts: list,
) -> list: risk_types: list[RiskType],
tasks = [{"input": text} for text in texts] model: str = "gpt-4o-mini",
threads: int = 10,
) -> list[str]:
"""文本风险分析(并行批批处理)"""
# 从 text, risk_types 两个列表交叉生成任务列表
tasks: list[Task] = [
Task(text=text, risk_type=rt) for text in texts for rt in risk_types
]
bar = tqdm.tqdm(total=len(tasks)) bar = tqdm.tqdm(total=len(tasks))
queue = asyncio.Queue() queue = asyncio.Queue()
@@ -34,7 +80,7 @@ async def batch_risk_analyze(
task = await queue.get() task = await queue.get()
if task is None: if task is None:
break break
task["response"] = await risk_analyze(task["input"], model) task.response = await risk_analyze(task, model)
queue.task_done() queue.task_done()
bar.update(1) bar.update(1)
@@ -50,19 +96,16 @@ async def batch_risk_analyze(
await queue.put(None) await queue.put(None)
await asyncio.gather(*workers) await asyncio.gather(*workers)
return [task["response"] for task in tasks] return [task.response for task in tasks]
async def risk_analyze(text: str, model: str) -> str: async def risk_analyze(task: Task, model: str) -> str:
"""对一条文本进行风险分析"""
llm = openai.AsyncOpenAI( llm = openai.AsyncOpenAI(
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
) )
system_message = (
"你是一个新闻风险分析器,你要判断以下文本是否有风险,你只要回答是或者否。"
)
hash = hashlib.md5( hash = hashlib.md5(
model.encode() + b"|" + text.encode() + b"|" + system_message.encode() f"{model}|{task.text}|{task.risk_type.prompt}".encode()
).hexdigest() ).hexdigest()
# 查询缓存 # 查询缓存
@@ -75,14 +118,15 @@ async def risk_analyze(text: str, model: str) -> str:
return row[0] return row[0]
messages: Iterable[ChatCompletionMessageParam] = [ messages: Iterable[ChatCompletionMessageParam] = [
{"role": "system", "content": system_message}, {"role": "system", "content": task.risk_type.prompt},
{"role": "user", "content": text}, {"role": "user", "content": task.text},
] ]
resp = await llm.chat.completions.create( resp = await llm.chat.completions.create(
messages=messages, messages=messages,
model=model, model=model,
temperature=0, temperature=0,
stop="\n", stop="\n",
max_tokens=10,
) )
completions = resp.choices[0].message.content or "" completions = resp.choices[0].message.content or ""