Files
matrix-chain/bot_chatgpt.py
2023-12-06 17:25:46 +08:00

402 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import dotenv
dotenv.load_dotenv()
import asyncio
import jinja2
import requests
import datetime
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from nio import MatrixRoom, RoomMessageFile, RoomMessageText
from langchain.chat_models import ChatOpenAI
import json
from langchain import LLMChain
from langchain.prompts import ChatPromptTemplate
from bot import Bot, print
from langchain.embeddings import OpenAIEmbeddings, awa
embeddings_model = OpenAIEmbeddings(
openai_api_key=os.environ["OPENAI_API_KEY"],
openai_api_base=os.environ["OPENAI_API_BASE"],
show_progress_bar=True,
)
client = Bot(
os.environ["BOT_CHATGPT_HOMESERVER"],
os.environ["BOT_CHATGPT_USER"],
os.environ["MATRIX_CHAIN_DEVICE"],
os.environ["BOT_CHATGPT_ACCESS_TOKEN"],
)
client.welcome_message = """你好👋,我是 matrix chain 中的大语言模型插件
## 使用方式:
- 直接在房间内发送消息GPT 会在消息列中进行回复。GPT 会单独记住每个消息列中的所有内容,每个消息列单独存在互不干扰
## 配置方式:
- 发送 "!system + 系统消息" 配置大语言模型的角色,例如发送 "!system 你是一个专业英语翻译,你要把我说的话翻译成英语。你可以调整语序结构和用词让翻译更加通顺。"
"""
class SilentUndefined(jinja2.Undefined):
def _fail_with_undefined_error(self, *args, **kwargs):
print(f'jinja2.Undefined: "{self._undefined_name}" is undefined')
return ""
def render(template: str, **kargs) -> str:
env = jinja2.Environment(undefined=SilentUndefined)
temp = env.from_string(template)
def now() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d")
temp.globals["now"] = now
return temp.render(**kargs)
async def get_reply_file_content(event):
"""When user reply to a event, retrive the file content (document) of event
Return with the file content and token length
"""
flattened = event.flattened()
formatted_body = flattened.get("content.formatted_body", "")
if not (
formatted_body.startswith("<mx-reply>") and "</mx-reply>" in formatted_body
):
return "", 0
print("replacing file content")
formatted_body = formatted_body[
formatted_body.index("</mx-reply>") + len("</mx-reply>") :
]
document_event_id = flattened.get("content.m.relates_to.m.in_reply_to.event_id", "")
fetch = await client.db.fetch_one(
query="""select d.content, d.token
from documents d
join event_document ed on d.md5 = ed.document_md5
where ed.event = :document_event_id""",
values={
"document_event_id": document_event_id,
},
)
if fetch and fetch[1] < 8192 + 4096:
content = fetch[0]
token = fetch[1]
print(content)
print(token)
print("-----------")
return content, token
print("document not found or too large", event.event_id)
return "", 0
@client.ignore_link
@client.message_callback_common_wrapper
async def message_callback(room: MatrixRoom, event: RoomMessageText) -> None:
# handle set system message
if event.body.startswith("!"):
if event.body.startswith("!system"):
systemMessageContent = event.body.lstrip("!system").strip()
# save to db
await client.db.execute(
query="""
insert into room_configs (room, system, examples)
values (:room_id, :systemMessageContent, '{}')
on conflict (room)
do update set system = excluded.system, examples = '{}'
""",
values={
"room_id": room.room_id,
"systemMessageContent": systemMessageContent,
},
)
await client.react_ok(room.room_id, event.event_id)
return
if event.body.startswith("!model"):
model_name = event.body.lstrip("!model").strip()
# save to db
await client.db.execute(
query="""
insert into room_configs (room, model_name)
values (:room_id, :model_name)
on conflict (room)
do update set model_name = excluded.model_name
""",
values={"room_id": room.room_id, "model_name": model_name},
)
await client.react_ok(room.room_id, event.event_id)
return
if event.body.startswith("!temp"):
temperature = float(event.body.lstrip("!temp").strip())
# save to db
await client.db.execute(
query="""
insert into room_configs (room, temperature)
values (:room_id, :temperature)
on conflict (room)
do update set temperature = excluded.temperature
""",
values={"room_id": room.room_id, "temperature": temperature},
)
await client.react_ok(room.room_id, event.event_id)
return
return
messages: list[BaseMessage] = []
# query prompt from db
db_result = await client.db.fetch_one(
query="""
select system, examples, model_name, temperature
from room_configs
where room = :room_id
limit 1
""",
values={"room_id": room.room_id},
)
model_name: str = db_result[2] if db_result else ""
temperature: float = db_result[3] if db_result else 0
systemMessageContent: str = db_result[0] if db_result else ""
systemMessageContent = systemMessageContent or ""
if systemMessageContent:
messages.append(SystemMessage(content=systemMessageContent))
examples = db_result[1] if db_result else []
for i, m in enumerate(examples):
if not m:
print("Warning: message is empty", m)
continue
if i % 2 == 0:
messages.append(HumanMessage(content=m["content"], example=True))
else:
messages.append(AIMessage(content=m["content"], example=True))
exampleTokens = 0
exampleTokens += sum(client.get_token_length(m) for m in examples if m)
# get embedding
embedding_query = await client.db.fetch_all(
query="""
select content, distance, total_token from (
select
content,
document_md5,
distance,
sum(token) over (partition by room order by distance) as total_token
from (
select
content,
rd.room,
e.document_md5,
e.embedding <#> :embedding as distance,
token
from embeddings e
join room_document rd on rd.document_md5 = e.document_md5
join room_configs rc on rc.room = rd.room
where rd.room = :room_id and rc.embedding
order by distance
limit 16
) as sub
) as sub2
where total_token < 6144
;""",
values={
"embedding": str(await embeddings_model.aembed_query(event.body)),
"room_id": room.room_id,
},
)
print("emebdding_query", embedding_query)
embedding_token = 0
embedding_text = ""
if len(embedding_query) > 0:
embedding_query.reverse()
embedding_text = "\n\n".join([i[0] for i in embedding_query])
embedding_token = client.get_token_length(embedding_text)
filecontent, filetoken = await get_reply_file_content(event)
# query memory from db
max_token = 4096 * 4
token_margin = 4096
system_token = client.get_token_length(systemMessageContent) + exampleTokens
memory_token = max_token - token_margin - system_token - embedding_token - filetoken
print(
"system_token",
system_token,
"emebdding_token",
embedding_token,
"filetoken",
filetoken,
)
rows = await client.db.fetch_all(
query="""select role, content from (
select role, content, sum(token) over (partition by root order by id desc) as total_token
from memories
where root = :root
order by id
) as sub
where total_token < :token
;""",
values={"root": event.event_id, "token": memory_token},
)
for role, content in rows:
if role == 1:
messages.append(HumanMessage(content=content))
elif role == 2:
messages.append(AIMessage(content=content))
else:
print("Unknown message role", role, content)
temp = "{{input}}"
if filecontent and embedding_text:
temp = """## Reference information:
{{embedding}}
---
## Query document:
{{filecontent}}
---
{{input}}"""
elif embedding_text:
temp = """## Reference information:
{{embedding}}
---
{{input}}"""
elif filecontent:
temp = """ ## Query document:
{{filecontent}}
---
{{input}}"""
temp = render(
temp, input=event.body, embedding=embedding_text, filecontent=filecontent
)
messages.append(HumanMessage(content=temp))
total_token = (
sum(client.get_token_length(m.content) for m in messages) + len(messages) * 6
)
if not model_name:
model_name = "gpt-3.5-turbo-1106"
print("messages", messages)
chat_model = ChatOpenAI(
openai_api_base=os.environ["OPENAI_API_BASE"],
openai_api_key=os.environ["OPENAI_API_KEY"],
model=model_name,
temperature=temperature,
)
chain = LLMChain(llm=chat_model, prompt=ChatPromptTemplate.from_messages(messages))
result = await chain.arun(
{
"input": event.body,
"embedding": embedding_text,
"filecontent": filecontent,
}
)
print(result)
await client.room_send(
room.room_id,
"m.room.message",
{
"body": result,
"msgtype": "m.text",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": event.event_id,
},
},
)
# record query and result
await client.db.execute_many(
query="insert into memories(root, role, content, token) values (:root, :role, :content, :token)",
values=[
{
"root": event.event_id,
"role": 1,
"content": event.body,
"token": client.get_token_length(event.body),
},
{
"root": event.event_id,
"role": 2,
"content": result,
"token": client.get_token_length(result),
},
],
)
client.add_event_callback(message_callback, RoomMessageText)
@client.ignore_self_message
@client.handel_no_gpt
@client.log_message
@client.with_typing
@client.replace_command_mark
@client.safe_try
async def message_file(room: MatrixRoom, event: RoomMessageFile):
if not event.flattened().get("content.info.mimetype") == "application/json":
print("not application/json")
return
size = event.flattened().get("content.info.size", 1024 * 1024 + 1)
if size > 1024 * 1024:
print("json file too large")
return
print("event url", event.url)
j = requests.get(
f'https://yongyuancv.cn/_matrix/media/r0/download/yongyuancv.cn/{event.url.rsplit("/", 1)[-1]}'
).json()
if j.get("chatgpt_api_web_version") is None:
print("not chatgpt-api-web chatstore export file")
return
if j["chatgpt_api_web_version"] < "v1.5.0":
raise ValueError(j["chatgpt_api_web_version"])
examples = [m["content"] for m in j["history"] if m["example"]]
await client.db.execute(
query="""
insert into room_configs (room, system, examples)
values (:room_id, :system, :examples)
on conflict (room) do update set system = excluded.system, examples = excluded.examples
""",
values={
"room_id": room.room_id,
"system": j["systemMessageContent"],
"examples": str(examples),
},
)
await client.room_send(
room.room_id,
"m.reaction",
{
"m.relates_to": {
"event_id": event.event_id,
"key": "😘",
"rel_type": "m.annotation",
}
},
)
client.add_event_callback(message_file, RoomMessageFile)
asyncio.run(client.sync_forever())