Files
matrix-chain/bot.py
2023-10-21 21:14:57 +08:00

316 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import dotenv
dotenv.load_dotenv()
import os
import traceback
from functools import wraps
from nio import AsyncClient, MatrixRoom, RoomMessageText, SyncError, InviteEvent
import tiktoken
import databases
import builtins
import sys
def print(*args, **kwargs):
kwargs["file"] = sys.stderr
builtins.print(*args, **kwargs)
class Bot(AsyncClient):
def __init__(self, homeserver: str, user: str, device_id: str, access_token: str):
super().__init__(homeserver)
self.access_token = access_token
self.user_id = self.user = user
self.device_id = device_id
self.welcome_message = ""
self._joined_rooms = []
self.db = databases.Database(os.environ["MATRIX_CHAIN_DB"])
self.enc = tiktoken.encoding_for_model("gpt-4")
# auto join
self.add_event_callback(self.auto_join, InviteEvent)
async def init_db(self):
db = self.db
await db.execute("CREATE EXTENSION IF NOT EXISTS vector")
await db.execute(
"""
CREATE TABLE IF NOT EXISTS documents
(
md5 character(32) NOT NULL PRIMARY KEY,
content text,
token integer,
url text
)
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS embeddings
(
document_md5 character(32) NOT NULL,
md5 character(32) NOT NULL,
content text NOT NULL,
token integer NOT NULL,
embedding vector(1536) NOT NULL,
PRIMARY KEY (document_md5, md5),
FOREIGN KEY (document_md5) REFERENCES documents(md5)
);
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS event_document
(
event text NOT NULL PRIMARY KEY,
document_md5 character(32) NOT NULL,
FOREIGN KEY (document_md5)
REFERENCES documents (md5)
);
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS memories
(
id SERIAL PRIMARY KEY,
root text NOT NULL,
role integer NOT NULL,
content text NOT NULL,
token integer NOT NULL
)
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS room_configs
(
room text NOT NULL PRIMARY KEY,
model_name text,
temperature float NOT NULL DEFAULT 0,
system text,
embedding boolean NOT NULL DEFAULT false,
examples TEXT[] NOT NULL DEFAULT '{}'
)
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS room_document (
room text NOT NULL,
document_md5 character(32) NOT NULL,
PRIMARY KEY (room, document_md5)
);
"""
)
def get_token_length(self, text: str) -> int:
return len(self.enc.encode(text))
async def sync_forever(self):
# init
print("connecting to db")
await self.db.connect()
# init db hook
print("init db hook")
await self.init_db()
# remote callback to perform initial sync
callbacks = self.event_callbacks
self.event_callbacks = []
# skip intial sync
print("Perform initial sync")
resp = await self.sync(timeout=30000)
if isinstance(resp, SyncError):
raise BaseException(SyncError)
self.event_callbacks = callbacks
# set online
print("Set online status")
await self.set_presence("online")
# sync
print("Start forever sync")
return await super().sync_forever(300000, since=resp.next_batch)
async def auto_join(self, room: MatrixRoom, event: InviteEvent):
print("join", event.sender, room.room_id)
if room.room_id in self._joined_rooms:
return
await self.join(room.room_id)
self._joined_rooms.append(room.room_id)
if self.welcome_message:
await self.room_send(
room.room_id,
"m.room.message",
{
"body": self.welcome_message,
"msgtype": "m.text",
"nogpt": True,
},
)
def ignore_self_message(self, func):
@wraps(func)
async def ret(room: MatrixRoom, event: RoomMessageText):
if event.sender == self.user:
return
return await func(room, event)
return ret
def log_message(self, func):
@wraps(func)
async def ret(room: MatrixRoom, event: RoomMessageText):
print(room.room_id, event.sender, event.body)
return await func(room, event)
return ret
def with_typing(self, func):
@wraps(func)
async def ret(room, *args, **kargs):
await self.room_typing(room.room_id, True, 60000 * 3)
resp = await func(room, *args, **kargs)
await self.room_typing(room.room_id, False)
return resp
return ret
def change_event_id_to_root_id(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
root = event.event_id
if event.flattened().get("content.m.relates_to.rel_type") == "m.thread":
root = event.source["content"]["m.relates_to"]["event_id"]
event.event_id = root
return await func(room, event, *args, **kargs)
return ret
def ignore_not_mentioned(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
flattened = event.flattened()
if not self.user in flattened.get(
"content.body", ""
) and not self.user in flattened.get("content.formatted_body", ""):
return
return await func(room, event, *args, **kargs)
return ret
def replace_command_mark(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
if event.body.startswith(""):
event.body = "!" + event.body[1:]
return await func(room, event, *args, **kargs)
return ret
def handel_no_gpt(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
if not event.flattened().get("content.nogpt") is None:
return
return await func(room, event, *args, **kargs)
return ret
def replace_reply_file_with_content(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
flattened = event.flattened()
formatted_body = flattened.get("content.formatted_body", "")
if (
formatted_body.startswith("<mx-reply>")
and "</mx-reply>" in formatted_body
):
print("replacing file content")
formatted_body = formatted_body[
formatted_body.index("</mx-reply>") + len("</mx-reply>") :
]
document_event_id = flattened.get(
"content.m.relates_to.m.in_reply_to.event_id", ""
)
fetch = await self.db.fetch_all(
query="""select d.content, d.token
from documents d
join event_document ed on d.md5 = ed.document_md5
where ed.event = :event_id""",
values={"event_id": document_event_id},
)
if len(fetch) > 0 and fetch[0][1] < 8192 + 4096:
content = fetch[0][0]
print(content)
print("-----------")
event.body = content + "\n\n---\n\n" + formatted_body
else:
print("document not found or too large", event.event_id)
return await func(room, event, *args, **kargs)
return ret
def ignore_link(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
if event.body.startswith("https://") or event.body.startswith("http://"):
return
return await func(room, event, *args, **kargs)
return ret
def safe_try(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
try:
return await func(room, event, *args, **kargs)
except Exception as e:
print("--------------")
print("error:")
print(traceback.format_exc())
print("--------------")
await self.room_send(
room.room_id,
"m.reaction",
{
"m.relates_to": {
"event_id": event.event_id,
"key": "😵",
"rel_type": "m.annotation",
}
},
)
return ret
def message_callback_common_wrapper(self, func):
@wraps(func)
@self.ignore_self_message
@self.handel_no_gpt
@self.log_message
@self.with_typing
@self.replace_reply_file_with_content
@self.change_event_id_to_root_id
@self.replace_command_mark
@self.safe_try
async def ret(*args, **kargs):
return await func(*args, **kargs)
return ret
async def react_ok(self, room_id: str, event_id: str):
await self.room_send(
room_id,
"m.reaction",
{
"m.relates_to": {
"event_id": event_id,
"key": "😘",
"rel_type": "m.annotation",
}
},
)