增加刷新 openai embedding 功能
This commit is contained in:
175
cucyuqing/cmd/embedding.py
Normal file
175
cucyuqing/cmd/embedding.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
import datetime
|
||||||
|
import asyncio
|
||||||
|
import tqdm
|
||||||
|
import os
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
import openai
|
||||||
|
import hashlib
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Literal
|
||||||
|
from cucyuqing.pg import pool, get_cur
|
||||||
|
from cucyuqing.config import OPENAI_API_KEY, OPENAI_BASE_URL
|
||||||
|
|
||||||
|
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-large"]
|
||||||
|
|
||||||
|
embedding_client = openai.AsyncOpenAI(
|
||||||
|
api_key=OPENAI_API_KEY,
|
||||||
|
base_url=OPENAI_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json")
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_length(text: str) -> int:
|
||||||
|
"""使用 openai 提供的 tokenizer **估算** token 长度"""
|
||||||
|
return len(tokenizer.encode(text).tokens)
|
||||||
|
|
||||||
|
|
||||||
|
class Task(BaseModel):
|
||||||
|
id: int
|
||||||
|
text: str
|
||||||
|
hash: str
|
||||||
|
embedding: list[float] | None
|
||||||
|
|
||||||
|
|
||||||
|
def hash_text(text: str, model: EmbeddingModel) -> str:
|
||||||
|
"""计算文本的哈希值"""
|
||||||
|
return hashlib.md5((text + "|" + model).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_text(text: str, max_length: int) -> str:
|
||||||
|
"""截断文本"""
|
||||||
|
tokens = tokenizer.encode(text).tokens[0:max_length]
|
||||||
|
return ''.join(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings(
|
||||||
|
texts: list[str],
|
||||||
|
model: EmbeddingModel,
|
||||||
|
threads: int = 1,
|
||||||
|
quiet: bool = False,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""获取embeddings函数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- text: 文本列表
|
||||||
|
- threads: 并发调用embedding接口线程数
|
||||||
|
- quiet: 是否关闭输出
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 针对 acge-large-zh 模型,需要将文本截断 1024 - 200
|
||||||
|
if model == "acge-large-zh":
|
||||||
|
texts = [truncate_text(text, 1024 - 2) for text in texts]
|
||||||
|
|
||||||
|
# 构建任务列表
|
||||||
|
ids = list(range(len(texts)))
|
||||||
|
hashes = [hash_text(i, model) for i in texts]
|
||||||
|
embeddings = (get_embedding_from_cache(hash) for hash in hashes)
|
||||||
|
embeddings = tqdm.tqdm(
|
||||||
|
embeddings, desc="Query embeddings cache", disable=quiet, total=len(texts)
|
||||||
|
)
|
||||||
|
tasks: list[Task] = [
|
||||||
|
Task(id=id, text=t, hash=hash, embedding=await embedding)
|
||||||
|
for id, t, hash, embedding in zip(ids, texts, hashes, embeddings)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 筛选出从缓存中查询不到的 embedding
|
||||||
|
query: list[Task] = [t for t in tasks if t.embedding is None]
|
||||||
|
|
||||||
|
# 将 query 切分称多个 batch, 每个 batch 的长度不超过过 4096, batch_size 不超过 32
|
||||||
|
max_batch_token_length = 8192
|
||||||
|
batch_query: list[list[Task]] = []
|
||||||
|
batch_token_length = 0 # TEMP
|
||||||
|
iter_batch: list[Task] = [] # TEMP
|
||||||
|
for q in query:
|
||||||
|
batch_token_length += get_token_length(q.text)
|
||||||
|
|
||||||
|
# 该批次已满,将该批次加入 batch_query
|
||||||
|
if batch_token_length > max_batch_token_length or len(iter_batch) >= 32:
|
||||||
|
batch_query.append(iter_batch)
|
||||||
|
iter_batch = [q]
|
||||||
|
batch_token_length = get_token_length(q.text)
|
||||||
|
continue
|
||||||
|
|
||||||
|
iter_batch.append(q)
|
||||||
|
|
||||||
|
# 最后收尾
|
||||||
|
if iter_batch:
|
||||||
|
batch_query.append(iter_batch)
|
||||||
|
|
||||||
|
# 定义进度条
|
||||||
|
pbar = tqdm.tqdm(
|
||||||
|
batch_query, desc="Requesting embeddings", disable=quiet, total=len(query)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定义 consumer
|
||||||
|
async def consumer() -> None:
|
||||||
|
while batch_query:
|
||||||
|
query = batch_query.pop()
|
||||||
|
if model == "acge-large-zh":
|
||||||
|
resp = await embedding_client.embeddings.create(
|
||||||
|
input=[q.text for q in query],
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif model == "text-embedding-3-large":
|
||||||
|
resp = await embedding_client.embeddings.create(
|
||||||
|
input=[q.text for q in query],
|
||||||
|
model="text-embedding-3-large",
|
||||||
|
dimensions=1024,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown model: {model} while calculating similarities"
|
||||||
|
)
|
||||||
|
|
||||||
|
data = resp.data
|
||||||
|
for q, d in zip(query, data):
|
||||||
|
q.embedding = d.embedding
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 并发启动
|
||||||
|
await asyncio.gather(*[consumer() for _ in range(threads)])
|
||||||
|
|
||||||
|
# 根据 task id 排序
|
||||||
|
ret: list[Task] = sorted(tasks, key=lambda x: x.id)
|
||||||
|
|
||||||
|
# 检查
|
||||||
|
assert len(tasks) == len(ret)
|
||||||
|
assert all(i.embedding is not None for i in ret)
|
||||||
|
|
||||||
|
return [i.embedding for i in ret] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding_from_cache(hash: str) -> list[float] | None:
|
||||||
|
"""根据 哈希 从缓存中查询 embedding
|
||||||
|
|
||||||
|
hash: 查询任务和哈希值,由文本和模型名称计算得到
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
res = await redis_client.get(f"embedding-{hash}")
|
||||||
|
if res is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(res, str):
|
||||||
|
raise ValueError(f"Unexpected type: {type(res)}")
|
||||||
|
return ujson.loads(res)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await pool.open()
|
||||||
|
async with get_cur() as cur:
|
||||||
|
# 这里选择 embedding_updated_at is null 使用索引避免全表扫描
|
||||||
|
await cur.execute("SELECT id, title, content from risk_news where embedding_updated_at is null limit 1000")
|
||||||
|
docs = await cur.fetchall()
|
||||||
|
if not docs:
|
||||||
|
print(datetime.datetime.now(), "No data to update")
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
return
|
||||||
|
embeddings = await get_embeddings([doc[1] + " " + doc[2] for doc in docs], "acge-large-zh")
|
||||||
|
async with get_cur() as cur:
|
||||||
|
for doc, embedding in tqdm.tqdm(zip(docs, embeddings), total=min(len(docs), len(embeddings)), desc="Update embeddings"):
|
||||||
|
await cur.execute("UPDATE risk_news SET embedding = %s, embedding_updated_at = now() where id = %s", (embedding, doc[0]))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -22,3 +22,5 @@ def must_get_env(key: str):
|
|||||||
ES_API = get_env_with_default("ES_API", "http://192.168.1.45:1444")
|
ES_API = get_env_with_default("ES_API", "http://192.168.1.45:1444")
|
||||||
PG_DSN = must_get_env("PG_DSN")
|
PG_DSN = must_get_env("PG_DSN")
|
||||||
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
||||||
|
OPENAI_API_KEY = must_get_env("OPENAI_API_KEY")
|
||||||
|
OPENAI_BASE_URL = get_env_with_default("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
|||||||
32
cucyuqing/res/acge-large-zh/config.json
Normal file
32
cucyuqing/res/acge-large-zh/config.json
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "acge",
|
||||||
|
"architectures": [
|
||||||
|
"BertModel"
|
||||||
|
],
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"classifier_dropout": null,
|
||||||
|
"directionality": "bidi",
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"layer_norm_eps": 1e-12,
|
||||||
|
"max_position_embeddings": 1024,
|
||||||
|
"model_type": "bert",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"pooler_fc_size": 768,
|
||||||
|
"pooler_num_attention_heads": 12,
|
||||||
|
"pooler_num_fc_layers": 3,
|
||||||
|
"pooler_size_per_head": 128,
|
||||||
|
"pooler_type": "first_token_transform",
|
||||||
|
"position_embedding_type": "absolute",
|
||||||
|
"torch_dtype": "float16",
|
||||||
|
"transformers_version": "4.28.0",
|
||||||
|
"type_vocab_size": 2,
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 21128
|
||||||
|
}
|
||||||
21278
cucyuqing/res/acge-large-zh/tokenizer.json
Normal file
21278
cucyuqing/res/acge-large-zh/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,3 +5,6 @@ aiohttp
|
|||||||
fastapi
|
fastapi
|
||||||
pydantic
|
pydantic
|
||||||
databases[aiomysql]
|
databases[aiomysql]
|
||||||
|
openai
|
||||||
|
tiktoken
|
||||||
|
tokenizers
|
||||||
|
|||||||
Reference in New Issue
Block a user