This commit is contained in:
2023-10-21 21:14:57 +08:00
commit 47ca0639da
11 changed files with 1541 additions and 0 deletions

315
bot.py Normal file
View File

@@ -0,0 +1,315 @@
import dotenv
dotenv.load_dotenv()
import os
import traceback
from functools import wraps
from nio import AsyncClient, MatrixRoom, RoomMessageText, SyncError, InviteEvent
import tiktoken
import databases
import builtins
import sys
def print(*args, **kwargs):
kwargs["file"] = sys.stderr
builtins.print(*args, **kwargs)
class Bot(AsyncClient):
def __init__(self, homeserver: str, user: str, device_id: str, access_token: str):
super().__init__(homeserver)
self.access_token = access_token
self.user_id = self.user = user
self.device_id = device_id
self.welcome_message = ""
self._joined_rooms = []
self.db = databases.Database(os.environ["MATRIX_CHAIN_DB"])
self.enc = tiktoken.encoding_for_model("gpt-4")
# auto join
self.add_event_callback(self.auto_join, InviteEvent)
async def init_db(self):
db = self.db
await db.execute("CREATE EXTENSION IF NOT EXISTS vector")
await db.execute(
"""
CREATE TABLE IF NOT EXISTS documents
(
md5 character(32) NOT NULL PRIMARY KEY,
content text,
token integer,
url text
)
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS embeddings
(
document_md5 character(32) NOT NULL,
md5 character(32) NOT NULL,
content text NOT NULL,
token integer NOT NULL,
embedding vector(1536) NOT NULL,
PRIMARY KEY (document_md5, md5),
FOREIGN KEY (document_md5) REFERENCES documents(md5)
);
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS event_document
(
event text NOT NULL PRIMARY KEY,
document_md5 character(32) NOT NULL,
FOREIGN KEY (document_md5)
REFERENCES documents (md5)
);
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS memories
(
id SERIAL PRIMARY KEY,
root text NOT NULL,
role integer NOT NULL,
content text NOT NULL,
token integer NOT NULL
)
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS room_configs
(
room text NOT NULL PRIMARY KEY,
model_name text,
temperature float NOT NULL DEFAULT 0,
system text,
embedding boolean NOT NULL DEFAULT false,
examples TEXT[] NOT NULL DEFAULT '{}'
)
"""
)
await db.execute(
"""
CREATE TABLE IF NOT EXISTS room_document (
room text NOT NULL,
document_md5 character(32) NOT NULL,
PRIMARY KEY (room, document_md5)
);
"""
)
def get_token_length(self, text: str) -> int:
return len(self.enc.encode(text))
async def sync_forever(self):
# init
print("connecting to db")
await self.db.connect()
# init db hook
print("init db hook")
await self.init_db()
# remote callback to perform initial sync
callbacks = self.event_callbacks
self.event_callbacks = []
# skip intial sync
print("Perform initial sync")
resp = await self.sync(timeout=30000)
if isinstance(resp, SyncError):
raise BaseException(SyncError)
self.event_callbacks = callbacks
# set online
print("Set online status")
await self.set_presence("online")
# sync
print("Start forever sync")
return await super().sync_forever(300000, since=resp.next_batch)
async def auto_join(self, room: MatrixRoom, event: InviteEvent):
print("join", event.sender, room.room_id)
if room.room_id in self._joined_rooms:
return
await self.join(room.room_id)
self._joined_rooms.append(room.room_id)
if self.welcome_message:
await self.room_send(
room.room_id,
"m.room.message",
{
"body": self.welcome_message,
"msgtype": "m.text",
"nogpt": True,
},
)
def ignore_self_message(self, func):
@wraps(func)
async def ret(room: MatrixRoom, event: RoomMessageText):
if event.sender == self.user:
return
return await func(room, event)
return ret
def log_message(self, func):
@wraps(func)
async def ret(room: MatrixRoom, event: RoomMessageText):
print(room.room_id, event.sender, event.body)
return await func(room, event)
return ret
def with_typing(self, func):
@wraps(func)
async def ret(room, *args, **kargs):
await self.room_typing(room.room_id, True, 60000 * 3)
resp = await func(room, *args, **kargs)
await self.room_typing(room.room_id, False)
return resp
return ret
def change_event_id_to_root_id(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
root = event.event_id
if event.flattened().get("content.m.relates_to.rel_type") == "m.thread":
root = event.source["content"]["m.relates_to"]["event_id"]
event.event_id = root
return await func(room, event, *args, **kargs)
return ret
def ignore_not_mentioned(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
flattened = event.flattened()
if not self.user in flattened.get(
"content.body", ""
) and not self.user in flattened.get("content.formatted_body", ""):
return
return await func(room, event, *args, **kargs)
return ret
def replace_command_mark(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
if event.body.startswith(""):
event.body = "!" + event.body[1:]
return await func(room, event, *args, **kargs)
return ret
def handel_no_gpt(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
if not event.flattened().get("content.nogpt") is None:
return
return await func(room, event, *args, **kargs)
return ret
def replace_reply_file_with_content(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
flattened = event.flattened()
formatted_body = flattened.get("content.formatted_body", "")
if (
formatted_body.startswith("<mx-reply>")
and "</mx-reply>" in formatted_body
):
print("replacing file content")
formatted_body = formatted_body[
formatted_body.index("</mx-reply>") + len("</mx-reply>") :
]
document_event_id = flattened.get(
"content.m.relates_to.m.in_reply_to.event_id", ""
)
fetch = await self.db.fetch_all(
query="""select d.content, d.token
from documents d
join event_document ed on d.md5 = ed.document_md5
where ed.event = :event_id""",
values={"event_id": document_event_id},
)
if len(fetch) > 0 and fetch[0][1] < 8192 + 4096:
content = fetch[0][0]
print(content)
print("-----------")
event.body = content + "\n\n---\n\n" + formatted_body
else:
print("document not found or too large", event.event_id)
return await func(room, event, *args, **kargs)
return ret
def ignore_link(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
if event.body.startswith("https://") or event.body.startswith("http://"):
return
return await func(room, event, *args, **kargs)
return ret
def safe_try(self, func):
@wraps(func)
async def ret(room, event, *args, **kargs):
try:
return await func(room, event, *args, **kargs)
except Exception as e:
print("--------------")
print("error:")
print(traceback.format_exc())
print("--------------")
await self.room_send(
room.room_id,
"m.reaction",
{
"m.relates_to": {
"event_id": event.event_id,
"key": "😵",
"rel_type": "m.annotation",
}
},
)
return ret
def message_callback_common_wrapper(self, func):
@wraps(func)
@self.ignore_self_message
@self.handel_no_gpt
@self.log_message
@self.with_typing
@self.replace_reply_file_with_content
@self.change_event_id_to_root_id
@self.replace_command_mark
@self.safe_try
async def ret(*args, **kargs):
return await func(*args, **kargs)
return ret
async def react_ok(self, room_id: str, event_id: str):
await self.room_send(
room_id,
"m.reaction",
{
"m.relates_to": {
"event_id": event_id,
"key": "😘",
"rel_type": "m.annotation",
}
},
)