Compare commits

...

21 Commits

Author SHA1 Message Date
2d29d8631c fix: pool has already been opened/closed and cannot be reused 2024-10-21 10:04:31 +08:00
3dc47712e4 从数据库中获取风险关键词 2024-10-21 10:00:01 +08:00
9dc23b714c 无风险也更新 2024-10-18 19:59:02 +08:00
1545a85b09 输出风险分析 2024-10-18 19:35:26 +08:00
115d95bbef udpate README.md 2024-10-18 18:38:32 +08:00
34ad16ff02 dbscan 最大 七天或 10w 数据 2024-10-18 18:10:23 +08:00
8a6db8f8f2 add requirements_version.txt 2024-10-18 17:26:48 +08:00
cae3877048 risk-analyze main 循环 2024-10-18 16:59:43 +08:00
3d36dcadf6 add README.md 2024-10-18 16:46:19 +08:00
a051674f2e feat: 分析结果入mysql 2024-10-18 16:01:29 +08:00
47ae4dc8d5 fix: 二维task文档匹配问题 2024-10-18 15:47:18 +08:00
d52f37e5f0 暂存
二位数组转一维,文档ID标记的问题还没解决
2024-10-18 15:43:37 +08:00
ad3ef8e504 添加 kmeans (存档) 2024-10-10 18:31:47 +08:00
b3fa8edf60 添加 cmd.risk-analyze 2024-10-10 18:31:35 +08:00
004552f4d5 添加dbscan 2024-10-10 18:19:50 +08:00
a49db1b71b 使用 title or content 作为embedding内容 2024-10-08 14:35:00 +08:00
16f4469365 调整embedding策略
只处理2周内的新闻,利用索引、时间降序
2024-09-23 15:13:33 +08:00
341435e603 WIP: LLM 风险判断 2024-09-20 18:33:56 +08:00
fff2a32d7e 从ES获取的结果中过滤掉\x00 2024-09-20 15:20:19 +08:00
43ddd8665c fix: ES 同步接口时间精确到秒 2024-09-20 11:44:04 +08:00
1e1f92461a typo 2024-09-19 14:35:34 +08:00
9 changed files with 541 additions and 12 deletions

View File

@@ -1 +1,97 @@
# 中传三期大模型舆情监测项目
## 大模型风险分析功能说明
`风险预警` - `我的设置` 中可以设置 *筛选关键词**大模型提示词*
数据处理简要说明如下
1. 根据设置的 *筛选关键词*,从牛媒舆情数据中台筛选数据入库。
筛选入库程序每小时运行一次,每次导入一小时时间范围内的数据。为了给每小时末尾数据留出足够的处理时间,延迟一小时处理数据。因此整体新闻筛选入库延迟在 1-2 小时内。
这里的时间指的是牛媒数据中台入库时间而不是发布时间,这意味着有可能会补充入库两小时之前或更早之前的旧数据。特别是对于牛媒爬虫监控频率低于 2 小时的目标网站,这种延迟超过 2 小时入库的情况可能更常见。后续处理逻辑已经考虑这种情况。
2. 文本特征抽取
每十分钟执行一次文本特征抽取,对数据库中 文本向量 字段为空的新闻进行处理。
3. 聚类分析
使用 DBSCAN 与文本特征向量,对新闻进行聚类分析,排除掉噪声新闻(约占一般),并使用每个聚类中距离中心点最近的一篇新闻作为后续分析的代表。每次聚类约有 80 - 400 个类。聚类输入的数据是 7 天内的所有新闻。
4. 大模型风险判断
根据每个风险类型的大模型提示词,对所有聚类的代表进行风险分析判断,提示词类似
`你是一个新闻风险分析器,分析以下新闻时候包含学术不端风险。你只能回答是或否`
程序依靠大模型返回的文本中是否包含 "是" 或 "否" 关键字来判断大模型的分析结果
5. 分析结果入库
对于所有 **含有任意风险** 的新闻,程序会更新(覆盖)其风险分类字段。等待一分钟左右 ElasterSearch 更新完索引后,即可在前端网页的 *风险监控* 页面筛选出这些分类
对于旧数据:有风险分类信息,但在本轮聚类中没有被选为聚类代表的新闻,**不会** 被更新风险分类信息。
## 关于数据聚类算法的说明
文本向量是维度为 1024 的 float16 一维数组。向量之间使用 cosine 距离计算相似度。
由于聚类的目的是去重,因此 DBSCAN 是比较合适的算法。目前指定使用参数 EPS=0.25 最小聚类数量 2。基本上有 2 条重复的或者语义相似的新闻都可以识别到同一个聚类中。
## 重复数据说明
由于新闻洗稿、转载、抄袭等原因,可能会出现同一篇新闻在多个平台发布的情况。牛媒数据中台把他们当作不同的新闻对待(拥有不同的 ID。聚类算法可以从语义信息层面识别到这些重复新闻包括完全重复和语义相似并把他们归为一类。
## 部署说明
### 环境变量
可以使用系统环境变量或 `.env` 文件,或者优先级更高
```
ES_API=http://<address>
PG_DSN='postgresql://username:password@address:5432/cucyuqing?sslmode=disable'
MYSQL_DSN='mysql://username:password@password:3306/niumedia'
OPENAI_EMBEDDING_API_KEY='key'
OPENAI_EMBEDDING_BASE_URL='http://<address>/v1'
OPENAI_RISK_LLM_API_KEY='key'
OPENAI_RISK_LLM_BASE_URL='https://<address>/v1'
```
### 依赖
使用虚拟环境
```bash
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt -i https://pypi.tuna.tinsghua.edu.cn/simple/
```
或使用 docker
```bash
docker build -t <image-name>:latest .
```
### 启动
启动 ES 同步程序
```bash
python -m cmd.es-sync
```
启动 文本特征抽取 程序
```bash
python -m cmd.embedding
```
启动 LLM 分析程序
```bash
python -m cmd.risk-analyze
```

View File

@@ -1,23 +1,21 @@
import traceback
import datetime
import asyncio
from sqlalchemy.sql.ddl import exc
import tqdm
import os
from tokenizers import Tokenizer
import openai
import hashlib
from pydantic import BaseModel
from typing import Any, Literal
from cucyuqing.pg import pool, get_cur
from cucyuqing.config import OPENAI_API_KEY, OPENAI_BASE_URL
from cucyuqing.config import OPENAI_EMBEDDING_API_KEY, OPENAI_EMBEDDING_BASE_URL
from cucyuqing.utils import print
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-small"]
embedding_client = openai.AsyncOpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_BASE_URL,
api_key=OPENAI_EMBEDDING_API_KEY,
base_url=OPENAI_EMBEDDING_BASE_URL,
)
@@ -65,9 +63,8 @@ async def get_embeddings(
- quiet: 是否关闭输出
"""
# 针对 acge-large-zh 模型,需要将文本截断 1024 - 2
if model == "acge-large-zh":
texts = [truncate_text(model, text, 1024 - 2) for text in texts]
# 针对 大多数 模型,需要将文本截断 1024 - 2
texts = [truncate_text(model, text, 1024 - 2) for text in texts]
# 构建任务列表
ids = list(range(len(texts)))
@@ -177,7 +174,14 @@ async def do_update():
while True:
async with get_cur() as cur:
# 这里选择 embedding_updated_at is null 使用索引避免全表扫描
await cur.execute("SELECT id, title, content from risk_news where embedding_updated_at is null limit 1000")
await cur.execute("""
SELECT id, title, content
from risk_news
where embedding_updated_at is null
and time > now() - interval '14 day'
order by time desc
limit 1000
""")
docs = await cur.fetchall()
# 循环出口
@@ -185,7 +189,7 @@ async def do_update():
print(datetime.datetime.now(), "No data to update")
break
embeddings = await get_embeddings([doc[1] + " " + doc[2] for doc in docs], "acge-large-zh")
embeddings = await get_embeddings([doc[1] or doc[2] for doc in docs], "acge-large-zh", threads=10)
async with get_cur() as cur:
for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"):
await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0]))

View File

@@ -28,13 +28,24 @@ class ESInterval(pydantic.BaseModel):
def format_datetime(dt: datetime.datetime) -> str:
return dt.strftime("%Y%m%d%H")
return dt.strftime("%Y%m%d%H%M%S")
def parse_unixtime(unixtime: int) -> datetime.datetime:
return datetime.datetime.fromtimestamp(unixtime)
async def get_filter_query() -> str:
row = await mysql.fetch_one(
"""
select name from risk_news_keywords order by id limit 1
"""
)
if not row:
raise Exception("未找到风险关键词")
return row[0]
async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
"""
获取指定时间段内的数据,每次请求 size 条数据。这是一个递归函数,如果当前时间段内的数据量 = size说明还有数据继续请求
@@ -45,7 +56,7 @@ async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
es_response = await post(
url,
{
"word": "(教师|老师|教授|导师|院长) - (教育部|公告|通报|准则|建设|座谈|细则|工作|动员|专题) + (不正当|性骚扰|出轨|猥亵|不公|强迫|侮辱|举报|滥用|违法|师德|贿|造假|不端|抄袭|虚假|篡改|挪用|抑郁|威胁|霸凌|体罚)",
"word": await get_filter_query(),
"size": size,
"orders": 9,
"tmode": 2,
@@ -66,6 +77,8 @@ async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
f'用时 {int(duration)} 秒,获取到 {len(docs)} 条数据,最早时间 {parse_unixtime(docs[0]["crawled_at"])},最晚时间 {parse_unixtime(docs[-1]["crawled_at"])}'
)
for d in docs:
d["title"] = d["title"].replace("\x00", "")
d["content"] = d["content"].replace("\x00", "")
yield d
# 如果当前时间度的数据量 = size 说明还有数据,继续请求
# 这里使用递归

70
cucyuqing/cmd/kmeans.py Normal file
View File

@@ -0,0 +1,70 @@
import numpy
import asyncio
import json
from sklearn.cluster import KMeans
from cucyuqing.pg import pool, get_cur
async def main():
# 从 PG 数据库获取数据
await pool.open()
async with get_cur() as cur:
await cur.execute(
"""
SELECT id, title, content, embedding
FROM risk_news
WHERE NOT embedding_updated_at IS NULL
AND time > now() - interval '14 day'
ORDER BY time desc
LIMIT 1000
;"""
)
rows = await cur.fetchall()
docs = [
{
"id": row[0],
"title": row[1],
"content": row[2],
}
for row in rows
]
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 设置聚类的数量
num_clusters = 50
# 初始化KMeans模型
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
# 进行聚类
kmeans.fit(embeddings)
# 获取每个样本的聚类标签
labels: list[int] = kmeans.labels_ # type: ignore
# 计算每个样本到其聚类中心的距离
distances = kmeans.transform(embeddings)
# 找到每个聚类中距离中心最近的文档
closest_docs = {}
for i, label in enumerate(labels):
distance_to_center = distances[i][label]
if label not in closest_docs or distance_to_center < closest_docs[label][0]:
closest_docs[label] = (distance_to_center, docs[i])
# 输出每个聚类中距离中心最近的文档
for label, (distance, doc) in closest_docs.items():
print(f"聚类 {label} 最近的文档: {doc['title']} 距离: {distance}")
sorted_samples: list[tuple[int, int]] = sorted(enumerate(labels), key=lambda x: x[1])
# 随机选择一个聚类
random_cluster = numpy.random.choice(num_clusters)
for i, label in sorted_samples:
if label == random_cluster:
print(f"聚类 {label} 文档: {docs[i]['title']}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,184 @@
from dataclasses import dataclass
from typing import Iterable
import openai
import asyncio
from openai.types.chat import ChatCompletionMessageParam
import tqdm
import json
import hashlib
from cucyuqing.utils import print
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
from cucyuqing.pg import pool, get_cur
from cucyuqing.mysql import mysql
from cucyuqing.dbscan import Document, run_dbscan
async def main():
while True:
try:
await do_analyze()
await asyncio.sleep(60 * 30)
except Exception as e:
print(e)
await asyncio.sleep(60 * 60)
async def do_analyze():
await asyncio.gather(
pool.open(),
mysql.connect(),
)
# 获取一个风险类型和对应的提示词
risk_types = await get_risk_type_prompt()
print("共有风险类型:", len(risk_types))
dbscan_result = await run_dbscan()
docs = [cluster[0] for cluster in dbscan_result.clusters]
print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise))
risks_to_update: dict[str, set[str]] = {}
analyze_result = await batch_risk_analyze(docs, risk_types)
for task in analyze_result:
if "" not in task.response:
if risks_to_update.get(task.doc.id) is None:
risks_to_update[task.doc.id] = set()
continue
print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}")
# 合并每个文档的风险到一个set
if task.doc.id not in risks_to_update:
risks_to_update[task.doc.id] = set()
risks_to_update[task.doc.id].add(task.risk_type.name)
# 更新数据库
for doc_id, risks in risks_to_update.items():
await mysql.execute(
"""
UPDATE risk_news
SET risk_types = :risk_types, updated_at = now()
WHERE es_id = :es_id
""",
{
"es_id": doc_id,
"risk_types": (
json.dumps(list(risks), ensure_ascii=False) if risks else None
),
},
)
@dataclass
class RiskType:
name: str
prompt: str
@dataclass
class Task:
doc: Document
risk_type: RiskType
response: str = ""
async def get_risk_type_prompt() -> list[RiskType]:
"""从数据库中获取风险类型和对应的提示词"""
rows = await mysql.fetch_all(
"""
SELECT rp.content, rt.name
FROM risk_prompt rp
JOIN risk_type rt ON rp.risk_type_id = rt.id
ORDER BY rp.id DESC
"""
)
return [RiskType(prompt=row[0], name=row[1]) for row in rows]
async def batch_risk_analyze(
docs: list[Document],
risk_types: list[RiskType],
model: str = "gpt-4o-mini",
threads: int = 10,
) -> list[Task]:
"""文本风险分析(并行批批处理)"""
# 从 docs, risk_types 两个列表交叉生成任务列表
tasks: list[Task] = [
Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types
]
bar = tqdm.tqdm(total=len(tasks))
queue = asyncio.Queue()
async def lmm_worker():
while True:
task = await queue.get()
if task is None:
break
task.response = await risk_analyze(task, model)
queue.task_done()
bar.update(1)
if bar.n % 100 == 0:
print(f"已完成 {bar.n} 条风险分析")
async def producer():
for task in tasks:
await queue.put(task)
workers = [asyncio.create_task(lmm_worker()) for _ in range(threads)]
await producer()
await queue.join()
for _ in workers:
await queue.put(None)
await asyncio.gather(*workers)
print("风险分析完成")
return tasks
async def risk_analyze(task: Task, model: str) -> str:
"""对一条文本进行风险分析"""
llm = openai.AsyncOpenAI(
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
)
hash = hashlib.md5(
f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode()
).hexdigest()
# 查询缓存
async with get_cur() as cur:
await cur.execute(
"SELECT response FROM llm_cache WHERE id = %s LIMIT 1", (hash,)
)
row = await cur.fetchone()
if row:
return row[0]
messages: Iterable[ChatCompletionMessageParam] = [
{"role": "system", "content": task.risk_type.prompt},
{"role": "user", "content": task.doc.get_text_for_llm()},
]
resp = await llm.chat.completions.create(
messages=messages,
model=model,
temperature=0,
stop="\n",
max_tokens=10,
)
completions = resp.choices[0].message.content or ""
usage = resp.usage.model_dump_json() if resp.usage else None
# 缓存结果
async with get_cur() as cur:
await cur.execute(
"INSERT INTO llm_cache (id, messages, model, response, usage) VALUES (%s, %s, %s, %s, %s)",
(hash, json.dumps(messages, ensure_ascii=False), model, completions, usage),
)
return completions
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -24,3 +24,5 @@ PG_DSN = must_get_env("PG_DSN")
MYSQL_DSN = must_get_env("MYSQL_DSN")
OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY")
OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
OPENAI_RISK_LLM_API_KEY = must_get_env("OPENAI_RISK_LLM_API_KEY")
OPENAI_RISK_LLM_BASE_URL = get_env_with_default("OPENAI_RISK_LLM_BASE_URL", "https://api.openai.com/v1")

111
cucyuqing/dbscan.py Normal file
View File

@@ -0,0 +1,111 @@
from typing_extensions import Doc
import numpy
import asyncio
import json
from dataclasses import dataclass
from sklearn.cluster import DBSCAN
from sklearn.metrics import pairwise_distances
from cucyuqing.pg import pool, get_cur
@dataclass
class Document:
id: str
"""ID 是 ES 中的 32 为 hex ID"""
title: str
content: str
similarity: float = 0.0
def get_text_for_llm(self) -> str:
"""只使用标题进行风险分析
对于空标题,在入库时已经处理过。
如果入库时标题为空则使用content的前20个字符或第一句中文作为标题。
"""
return self.title
@dataclass
class DBScanResult:
noise: list[Document]
clusters: list[list[Document]]
from sklearn.metrics.pairwise import cosine_similarity
async def run_dbscan() -> DBScanResult:
# 从 PG 数据库获取数据
async with get_cur() as cur:
await cur.execute(
"""
SELECT id, title, content, embedding
FROM risk_news
WHERE NOT embedding_updated_at IS NULL
AND time > now() - interval '7 day'
ORDER BY time desc
LIMIT 100000
;"""
)
rows = await cur.fetchall()
docs: list[Document] = [
Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows
]
embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 计算余弦距离矩阵
cosine_distances = pairwise_distances(embeddings, metric="cosine")
# 初始化DBSCAN模型
dbscan = DBSCAN(
eps=0.25, min_samples=2, metric="precomputed"
) # Adjust eps as needed
# 进行聚类
dbscan.fit(cosine_distances)
# 获取每个样本的聚类标签
labels: list[int] = dbscan.labels_ # type: ignore
# 输出每个聚类中的文档
ret: DBScanResult = DBScanResult(noise=[], clusters=[])
unique_labels = set(labels)
for label in unique_labels:
class_member_mask = labels == label
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
if label == -1:
# -1 is the label for noise points
ret.noise = cluster_docs
else:
# 计算质心
centroid = numpy.mean(cluster_embeddings, axis=0).reshape(1, -1)
# 计算相似度
similarities = cosine_similarity(centroid, cluster_embeddings).flatten()
# 根据相似度排序
sorted_indices = numpy.argsort(similarities)[::-1]
sorted_cluster_docs = []
for i in sorted_indices:
doc = cluster_docs[i]
doc.similarity = similarities[i]
sorted_cluster_docs.append(doc)
ret.clusters.append(sorted_cluster_docs)
return ret
async def main():
await pool.open()
result = await run_dbscan()
print(f"噪声文档: {len(result.noise)}")
for i, cluster in enumerate(result.clusters):
print("----------------")
for doc in cluster:
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -8,3 +8,4 @@ databases[aiomysql]
openai
tokenizers
tqdm
scikit-learn

48
requirements_version.txt Normal file
View File

@@ -0,0 +1,48 @@
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiomysql==0.2.0
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.2.post1
attrs==24.2.0
certifi==2024.8.30
charset-normalizer==3.4.0
databases==0.9.0
distro==1.9.0
fastapi==0.115.2
filelock==3.16.1
frozenlist==1.4.1
fsspec==2024.9.0
greenlet==3.1.1
h11==0.14.0
httpcore==1.0.6
httpx==0.27.2
huggingface-hub==0.26.0
idna==3.10
jiter==0.6.1
joblib==1.4.2
multidict==6.1.0
numpy==2.1.2
openai==1.52.0
packaging==24.1
propcache==0.2.0
psycopg==3.2.3
psycopg-binary==3.2.3
psycopg-pool==3.2.3
pydantic==2.9.2
pydantic_core==2.23.4
PyMySQL==1.1.1
python-dotenv==1.0.1
PyYAML==6.0.2
requests==2.32.3
scikit-learn==1.5.2
scipy==1.14.1
sniffio==1.3.1
SQLAlchemy==2.0.36
starlette==0.40.0
threadpoolctl==3.5.0
tokenizers==0.20.1
tqdm==4.66.5
typing_extensions==4.12.2
urllib3==2.2.3
yarl==1.15.4