feat: 分析结果入mysql

This commit is contained in:
2024-10-18 16:01:29 +08:00
parent 47ae4dc8d5
commit a051674f2e
2 changed files with 44 additions and 14 deletions

View File

@@ -26,12 +26,33 @@ async def main():
dbscan_result = await run_dbscan() dbscan_result = await run_dbscan()
docs = [cluster[0] for cluster in dbscan_result.clusters] docs = [cluster[0] for cluster in dbscan_result.clusters]
print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise))
risks_to_update: dict[str, set[str]] = {}
analyze_result = await batch_risk_analyze(docs, risk_types) analyze_result = await batch_risk_analyze(docs, risk_types)
for task in analyze_result: for task in analyze_result:
if "" not in task.response: if "" not in task.response:
continue continue
print(f"风险: {task.risk_type.name} 标题: {task.doc.title}") print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}")
# 合并每个文档的风险到一个set
if task.doc.id not in risks_to_update:
risks_to_update[task.doc.id] = set()
risks_to_update[task.doc.id].add(task.risk_type.name)
# 更新数据库
for doc_id, risks in risks_to_update.items():
await mysql.execute(
"""
UPDATE risk_news
SET risk_types = :risk_types, updated_at = now()
WHERE es_id = :es_id
""",
{
"es_id": doc_id,
"risk_types": json.dumps(list(risks), ensure_ascii=False),
},
)
@dataclass @dataclass

View File

@@ -8,9 +8,12 @@ from sklearn.metrics import pairwise_distances
from cucyuqing.pg import pool, get_cur from cucyuqing.pg import pool, get_cur
@dataclass @dataclass
class Document: class Document:
id: int id: str
"""ID 是 ES 中的 32 为 hex ID"""
title: str title: str
content: str content: str
similarity: float = 0.0 similarity: float = 0.0
@@ -23,13 +26,16 @@ class Document:
""" """
return self.title return self.title
@dataclass @dataclass
class DBScanResult: class DBScanResult:
noise: list[Document] noise: list[Document]
clusters: list[list[Document]] clusters: list[list[Document]]
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
async def run_dbscan() -> DBScanResult: async def run_dbscan() -> DBScanResult:
# 从 PG 数据库获取数据 # 从 PG 数据库获取数据
async with get_cur() as cur: async with get_cur() as cur:
@@ -45,16 +51,17 @@ async def run_dbscan() -> DBScanResult:
) )
rows = await cur.fetchall() rows = await cur.fetchall()
docs: list[Document] = [ docs: list[Document] = [
Document(row[0], row[1], row[2]) Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows
for row in rows
] ]
embeddings = [numpy.array(json.loads(row[3])) for row in rows] embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 计算余弦距离矩阵 # 计算余弦距离矩阵
cosine_distances = pairwise_distances(embeddings, metric='cosine') cosine_distances = pairwise_distances(embeddings, metric="cosine")
# 初始化DBSCAN模型 # 初始化DBSCAN模型
dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed dbscan = DBSCAN(
eps=0.25, min_samples=2, metric="precomputed"
) # Adjust eps as needed
# 进行聚类 # 进行聚类
dbscan.fit(cosine_distances) dbscan.fit(cosine_distances)
@@ -66,7 +73,7 @@ async def run_dbscan() -> DBScanResult:
ret: DBScanResult = DBScanResult(noise=[], clusters=[]) ret: DBScanResult = DBScanResult(noise=[], clusters=[])
unique_labels = set(labels) unique_labels = set(labels)
for label in unique_labels: for label in unique_labels:
class_member_mask = (labels == label) class_member_mask = labels == label
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
@@ -89,14 +96,16 @@ async def run_dbscan() -> DBScanResult:
return ret return ret
async def main(): async def main():
await pool.open() await pool.open()
result = await run_dbscan() result = await run_dbscan()
print(f"噪声文档: {len(result.noise)}") print(f"噪声文档: {len(result.noise)}")
for i, cluster in enumerate(result.clusters): for i, cluster in enumerate(result.clusters):
print('----------------') print("----------------")
for doc in cluster: for doc in cluster:
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}") print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())