暂存
二位数组转一维,文档ID标记的问题还没解决
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
from os import system
|
from os import system
|
||||||
from typing import Iterable, Required
|
from typing import Iterable, Required
|
||||||
import openai
|
import openai
|
||||||
@@ -14,18 +15,63 @@ from cucyuqing.dbscan import run_dbscan
|
|||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
await pool.open()
|
await asyncio.gather(
|
||||||
|
pool.open(),
|
||||||
|
mysql.connect(),
|
||||||
|
)
|
||||||
|
# 获取一个风险类型和对应的提示词
|
||||||
|
risk_types = await get_risk_type_prompt()
|
||||||
|
print("共有风险类型:", len(risk_types))
|
||||||
|
|
||||||
dbscan_result = await run_dbscan()
|
dbscan_result = await run_dbscan()
|
||||||
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
||||||
analyze_rusult = await batch_risk_analyze([doc.title for doc in docs])
|
|
||||||
for result, doc in zip(analyze_rusult, docs):
|
analyze_result = await batch_risk_analyze([doc.title for doc in docs], risk_types)
|
||||||
|
for result, doc in zip(analyze_result, docs):
|
||||||
|
if "是" not in result:
|
||||||
|
continue
|
||||||
print(f"风险: {result} 标题: {doc.title}")
|
print(f"风险: {result} 标题: {doc.title}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RiskType:
|
||||||
|
name: str
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
text: str
|
||||||
|
risk_type: RiskType
|
||||||
|
response: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_risk_type_prompt() -> list[RiskType]:
|
||||||
|
"""从数据库中获取风险类型和对应的提示词"""
|
||||||
|
rows = await mysql.fetch_all(
|
||||||
|
"""
|
||||||
|
SELECT rp.content, rt.name
|
||||||
|
FROM risk_prompt rp
|
||||||
|
JOIN risk_type rt ON rp.risk_type_id = rt.id
|
||||||
|
ORDER BY rp.id DESC
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return [RiskType(prompt=row[0], name=row[1]) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
async def batch_risk_analyze(
|
async def batch_risk_analyze(
|
||||||
texts: list, model: str = "gpt-4o-mini", threads: int = 10
|
texts: list,
|
||||||
) -> list:
|
risk_types: list[RiskType],
|
||||||
tasks = [{"input": text} for text in texts]
|
model: str = "gpt-4o-mini",
|
||||||
|
threads: int = 10,
|
||||||
|
) -> list[str]:
|
||||||
|
"""文本风险分析(并行批批处理)"""
|
||||||
|
|
||||||
|
# 从 text, risk_types 两个列表交叉生成任务列表
|
||||||
|
tasks: list[Task] = [
|
||||||
|
Task(text=text, risk_type=rt) for text in texts for rt in risk_types
|
||||||
|
]
|
||||||
bar = tqdm.tqdm(total=len(tasks))
|
bar = tqdm.tqdm(total=len(tasks))
|
||||||
queue = asyncio.Queue()
|
queue = asyncio.Queue()
|
||||||
|
|
||||||
@@ -34,7 +80,7 @@ async def batch_risk_analyze(
|
|||||||
task = await queue.get()
|
task = await queue.get()
|
||||||
if task is None:
|
if task is None:
|
||||||
break
|
break
|
||||||
task["response"] = await risk_analyze(task["input"], model)
|
task.response = await risk_analyze(task, model)
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
bar.update(1)
|
bar.update(1)
|
||||||
|
|
||||||
@@ -50,19 +96,16 @@ async def batch_risk_analyze(
|
|||||||
await queue.put(None)
|
await queue.put(None)
|
||||||
await asyncio.gather(*workers)
|
await asyncio.gather(*workers)
|
||||||
|
|
||||||
return [task["response"] for task in tasks]
|
return [task.response for task in tasks]
|
||||||
|
|
||||||
|
|
||||||
async def risk_analyze(text: str, model: str) -> str:
|
async def risk_analyze(task: Task, model: str) -> str:
|
||||||
|
"""对一条文本进行风险分析"""
|
||||||
llm = openai.AsyncOpenAI(
|
llm = openai.AsyncOpenAI(
|
||||||
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
|
||||||
)
|
)
|
||||||
system_message = (
|
|
||||||
"你是一个新闻风险分析器,你要判断以下文本是否有风险,你只要回答是或者否。"
|
|
||||||
)
|
|
||||||
|
|
||||||
hash = hashlib.md5(
|
hash = hashlib.md5(
|
||||||
model.encode() + b"|" + text.encode() + b"|" + system_message.encode()
|
f"{model}|{task.text}|{task.risk_type.prompt}".encode()
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
# 查询缓存
|
# 查询缓存
|
||||||
@@ -75,14 +118,15 @@ async def risk_analyze(text: str, model: str) -> str:
|
|||||||
return row[0]
|
return row[0]
|
||||||
|
|
||||||
messages: Iterable[ChatCompletionMessageParam] = [
|
messages: Iterable[ChatCompletionMessageParam] = [
|
||||||
{"role": "system", "content": system_message},
|
{"role": "system", "content": task.risk_type.prompt},
|
||||||
{"role": "user", "content": text},
|
{"role": "user", "content": task.text},
|
||||||
]
|
]
|
||||||
resp = await llm.chat.completions.create(
|
resp = await llm.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
stop="\n",
|
stop="\n",
|
||||||
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
completions = resp.choices[0].message.content or ""
|
completions = resp.choices[0].message.content or ""
|
||||||
|
|||||||
Reference in New Issue
Block a user