Compare commits
5 Commits
dd41d6fa5f
...
71fd688227
| Author | SHA1 | Date | |
|---|---|---|---|
|
71fd688227
|
|||
|
0b658dee88
|
|||
|
92e3699cd8
|
|||
|
4b5e59f35d
|
|||
|
7395d98ce3
|
@@ -1,5 +1,7 @@
|
|||||||
|
import traceback
|
||||||
import datetime
|
import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from sqlalchemy.sql.ddl import exc
|
||||||
import tqdm
|
import tqdm
|
||||||
import os
|
import os
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
@@ -11,7 +13,7 @@ from cucyuqing.pg import pool, get_cur
|
|||||||
from cucyuqing.config import OPENAI_API_KEY, OPENAI_BASE_URL
|
from cucyuqing.config import OPENAI_API_KEY, OPENAI_BASE_URL
|
||||||
from cucyuqing.utils import print
|
from cucyuqing.utils import print
|
||||||
|
|
||||||
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-large"]
|
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-small"]
|
||||||
|
|
||||||
embedding_client = openai.AsyncOpenAI(
|
embedding_client = openai.AsyncOpenAI(
|
||||||
api_key=OPENAI_API_KEY,
|
api_key=OPENAI_API_KEY,
|
||||||
@@ -19,11 +21,14 @@ embedding_client = openai.AsyncOpenAI(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
tokenizer = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json")
|
tokenizers: dict[EmbeddingModel, Any] = {}
|
||||||
|
tokenizers['acge-large-zh'] = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json")
|
||||||
|
tokenizers['text-embedding-3-small'] = Tokenizer.from_file("cucyuqing/res/cl100k_base/tokenizer.json")
|
||||||
|
|
||||||
|
|
||||||
def get_token_length(text: str) -> int:
|
def get_token_length(model_name: EmbeddingModel, text: str) -> int:
|
||||||
"""使用 openai 提供的 tokenizer **估算** token 长度"""
|
"""使用 openai 提供的 tokenizer **估算** token 长度"""
|
||||||
|
tokenizer = tokenizers[model_name]
|
||||||
return len(tokenizer.encode(text).tokens)
|
return len(tokenizer.encode(text).tokens)
|
||||||
|
|
||||||
|
|
||||||
@@ -39,8 +44,9 @@ def hash_text(text: str, model: EmbeddingModel) -> str:
|
|||||||
return hashlib.md5((text + "|" + model).encode()).hexdigest()
|
return hashlib.md5((text + "|" + model).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def truncate_text(text: str, max_length: int) -> str:
|
def truncate_text(model_name: EmbeddingModel, text: str, max_length: int) -> str:
|
||||||
"""截断文本"""
|
"""截断文本"""
|
||||||
|
tokenizer = tokenizers[model_name]
|
||||||
tokens = tokenizer.encode(text).tokens[0:max_length]
|
tokens = tokenizer.encode(text).tokens[0:max_length]
|
||||||
return ''.join(tokens)
|
return ''.join(tokens)
|
||||||
|
|
||||||
@@ -59,9 +65,9 @@ async def get_embeddings(
|
|||||||
- quiet: 是否关闭输出
|
- quiet: 是否关闭输出
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 针对 acge-large-zh 模型,需要将文本截断 1024 - 200
|
# 针对 acge-large-zh 模型,需要将文本截断 1024 - 2
|
||||||
if model == "acge-large-zh":
|
if model == "acge-large-zh":
|
||||||
texts = [truncate_text(text, 1024 - 2) for text in texts]
|
texts = [truncate_text(model, text, 1024 - 2) for text in texts]
|
||||||
|
|
||||||
# 构建任务列表
|
# 构建任务列表
|
||||||
ids = list(range(len(texts)))
|
ids = list(range(len(texts)))
|
||||||
@@ -84,13 +90,13 @@ async def get_embeddings(
|
|||||||
batch_token_length = 0 # TEMP
|
batch_token_length = 0 # TEMP
|
||||||
iter_batch: list[Task] = [] # TEMP
|
iter_batch: list[Task] = [] # TEMP
|
||||||
for q in query:
|
for q in query:
|
||||||
batch_token_length += get_token_length(q.text)
|
batch_token_length += get_token_length(model, q.text)
|
||||||
|
|
||||||
# 该批次已满,将该批次加入 batch_query
|
# 该批次已满,将该批次加入 batch_query
|
||||||
if batch_token_length > max_batch_token_length or len(iter_batch) >= 32:
|
if batch_token_length > max_batch_token_length or len(iter_batch) >= 32:
|
||||||
batch_query.append(iter_batch)
|
batch_query.append(iter_batch)
|
||||||
iter_batch = [q]
|
iter_batch = [q]
|
||||||
batch_token_length = get_token_length(q.text)
|
batch_token_length = get_token_length(model, q.text)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
iter_batch.append(q)
|
iter_batch.append(q)
|
||||||
@@ -113,10 +119,10 @@ async def get_embeddings(
|
|||||||
input=[q.text for q in query],
|
input=[q.text for q in query],
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif model == "text-embedding-3-large":
|
elif model == "text-embedding-3-small":
|
||||||
resp = await embedding_client.embeddings.create(
|
resp = await embedding_client.embeddings.create(
|
||||||
input=[q.text for q in query],
|
input=[q.text for q in query],
|
||||||
model="text-embedding-3-large",
|
model=model,
|
||||||
dimensions=1024,
|
dimensions=1024,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -160,8 +166,12 @@ async def get_embedding_from_cache(hash: str) -> list[float] | None:
|
|||||||
async def main():
|
async def main():
|
||||||
await pool.open()
|
await pool.open()
|
||||||
while True:
|
while True:
|
||||||
await do_update()
|
try:
|
||||||
await asyncio.sleep(60)
|
await do_update()
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
finally:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
async def do_update():
|
async def do_update():
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
ES 数据同步脚本
|
ES 数据同步脚本
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
@@ -180,11 +181,11 @@ async def sync():
|
|||||||
await cur.execute(
|
await cur.execute(
|
||||||
"""
|
"""
|
||||||
WITH RECURSIVE time_slots AS (
|
WITH RECURSIVE time_slots AS (
|
||||||
SELECT date_trunc('hour', now() - interval '1 hour') - interval '14 day' AS start_time
|
SELECT date_trunc('hour', now() - interval '2 hour') - interval '14 day' AS start_time
|
||||||
UNION ALL
|
UNION ALL
|
||||||
SELECT start_time + INTERVAL '1 hour'
|
SELECT start_time + INTERVAL '1 hour'
|
||||||
FROM time_slots
|
FROM time_slots
|
||||||
WHERE start_time < date_trunc('hour', now() - interval '1 hour')
|
WHERE start_time < date_trunc('hour', now() - interval '2 hour')
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
ts.start_time,
|
ts.start_time,
|
||||||
@@ -219,9 +220,15 @@ async def sync():
|
|||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
while True:
|
while True:
|
||||||
await sync()
|
try:
|
||||||
print("同步完成,等待下一轮同步")
|
await sync()
|
||||||
await asyncio.sleep(60)
|
print("同步完成,等待下一轮同步")
|
||||||
|
except Exception as e:
|
||||||
|
# 打印出错误堆栈
|
||||||
|
traceback.print_exc()
|
||||||
|
print("同步出错,等待 60 秒后重试", e)
|
||||||
|
finally:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -22,5 +22,5 @@ def must_get_env(key: str):
|
|||||||
ES_API = get_env_with_default("ES_API", "http://192.168.1.45:1444")
|
ES_API = get_env_with_default("ES_API", "http://192.168.1.45:1444")
|
||||||
PG_DSN = must_get_env("PG_DSN")
|
PG_DSN = must_get_env("PG_DSN")
|
||||||
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
MYSQL_DSN = must_get_env("MYSQL_DSN")
|
||||||
OPENAI_API_KEY = must_get_env("OPENAI_API_KEY")
|
OPENAI_EMBEDDING_API_KEY = must_get_env("OPENAI_EMBEDDING_API_KEY")
|
||||||
OPENAI_BASE_URL = get_env_with_default("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
OPENAI_EMBEDDING_BASE_URL = get_env_with_default("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
|||||||
200353
cucyuqing/res/cl100k_base/tokenizer.json
Normal file
200353
cucyuqing/res/cl100k_base/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user