添加 cl100k_base 字典与 text-embedding-3 适配

This commit is contained in:
2024-09-19 14:24:19 +08:00
parent 4b5e59f35d
commit 92e3699cd8
2 changed files with 200367 additions and 10 deletions

View File

@@ -11,7 +11,7 @@ from cucyuqing.pg import pool, get_cur
from cucyuqing.config import OPENAI_API_KEY, OPENAI_BASE_URL from cucyuqing.config import OPENAI_API_KEY, OPENAI_BASE_URL
from cucyuqing.utils import print from cucyuqing.utils import print
EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-large"] EmbeddingModel = Literal["acge-large-zh", "text-embedding-3-small"]
embedding_client = openai.AsyncOpenAI( embedding_client = openai.AsyncOpenAI(
api_key=OPENAI_API_KEY, api_key=OPENAI_API_KEY,
@@ -19,11 +19,14 @@ embedding_client = openai.AsyncOpenAI(
) )
tokenizer = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json") tokenizers: dict[EmbeddingModel, Any] = {}
tokenizers['acge-large-zh'] = Tokenizer.from_file("cucyuqing/res/acge-large-zh/tokenizer.json")
tokenizers['text-embedding-3-small'] = Tokenizer.from_file("cucyuqing/res/cl100k_base/tokenizer.json")
def get_token_length(text: str) -> int: def get_token_length(model_name: EmbeddingModel, text: str) -> int:
"""使用 openai 提供的 tokenizer **估算** token 长度""" """使用 openai 提供的 tokenizer **估算** token 长度"""
tokenizer = tokenizers[model_name]
return len(tokenizer.encode(text).tokens) return len(tokenizer.encode(text).tokens)
@@ -39,8 +42,9 @@ def hash_text(text: str, model: EmbeddingModel) -> str:
return hashlib.md5((text + "|" + model).encode()).hexdigest() return hashlib.md5((text + "|" + model).encode()).hexdigest()
def truncate_text(text: str, max_length: int) -> str: def truncate_text(model_name: EmbeddingModel, text: str, max_length: int) -> str:
"""截断文本""" """截断文本"""
tokenizer = tokenizers[model_name]
tokens = tokenizer.encode(text).tokens[0:max_length] tokens = tokenizer.encode(text).tokens[0:max_length]
return ''.join(tokens) return ''.join(tokens)
@@ -59,9 +63,9 @@ async def get_embeddings(
- quiet: 是否关闭输出 - quiet: 是否关闭输出
""" """
# 针对 acge-large-zh 模型,需要将文本截断 1024 - 200 # 针对 acge-large-zh 模型,需要将文本截断 1024 - 2
if model == "acge-large-zh": if model == "acge-large-zh":
texts = [truncate_text(text, 1024 - 2) for text in texts] texts = [truncate_text(model, text, 1024 - 2) for text in texts]
# 构建任务列表 # 构建任务列表
ids = list(range(len(texts))) ids = list(range(len(texts)))
@@ -84,13 +88,13 @@ async def get_embeddings(
batch_token_length = 0 # TEMP batch_token_length = 0 # TEMP
iter_batch: list[Task] = [] # TEMP iter_batch: list[Task] = [] # TEMP
for q in query: for q in query:
batch_token_length += get_token_length(q.text) batch_token_length += get_token_length(model, q.text)
# 该批次已满,将该批次加入 batch_query # 该批次已满,将该批次加入 batch_query
if batch_token_length > max_batch_token_length or len(iter_batch) >= 32: if batch_token_length > max_batch_token_length or len(iter_batch) >= 32:
batch_query.append(iter_batch) batch_query.append(iter_batch)
iter_batch = [q] iter_batch = [q]
batch_token_length = get_token_length(q.text) batch_token_length = get_token_length(model, q.text)
continue continue
iter_batch.append(q) iter_batch.append(q)
@@ -113,10 +117,10 @@ async def get_embeddings(
input=[q.text for q in query], input=[q.text for q in query],
model=model, model=model,
) )
elif model == "text-embedding-3-large": elif model == "text-embedding-3-small":
resp = await embedding_client.embeddings.create( resp = await embedding_client.embeddings.create(
input=[q.text for q in query], input=[q.text for q in query],
model="text-embedding-3-large", model=model,
dimensions=1024, dimensions=1024,
) )
else: else:

File diff suppressed because it is too large Load Diff