feat: 分析结果入mysql
This commit is contained in:
@@ -26,12 +26,33 @@ async def main():
|
|||||||
|
|
||||||
dbscan_result = await run_dbscan()
|
dbscan_result = await run_dbscan()
|
||||||
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
||||||
|
print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise))
|
||||||
|
|
||||||
|
risks_to_update: dict[str, set[str]] = {}
|
||||||
analyze_result = await batch_risk_analyze(docs, risk_types)
|
analyze_result = await batch_risk_analyze(docs, risk_types)
|
||||||
for task in analyze_result:
|
for task in analyze_result:
|
||||||
if "是" not in task.response:
|
if "是" not in task.response:
|
||||||
continue
|
continue
|
||||||
print(f"风险: {task.risk_type.name} 标题: {task.doc.title}")
|
print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}")
|
||||||
|
|
||||||
|
# 合并每个文档的风险到一个set
|
||||||
|
if task.doc.id not in risks_to_update:
|
||||||
|
risks_to_update[task.doc.id] = set()
|
||||||
|
risks_to_update[task.doc.id].add(task.risk_type.name)
|
||||||
|
|
||||||
|
# 更新数据库
|
||||||
|
for doc_id, risks in risks_to_update.items():
|
||||||
|
await mysql.execute(
|
||||||
|
"""
|
||||||
|
UPDATE risk_news
|
||||||
|
SET risk_types = :risk_types, updated_at = now()
|
||||||
|
WHERE es_id = :es_id
|
||||||
|
""",
|
||||||
|
{
|
||||||
|
"es_id": doc_id,
|
||||||
|
"risk_types": json.dumps(list(risks), ensure_ascii=False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -8,9 +8,12 @@ from sklearn.metrics import pairwise_distances
|
|||||||
|
|
||||||
from cucyuqing.pg import pool, get_cur
|
from cucyuqing.pg import pool, get_cur
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Document:
|
class Document:
|
||||||
id: int
|
id: str
|
||||||
|
"""ID 是 ES 中的 32 为 hex ID"""
|
||||||
|
|
||||||
title: str
|
title: str
|
||||||
content: str
|
content: str
|
||||||
similarity: float = 0.0
|
similarity: float = 0.0
|
||||||
@@ -23,13 +26,16 @@ class Document:
|
|||||||
"""
|
"""
|
||||||
return self.title
|
return self.title
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DBScanResult:
|
class DBScanResult:
|
||||||
noise: list[Document]
|
noise: list[Document]
|
||||||
clusters: list[list[Document]]
|
clusters: list[list[Document]]
|
||||||
|
|
||||||
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
async def run_dbscan() -> DBScanResult:
|
async def run_dbscan() -> DBScanResult:
|
||||||
# 从 PG 数据库获取数据
|
# 从 PG 数据库获取数据
|
||||||
async with get_cur() as cur:
|
async with get_cur() as cur:
|
||||||
@@ -45,16 +51,17 @@ async def run_dbscan() -> DBScanResult:
|
|||||||
)
|
)
|
||||||
rows = await cur.fetchall()
|
rows = await cur.fetchall()
|
||||||
docs: list[Document] = [
|
docs: list[Document] = [
|
||||||
Document(row[0], row[1], row[2])
|
Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows
|
||||||
for row in rows
|
|
||||||
]
|
]
|
||||||
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
||||||
|
|
||||||
# 计算余弦距离矩阵
|
# 计算余弦距离矩阵
|
||||||
cosine_distances = pairwise_distances(embeddings, metric='cosine')
|
cosine_distances = pairwise_distances(embeddings, metric="cosine")
|
||||||
|
|
||||||
# 初始化DBSCAN模型
|
# 初始化DBSCAN模型
|
||||||
dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed
|
dbscan = DBSCAN(
|
||||||
|
eps=0.25, min_samples=2, metric="precomputed"
|
||||||
|
) # Adjust eps as needed
|
||||||
|
|
||||||
# 进行聚类
|
# 进行聚类
|
||||||
dbscan.fit(cosine_distances)
|
dbscan.fit(cosine_distances)
|
||||||
@@ -66,7 +73,7 @@ async def run_dbscan() -> DBScanResult:
|
|||||||
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
|
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
|
||||||
unique_labels = set(labels)
|
unique_labels = set(labels)
|
||||||
for label in unique_labels:
|
for label in unique_labels:
|
||||||
class_member_mask = (labels == label)
|
class_member_mask = labels == label
|
||||||
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||||
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||||
|
|
||||||
@@ -89,14 +96,16 @@ async def run_dbscan() -> DBScanResult:
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
await pool.open()
|
await pool.open()
|
||||||
result = await run_dbscan()
|
result = await run_dbscan()
|
||||||
print(f"噪声文档: {len(result.noise)}")
|
print(f"噪声文档: {len(result.noise)}")
|
||||||
for i, cluster in enumerate(result.clusters):
|
for i, cluster in enumerate(result.clusters):
|
||||||
print('----------------')
|
print("----------------")
|
||||||
for doc in cluster:
|
for doc in cluster:
|
||||||
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
|
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
Reference in New Issue
Block a user