Compare commits

...

5 Commits

Author SHA1 Message Date
71fd688227 修改 OPENAI 环境变量为 OPENAI_EMBEDDING 2024-09-19 14:29:30 +08:00
0b658dee88 添加 embedding 错误处理 2024-09-19 14:27:42 +08:00
92e3699cd8 添加 cl100k_base 字典与 text-embedding-3 适配 2024-09-19 14:26:20 +08:00
4b5e59f35d 同步添加重试机制 2024-09-19 11:07:16 +08:00
7395d98ce3 延迟一小时入库
-1 hour -> start_time 相当于即时入库
-2 hour -> start_time -1 -> end_time 相当于延迟一小时入库
2024-09-19 10:55:01 +08:00
4 changed files with 200389 additions and 19 deletions

View File

@@ -1,5 +1,7 @@
import traceback
import datetime
import asyncio
from sqlalchemy.sql.ddl import exc
import tqdm
import os
from tokenizers import Tokenizer
@@ -11,7 +13,7 @@ from cucyuqing.pg import pool, get_cur
from cucyuqing.config import OPENAI_API_KEY, OPENAI_BASE_URL
from cucyuqing.utils import print
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-large"]
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-small"]
embedding_client = openai.AsyncOpenAI(
api_key=OPENAI_API_KEY,
@@ -19,11 +21,14 @@ embedding_client = openai.AsyncOpenAI(
)
tokenizer = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json")
tokenizers: dict[EmbeddingModel, Any] = {}
tokenizers['acge-large-zh'] = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json")
tokenizers['text-embedding-3-small'] = Tokenizer.from_file("cucyuqing/res/cl100k_base/tokenizer.json")
def get_token_length(text: str) -> int:
def get_token_length(model_name: EmbeddingModel, text: str) -> int:
"""使用 openai 提供的 tokenizer **估算** token 长度"""
tokenizer = tokenizers[model_name]
return len(tokenizer.encode(text).tokens)
@@ -39,8 +44,9 @@ def hash_text(text: str, model: EmbeddingModel) -> str:
return hashlib.md5((text + "|" + model).encode()).hexdigest()
def truncate_text(text: str, max_length: int) -> str:
def truncate_text(model_name: EmbeddingModel, text: str, max_length: int) -> str:
"""截断文本"""
tokenizer = tokenizers[model_name]
tokens = tokenizer.encode(text).tokens[0:max_length]
return ''.join(tokens)
@@ -59,9 +65,9 @@ async def get_embeddings(
- quiet: 是否关闭输出
"""
# 针对 acge-large-zh 模型,需要将文本截断 1024 - 200
# 针对 acge-large-zh 模型,需要将文本截断 1024 - 2
if model == "acge-large-zh":
texts = [truncate_text(text, 1024 - 2) for text in texts]
texts = [truncate_text(model, text, 1024 - 2) for text in texts]
# 构建任务列表
ids = list(range(len(texts)))
@@ -84,13 +90,13 @@ async def get_embeddings(
batch_token_length = 0 # TEMP
iter_batch: list[Task] = [] # TEMP
for q in query:
batch_token_length += get_token_length(q.text)
batch_token_length += get_token_length(model, q.text)
# 该批次已满,将该批次加入 batch_query
if batch_token_length > max_batch_token_length or len(iter_batch) >= 32:
batch_query.append(iter_batch)
iter_batch = [q]
batch_token_length = get_token_length(q.text)
batch_token_length = get_token_length(model, q.text)
continue
iter_batch.append(q)
@@ -113,10 +119,10 @@ async def get_embeddings(
input=[q.text for q in query],
model=model,
)
elif model == "text-embedding-3-large":
elif model == "text-embedding-3-small":
resp = await embedding_client.embeddings.create(
input=[q.text for q in query],
model="text-embedding-3-large",
model=model,
dimensions=1024,
)
else:
@@ -160,8 +166,12 @@ async def get_embedding_from_cache(hash: str) -> list[float] | None:
async def main():
await pool.open()
while True:
await do_update()
await asyncio.sleep(60)
try:
await do_update()
except Exception as e:
print(traceback.format_exc())
finally:
await asyncio.sleep(60)
async def do_update():
while True:

View File

@@ -2,6 +2,7 @@
ES 数据同步脚本
"""
import traceback
import asyncio
import time
import json
@@ -180,11 +181,11 @@ async def sync():
await cur.execute(
"""
WITH RECURSIVE time_slots AS (
SELECT date_trunc('hour', now() - interval '1 hour') - interval '14 day' AS start_time
SELECT date_trunc('hour', now() - interval '2 hour') - interval '14 day' AS start_time
UNION ALL
SELECT start_time + INTERVAL '1 hour'
FROM time_slots
WHERE start_time < date_trunc('hour', now() - interval '1 hour')
WHERE start_time < date_trunc('hour', now() - interval '2 hour')
)
SELECT
ts.start_time,
@@ -219,9 +220,15 @@ async def sync():
async def main():
while True:
await sync()
print("同步完成,等待下一轮同步")
await asyncio.sleep(60)
try:
await sync()
print("同步完成,等待下一轮同步")
except Exception as e:
# 打印出错误堆栈
traceback.print_exc()
print("同步出错,等待 60 秒后重试", e)
finally:
await asyncio.sleep(60)
if __name__ == "__main__":

View File

@@ -22,5 +22,5 @@ def must_get_env(key: str):
ES_API = get_env_with_default("ES_API", "http://192.168.1.45:1444")
PG_DSN = must_get_env("PG_DSN")
MYSQL_DSN = must_get_env("MYSQL_DSN")
OPENAI_API_KEY = must_get_env("OPENAI_API_KEY")
OPENAI_BASE_URL = get_env_with_default("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY")
OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")

File diff suppressed because it is too large Load Diff