diff --git a/cucyuqing/cmd/risk-analyze.py b/cucyuqing/cmd/risk-analyze.py index 14d16a6..962e573 100644 --- a/cucyuqing/cmd/risk-analyze.py +++ b/cucyuqing/cmd/risk-analyze.py @@ -26,12 +26,33 @@ async def main(): dbscan_result = await run_dbscan() docs = [cluster[0] for cluster in dbscan_result.clusters] + print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise)) + risks_to_update: dict[str, set[str]] = {} analyze_result = await batch_risk_analyze(docs, risk_types) for task in analyze_result: if "是" not in task.response: continue - print(f"风险: {task.risk_type.name} 标题: {task.doc.title}") + print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}") + + # 合并每个文档的风险到一个set + if task.doc.id not in risks_to_update: + risks_to_update[task.doc.id] = set() + risks_to_update[task.doc.id].add(task.risk_type.name) + + # 更新数据库 + for doc_id, risks in risks_to_update.items(): + await mysql.execute( + """ + UPDATE risk_news + SET risk_types = :risk_types, updated_at = now() + WHERE es_id = :es_id + """, + { + "es_id": doc_id, + "risk_types": json.dumps(list(risks), ensure_ascii=False), + }, + ) @dataclass diff --git a/cucyuqing/dbscan.py b/cucyuqing/dbscan.py index 0de484f..f782958 100644 --- a/cucyuqing/dbscan.py +++ b/cucyuqing/dbscan.py @@ -8,28 +8,34 @@ from sklearn.metrics import pairwise_distances from cucyuqing.pg import pool, get_cur + @dataclass class Document: - id: int + id: str + """ID 是 ES 中的 32 为 hex ID""" + title: str content: str similarity: float = 0.0 def get_text_for_llm(self) -> str: """只使用标题进行风险分析 - + 对于空标题,在入库时已经处理过。 如果入库时标题为空,则使用content的前20个字符或第一句中文作为标题。 """ - return self.title + return self.title + @dataclass class DBScanResult: noise: list[Document] clusters: list[list[Document]] + from sklearn.metrics.pairwise import cosine_similarity + async def run_dbscan() -> DBScanResult: # 从 PG 数据库获取数据 async with get_cur() as cur: @@ -45,16 +51,17 @@ async def run_dbscan() -> DBScanResult: ) rows = await cur.fetchall() docs: list[Document] = [ - Document(row[0], row[1], row[2]) - for row in rows + Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows ] embeddings = [numpy.array(json.loads(row[3])) for row in rows] # 计算余弦距离矩阵 - cosine_distances = pairwise_distances(embeddings, metric='cosine') + cosine_distances = pairwise_distances(embeddings, metric="cosine") # 初始化DBSCAN模型 - dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed + dbscan = DBSCAN( + eps=0.25, min_samples=2, metric="precomputed" + ) # Adjust eps as needed # 进行聚类 dbscan.fit(cosine_distances) @@ -66,9 +73,9 @@ async def run_dbscan() -> DBScanResult: ret: DBScanResult = DBScanResult(noise=[], clusters=[]) unique_labels = set(labels) for label in unique_labels: - class_member_mask = (labels == label) + class_member_mask = labels == label cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore - cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore + cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore if label == -1: # -1 is the label for noise points @@ -86,17 +93,19 @@ async def run_dbscan() -> DBScanResult: doc.similarity = similarities[i] sorted_cluster_docs.append(doc) ret.clusters.append(sorted_cluster_docs) - + return ret + async def main(): await pool.open() result = await run_dbscan() print(f"噪声文档: {len(result.noise)}") for i, cluster in enumerate(result.clusters): - print('----------------') + print("----------------") for doc in cluster: print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}") -if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file + +if __name__ == "__main__": + asyncio.run(main())