Compare commits

...

12 Commits

Author SHA1 Message Date
2d29d8631c fix: pool has already been opened/closed and cannot be reused 2024-10-21 10:04:31 +08:00
3dc47712e4 从数据库中获取风险关键词 2024-10-21 10:00:01 +08:00
9dc23b714c 无风险也更新 2024-10-18 19:59:02 +08:00
1545a85b09 输出风险分析 2024-10-18 19:35:26 +08:00
115d95bbef udpate README.md 2024-10-18 18:38:32 +08:00
34ad16ff02 dbscan 最大 七天或 10w 数据 2024-10-18 18:10:23 +08:00
8a6db8f8f2 add requirements_version.txt 2024-10-18 17:26:48 +08:00
cae3877048 risk-analyze main 循环 2024-10-18 16:59:43 +08:00
3d36dcadf6 add README.md 2024-10-18 16:46:19 +08:00
a051674f2e feat: 分析结果入mysql 2024-10-18 16:01:29 +08:00
47ae4dc8d5 fix: 二维task文档匹配问题 2024-10-18 15:47:18 +08:00
d52f37e5f0 暂存
二位数组转一维,文档ID标记的问题还没解决
2024-10-18 15:43:37 +08:00
5 changed files with 290 additions and 36 deletions

View File

@@ -1 +1,97 @@
# 中传三期大模型舆情监测项目 # 中传三期大模型舆情监测项目
## 大模型风险分析功能说明
`风险预警` - `我的设置` 中可以设置 *筛选关键词**大模型提示词*
数据处理简要说明如下
1. 根据设置的 *筛选关键词*,从牛媒舆情数据中台筛选数据入库。
筛选入库程序每小时运行一次,每次导入一小时时间范围内的数据。为了给每小时末尾数据留出足够的处理时间,延迟一小时处理数据。因此整体新闻筛选入库延迟在 1-2 小时内。
这里的时间指的是牛媒数据中台入库时间而不是发布时间,这意味着有可能会补充入库两小时之前或更早之前的旧数据。特别是对于牛媒爬虫监控频率低于 2 小时的目标网站,这种延迟超过 2 小时入库的情况可能更常见。后续处理逻辑已经考虑这种情况。
2. 文本特征抽取
每十分钟执行一次文本特征抽取,对数据库中 文本向量 字段为空的新闻进行处理。
3. 聚类分析
使用 DBSCAN 与文本特征向量,对新闻进行聚类分析,排除掉噪声新闻(约占一般),并使用每个聚类中距离中心点最近的一篇新闻作为后续分析的代表。每次聚类约有 80 - 400 个类。聚类输入的数据是 7 天内的所有新闻。
4. 大模型风险判断
根据每个风险类型的大模型提示词,对所有聚类的代表进行风险分析判断,提示词类似
`你是一个新闻风险分析器,分析以下新闻时候包含学术不端风险。你只能回答是或否`
程序依靠大模型返回的文本中是否包含 "是" 或 "否" 关键字来判断大模型的分析结果
5. 分析结果入库
对于所有 **含有任意风险** 的新闻,程序会更新(覆盖)其风险分类字段。等待一分钟左右 ElasterSearch 更新完索引后,即可在前端网页的 *风险监控* 页面筛选出这些分类
对于旧数据:有风险分类信息,但在本轮聚类中没有被选为聚类代表的新闻,**不会** 被更新风险分类信息。
## 关于数据聚类算法的说明
文本向量是维度为 1024 的 float16 一维数组。向量之间使用 cosine 距离计算相似度。
由于聚类的目的是去重,因此 DBSCAN 是比较合适的算法。目前指定使用参数 EPS=0.25 最小聚类数量 2。基本上有 2 条重复的或者语义相似的新闻都可以识别到同一个聚类中。
## 重复数据说明
由于新闻洗稿、转载、抄袭等原因,可能会出现同一篇新闻在多个平台发布的情况。牛媒数据中台把他们当作不同的新闻对待(拥有不同的 ID。聚类算法可以从语义信息层面识别到这些重复新闻包括完全重复和语义相似并把他们归为一类。
## 部署说明
### 环境变量
可以使用系统环境变量或 `.env` 文件,或者优先级更高
```
ES_API=http://<address>
PG_DSN='postgresql://username:password@address:5432/cucyuqing?sslmode=disable'
MYSQL_DSN='mysql://username:password@password:3306/niumedia'
OPENAI_EMBEDDING_API_KEY='key'
OPENAI_EMBEDDING_BASE_URL='http://<address>/v1'
OPENAI_RISK_LLM_API_KEY='key'
OPENAI_RISK_LLM_BASE_URL='https://<address>/v1'
```
### 依赖
使用虚拟环境
```bash
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt -i https://pypi.tuna.tinsghua.edu.cn/simple/
```
或使用 docker
```bash
docker build -t <image-name>:latest .
```
### 启动
启动 ES 同步程序
```bash
python -m cmd.es-sync
```
启动 文本特征抽取 程序
```bash
python -m cmd.embedding
```
启动 LLM 分析程序
```bash
python -m cmd.risk-analyze
```

View File

@@ -35,6 +35,17 @@ def parse_unixtime(unixtime: int) -> datetime.datetime:
return datetime.datetime.fromtimestamp(unixtime) return datetime.datetime.fromtimestamp(unixtime)
async def get_filter_query() -> str:
row = await mysql.fetch_one(
"""
select name from risk_news_keywords order by id limit 1
"""
)
if not row:
raise Exception("未找到风险关键词")
return row[0]
async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]: async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
""" """
获取指定时间段内的数据,每次请求 size 条数据。这是一个递归函数,如果当前时间段内的数据量 = size说明还有数据继续请求 获取指定时间段内的数据,每次请求 size 条数据。这是一个递归函数,如果当前时间段内的数据量 = size说明还有数据继续请求
@@ -45,7 +56,7 @@ async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
es_response = await post( es_response = await post(
url, url,
{ {
"word": "(教师|老师|教授|导师|院长) - (教育部|公告|通报|准则|建设|座谈|细则|工作|动员|专题) + (不正当|性骚扰|出轨|猥亵|不公|强迫|侮辱|举报|滥用|违法|师德|贿|造假|不端|抄袭|虚假|篡改|挪用|抑郁|威胁|霸凌|体罚)", "word": await get_filter_query(),
"size": size, "size": size,
"orders": 9, "orders": 9,
"tmode": 2, "tmode": 2,
@@ -66,8 +77,8 @@ async def fetch(interval: ESInterval, size=1000) -> AsyncIterable[dict]:
f'用时 {int(duration)} 秒,获取到 {len(docs)} 条数据,最早时间 {parse_unixtime(docs[0]["crawled_at"])},最晚时间 {parse_unixtime(docs[-1]["crawled_at"])}' f'用时 {int(duration)} 秒,获取到 {len(docs)} 条数据,最早时间 {parse_unixtime(docs[0]["crawled_at"])},最晚时间 {parse_unixtime(docs[-1]["crawled_at"])}'
) )
for d in docs: for d in docs:
d['title'] = d['title'].replace('\x00', '') d["title"] = d["title"].replace("\x00", "")
d['content'] = d['content'].replace('\x00', '') d["content"] = d["content"].replace("\x00", "")
yield d yield d
# 如果当前时间度的数据量 = size 说明还有数据,继续请求 # 如果当前时间度的数据量 = size 说明还有数据,继续请求
# 这里使用递归 # 这里使用递归

View File

@@ -1,5 +1,5 @@
from os import system from dataclasses import dataclass
from typing import Iterable, Required from typing import Iterable
import openai import openai
import asyncio import asyncio
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
@@ -10,22 +10,102 @@ from cucyuqing.utils import print
from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL from cucyuqing.config import OPENAI_RISK_LLM_API_KEY, OPENAI_RISK_LLM_BASE_URL
from cucyuqing.pg import pool, get_cur from cucyuqing.pg import pool, get_cur
from cucyuqing.mysql import mysql from cucyuqing.mysql import mysql
from cucyuqing.dbscan import run_dbscan from cucyuqing.dbscan import Document, run_dbscan
async def main(): async def main():
await pool.open() while True:
try:
await do_analyze()
await asyncio.sleep(60 * 30)
except Exception as e:
print(e)
await asyncio.sleep(60 * 60)
async def do_analyze():
await asyncio.gather(
pool.open(),
mysql.connect(),
)
# 获取一个风险类型和对应的提示词
risk_types = await get_risk_type_prompt()
print("共有风险类型:", len(risk_types))
dbscan_result = await run_dbscan() dbscan_result = await run_dbscan()
docs = [cluster[0] for cluster in dbscan_result.clusters] docs = [cluster[0] for cluster in dbscan_result.clusters]
analyze_rusult = await batch_risk_analyze([doc.title for doc in docs]) print("共有待分析文档:", len(docs), "噪声", len(dbscan_result.noise))
for result, doc in zip(analyze_rusult, docs):
print(f"风险: {result} 标题: {doc.title}") risks_to_update: dict[str, set[str]] = {}
analyze_result = await batch_risk_analyze(docs, risk_types)
for task in analyze_result:
if "" not in task.response:
if risks_to_update.get(task.doc.id) is None:
risks_to_update[task.doc.id] = set()
continue
print(f"风险: {task.risk_type.name} 标题: {task.doc.title} {task.doc.id}")
# 合并每个文档的风险到一个set
if task.doc.id not in risks_to_update:
risks_to_update[task.doc.id] = set()
risks_to_update[task.doc.id].add(task.risk_type.name)
# 更新数据库
for doc_id, risks in risks_to_update.items():
await mysql.execute(
"""
UPDATE risk_news
SET risk_types = :risk_types, updated_at = now()
WHERE es_id = :es_id
""",
{
"es_id": doc_id,
"risk_types": (
json.dumps(list(risks), ensure_ascii=False) if risks else None
),
},
)
@dataclass
class RiskType:
name: str
prompt: str
@dataclass
class Task:
doc: Document
risk_type: RiskType
response: str = ""
async def get_risk_type_prompt() -> list[RiskType]:
"""从数据库中获取风险类型和对应的提示词"""
rows = await mysql.fetch_all(
"""
SELECT rp.content, rt.name
FROM risk_prompt rp
JOIN risk_type rt ON rp.risk_type_id = rt.id
ORDER BY rp.id DESC
"""
)
return [RiskType(prompt=row[0], name=row[1]) for row in rows]
async def batch_risk_analyze( async def batch_risk_analyze(
texts: list, model: str = "gpt-4o-mini", threads: int = 10 docs: list[Document],
) -> list: risk_types: list[RiskType],
tasks = [{"input": text} for text in texts] model: str = "gpt-4o-mini",
threads: int = 10,
) -> list[Task]:
"""文本风险分析(并行批批处理)"""
# 从 docs, risk_types 两个列表交叉生成任务列表
tasks: list[Task] = [
Task(doc=doc, risk_type=rt) for doc in docs for rt in risk_types
]
bar = tqdm.tqdm(total=len(tasks)) bar = tqdm.tqdm(total=len(tasks))
queue = asyncio.Queue() queue = asyncio.Queue()
@@ -34,9 +114,11 @@ async def batch_risk_analyze(
task = await queue.get() task = await queue.get()
if task is None: if task is None:
break break
task["response"] = await risk_analyze(task["input"], model) task.response = await risk_analyze(task, model)
queue.task_done() queue.task_done()
bar.update(1) bar.update(1)
if bar.n % 100 == 0:
print(f"已完成 {bar.n} 条风险分析")
async def producer(): async def producer():
for task in tasks: for task in tasks:
@@ -50,19 +132,18 @@ async def batch_risk_analyze(
await queue.put(None) await queue.put(None)
await asyncio.gather(*workers) await asyncio.gather(*workers)
return [task["response"] for task in tasks] print("风险分析完成")
return tasks
async def risk_analyze(text: str, model: str) -> str: async def risk_analyze(task: Task, model: str) -> str:
"""对一条文本进行风险分析"""
llm = openai.AsyncOpenAI( llm = openai.AsyncOpenAI(
api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL api_key=OPENAI_RISK_LLM_API_KEY, base_url=OPENAI_RISK_LLM_BASE_URL
) )
system_message = (
"你是一个新闻风险分析器,你要判断以下文本是否有风险,你只要回答是或者否。"
)
hash = hashlib.md5( hash = hashlib.md5(
model.encode() + b"|" + text.encode() + b"|" + system_message.encode() f"{model}|{task.doc.get_text_for_llm()}|{task.risk_type.prompt}".encode()
).hexdigest() ).hexdigest()
# 查询缓存 # 查询缓存
@@ -75,14 +156,15 @@ async def risk_analyze(text: str, model: str) -> str:
return row[0] return row[0]
messages: Iterable[ChatCompletionMessageParam] = [ messages: Iterable[ChatCompletionMessageParam] = [
{"role": "system", "content": system_message}, {"role": "system", "content": task.risk_type.prompt},
{"role": "user", "content": text}, {"role": "user", "content": task.doc.get_text_for_llm()},
] ]
resp = await llm.chat.completions.create( resp = await llm.chat.completions.create(
messages=messages, messages=messages,
model=model, model=model,
temperature=0, temperature=0,
stop="\n", stop="\n",
max_tokens=10,
) )
completions = resp.choices[0].message.content or "" completions = resp.choices[0].message.content or ""

View File

@@ -8,20 +8,34 @@ from sklearn.metrics import pairwise_distances
from cucyuqing.pg import pool, get_cur from cucyuqing.pg import pool, get_cur
@dataclass @dataclass
class Document: class Document:
id: int id: str
"""ID 是 ES 中的 32 为 hex ID"""
title: str title: str
content: str content: str
similarity: float = 0.0 similarity: float = 0.0
def get_text_for_llm(self) -> str:
"""只使用标题进行风险分析
对于空标题,在入库时已经处理过。
如果入库时标题为空则使用content的前20个字符或第一句中文作为标题。
"""
return self.title
@dataclass @dataclass
class DBScanResult: class DBScanResult:
noise: list[Document] noise: list[Document]
clusters: list[list[Document]] clusters: list[list[Document]]
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
async def run_dbscan() -> DBScanResult: async def run_dbscan() -> DBScanResult:
# 从 PG 数据库获取数据 # 从 PG 数据库获取数据
async with get_cur() as cur: async with get_cur() as cur:
@@ -30,23 +44,24 @@ async def run_dbscan() -> DBScanResult:
SELECT id, title, content, embedding SELECT id, title, content, embedding
FROM risk_news FROM risk_news
WHERE NOT embedding_updated_at IS NULL WHERE NOT embedding_updated_at IS NULL
AND time > now() - interval '14 day' AND time > now() - interval '7 day'
ORDER BY time desc ORDER BY time desc
LIMIT 10000 LIMIT 100000
;""" ;"""
) )
rows = await cur.fetchall() rows = await cur.fetchall()
docs: list[Document] = [ docs: list[Document] = [
Document(row[0], row[1], row[2]) Document(str(row[0]).replace("-", ""), row[1], row[2]) for row in rows
for row in rows
] ]
embeddings = [numpy.array(json.loads(row[3])) for row in rows] embeddings = [numpy.array(json.loads(row[3])) for row in rows]
# 计算余弦距离矩阵 # 计算余弦距离矩阵
cosine_distances = pairwise_distances(embeddings, metric='cosine') cosine_distances = pairwise_distances(embeddings, metric="cosine")
# 初始化DBSCAN模型 # 初始化DBSCAN模型
dbscan = DBSCAN(eps=0.25, min_samples=2, metric='precomputed') # Adjust eps as needed dbscan = DBSCAN(
eps=0.25, min_samples=2, metric="precomputed"
) # Adjust eps as needed
# 进行聚类 # 进行聚类
dbscan.fit(cosine_distances) dbscan.fit(cosine_distances)
@@ -58,9 +73,9 @@ async def run_dbscan() -> DBScanResult:
ret: DBScanResult = DBScanResult(noise=[], clusters=[]) ret: DBScanResult = DBScanResult(noise=[], clusters=[])
unique_labels = set(labels) unique_labels = set(labels)
for label in unique_labels: for label in unique_labels:
class_member_mask = (labels == label) class_member_mask = labels == label
cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore cluster_docs = [docs[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore cluster_embeddings = [embeddings[i] for i in range(len(labels)) if class_member_mask[i]] # type: ignore
if label == -1: if label == -1:
# -1 is the label for noise points # -1 is the label for noise points
@@ -78,17 +93,19 @@ async def run_dbscan() -> DBScanResult:
doc.similarity = similarities[i] doc.similarity = similarities[i]
sorted_cluster_docs.append(doc) sorted_cluster_docs.append(doc)
ret.clusters.append(sorted_cluster_docs) ret.clusters.append(sorted_cluster_docs)
return ret return ret
async def main(): async def main():
await pool.open() await pool.open()
result = await run_dbscan() result = await run_dbscan()
print(f"噪声文档: {len(result.noise)}") print(f"噪声文档: {len(result.noise)}")
for i, cluster in enumerate(result.clusters): for i, cluster in enumerate(result.clusters):
print('----------------') print("----------------")
for doc in cluster: for doc in cluster:
print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}") print(f"聚类 {i} 文档: {doc.title} 相似度: {doc.similarity}")
if __name__ == '__main__':
asyncio.run(main()) if __name__ == "__main__":
asyncio.run(main())

48
requirements_version.txt Normal file
View File

@@ -0,0 +1,48 @@
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiomysql==0.2.0
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.2.post1
attrs==24.2.0
certifi==2024.8.30
charset-normalizer==3.4.0
databases==0.9.0
distro==1.9.0
fastapi==0.115.2
filelock==3.16.1
frozenlist==1.4.1
fsspec==2024.9.0
greenlet==3.1.1
h11==0.14.0
httpcore==1.0.6
httpx==0.27.2
huggingface-hub==0.26.0
idna==3.10
jiter==0.6.1
joblib==1.4.2
multidict==6.1.0
numpy==2.1.2
openai==1.52.0
packaging==24.1
propcache==0.2.0
psycopg==3.2.3
psycopg-binary==3.2.3
psycopg-pool==3.2.3
pydantic==2.9.2
pydantic_core==2.23.4
PyMySQL==1.1.1
python-dotenv==1.0.1
PyYAML==6.0.2
requests==2.32.3
scikit-learn==1.5.2
scipy==1.14.1
sniffio==1.3.1
SQLAlchemy==2.0.36
starlette==0.40.0
threadpoolctl==3.5.0
tokenizers==0.20.1
tqdm==4.66.5
typing_extensions==4.12.2
urllib3==2.2.3
yarl==1.15.4