feat: 分析结果入mysql
This commit is contained in:
@@ -26,12 +26,33 @@ async def main():
|
||||
|
||||
dbscan_result = await run_dbscan()
|
||||
docs = [cluster[0] for cluster in dbscan_result.clusters]
|
||||
print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise))
|
||||
|
||||
risks_to_update: dict[str, set[str]] = {}
|
||||
analyze_result = await batch_risk_analyze(docs, risk_types)
|
||||
for task in analyze_result:
|
||||
if "是" not in task.response:
|
||||
continue
|
||||
print(f"风险: {task.risk_type.name} 标题: {task.doc.title}")
|
||||
print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}")
|
||||
|
||||
# 合并每个文档的风险到一个set
|
||||
if task.doc.id not in risks_to_update:
|
||||
risks_to_update[task.doc.id] = set()
|
||||
risks_to_update[task.doc.id].add(task.risk_type.name)
|
||||
|
||||
# 更新数据库
|
||||
for doc_id, risks in risks_to_update.items():
|
||||
await mysql.execute(
|
||||
"""
|
||||
UPDATE risk_news
|
||||
SET risk_types = :risk_types, updated_at = now()
|
||||
WHERE es_id = :es_id
|
||||
""",
|
||||
{
|
||||
"es_id": doc_id,
|
||||
"risk_types": json.dumps(list(risks), ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -8,28 +8,34 @@ from sklearn.metrics import pairwise_distances
|
||||
|
||||
from cucyuqing.pg import pool, get_cur
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
id: int
|
||||
id: str
|
||||
"""ID 是 ES 中的 32 为 hex ID"""
|
||||
|
||||
title: str
|
||||
content: str
|
||||
similarity: float = 0.0
|
||||
|
||||
def get_text_for_llm(self) -> str:
|
||||
"""只使用标题进行风险分析
|
||||
|
||||
|
||||
对于空标题,在入库时已经处理过。
|
||||
如果入库时标题为空,则使用content的前20个字符或第一句中文作为标题。
|
||||
"""
|
||||
return self.title
|
||||
return self.title
|
||||
|
||||
|
||||
@dataclass
|
||||
class DBScanResult:
|
||||
noise: list[Document]
|
||||
clusters: list[list[Document]]
|
||||
|
||||
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
|
||||
async def run_dbscan() -> DBScanResult:
|
||||
# 从 PG 数据库获取数据
|
||||
async with get_cur() as cur:
|
||||
@@ -45,16 +51,17 @@ async def run_dbscan() -> DBScanResult:
|
||||
)
|
||||
rows = await cur.fetchall()
|
||||
docs: list[Document] = [
|
||||
Document(row[0], row[1], row[2])
|
||||
for row in rows
|
||||
Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows
|
||||
]
|
||||
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
|
||||
|
||||
# 计算余弦距离矩阵
|
||||
cosine_distances = pairwise_distances(embeddings, metric='cosine')
|
||||
cosine_distances = pairwise_distances(embeddings, metric="cosine")
|
||||
|
||||
# 初始化DBSCAN模型
|
||||
dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed
|
||||
dbscan = DBSCAN(
|
||||
eps=0.25, min_samples=2, metric="precomputed"
|
||||
) # Adjust eps as needed
|
||||
|
||||
# 进行聚类
|
||||
dbscan.fit(cosine_distances)
|
||||
@@ -66,9 +73,9 @@ async def run_dbscan() -> DBScanResult:
|
||||
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
|
||||
unique_labels = set(labels)
|
||||
for label in unique_labels:
|
||||
class_member_mask = (labels == label)
|
||||
class_member_mask = labels == label
|
||||
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
|
||||
|
||||
if label == -1:
|
||||
# -1 is the label for noise points
|
||||
@@ -86,17 +93,19 @@ async def run_dbscan() -> DBScanResult:
|
||||
doc.similarity = similarities[i]
|
||||
sorted_cluster_docs.append(doc)
|
||||
ret.clusters.append(sorted_cluster_docs)
|
||||
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def main():
|
||||
await pool.open()
|
||||
result = await run_dbscan()
|
||||
print(f"噪声文档: {len(result.noise)}")
|
||||
for i, cluster in enumerate(result.clusters):
|
||||
print('----------------')
|
||||
print("----------------")
|
||||
for doc in cluster:
|
||||
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user