316 lines
9.8 KiB
Python
316 lines
9.8 KiB
Python
import dotenv
|
||
|
||
dotenv.load_dotenv()
|
||
import os
|
||
import traceback
|
||
from functools import wraps
|
||
from nio import AsyncClient, MatrixRoom, RoomMessageText, SyncError, InviteEvent
|
||
import tiktoken
|
||
import databases
|
||
import builtins
|
||
import sys
|
||
|
||
|
||
def print(*args, **kwargs):
|
||
kwargs["file"] = sys.stderr
|
||
builtins.print(*args, **kwargs)
|
||
|
||
|
||
class Bot(AsyncClient):
|
||
def __init__(self, homeserver: str, user: str, device_id: str, access_token: str):
|
||
super().__init__(homeserver)
|
||
self.access_token = access_token
|
||
self.user_id = self.user = user
|
||
self.device_id = device_id
|
||
self.welcome_message = ""
|
||
self._joined_rooms = []
|
||
self.db = databases.Database(os.environ["MATRIX_CHAIN_DB"])
|
||
|
||
self.enc = tiktoken.encoding_for_model("gpt-4")
|
||
|
||
# auto join
|
||
self.add_event_callback(self.auto_join, InviteEvent)
|
||
|
||
async def init_db(self):
|
||
db = self.db
|
||
await db.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||
await db.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS documents
|
||
(
|
||
md5 character(32) NOT NULL PRIMARY KEY,
|
||
content text,
|
||
token integer,
|
||
url text
|
||
)
|
||
"""
|
||
)
|
||
await db.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS embeddings
|
||
(
|
||
document_md5 character(32) NOT NULL,
|
||
md5 character(32) NOT NULL,
|
||
content text NOT NULL,
|
||
token integer NOT NULL,
|
||
embedding vector(1536) NOT NULL,
|
||
PRIMARY KEY (document_md5, md5),
|
||
FOREIGN KEY (document_md5) REFERENCES documents(md5)
|
||
);
|
||
"""
|
||
)
|
||
await db.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS event_document
|
||
(
|
||
event text NOT NULL PRIMARY KEY,
|
||
document_md5 character(32) NOT NULL,
|
||
FOREIGN KEY (document_md5)
|
||
REFERENCES documents (md5)
|
||
);
|
||
"""
|
||
)
|
||
await db.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS memories
|
||
(
|
||
id SERIAL PRIMARY KEY,
|
||
root text NOT NULL,
|
||
role integer NOT NULL,
|
||
content text NOT NULL,
|
||
token integer NOT NULL
|
||
)
|
||
"""
|
||
)
|
||
await db.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS room_configs
|
||
(
|
||
room text NOT NULL PRIMARY KEY,
|
||
model_name text,
|
||
temperature float NOT NULL DEFAULT 0,
|
||
system text,
|
||
embedding boolean NOT NULL DEFAULT false,
|
||
examples TEXT[] NOT NULL DEFAULT '{}'
|
||
)
|
||
"""
|
||
)
|
||
await db.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS room_document (
|
||
room text NOT NULL,
|
||
document_md5 character(32) NOT NULL,
|
||
PRIMARY KEY (room, document_md5)
|
||
);
|
||
"""
|
||
)
|
||
|
||
def get_token_length(self, text: str) -> int:
|
||
return len(self.enc.encode(text))
|
||
|
||
async def sync_forever(self):
|
||
# init
|
||
print("connecting to db")
|
||
await self.db.connect()
|
||
|
||
# init db hook
|
||
print("init db hook")
|
||
await self.init_db()
|
||
|
||
# remote callback to perform initial sync
|
||
callbacks = self.event_callbacks
|
||
self.event_callbacks = []
|
||
# skip intial sync
|
||
print("Perform initial sync")
|
||
resp = await self.sync(timeout=30000)
|
||
if isinstance(resp, SyncError):
|
||
raise BaseException(SyncError)
|
||
self.event_callbacks = callbacks
|
||
# set online
|
||
print("Set online status")
|
||
await self.set_presence("online")
|
||
# sync
|
||
print("Start forever sync")
|
||
return await super().sync_forever(300000, since=resp.next_batch)
|
||
|
||
async def auto_join(self, room: MatrixRoom, event: InviteEvent):
|
||
print("join", event.sender, room.room_id)
|
||
if room.room_id in self._joined_rooms:
|
||
return
|
||
await self.join(room.room_id)
|
||
self._joined_rooms.append(room.room_id)
|
||
if self.welcome_message:
|
||
await self.room_send(
|
||
room.room_id,
|
||
"m.room.message",
|
||
{
|
||
"body": self.welcome_message,
|
||
"msgtype": "m.text",
|
||
"nogpt": True,
|
||
},
|
||
)
|
||
|
||
def ignore_self_message(self, func):
|
||
@wraps(func)
|
||
async def ret(room: MatrixRoom, event: RoomMessageText):
|
||
if event.sender == self.user:
|
||
return
|
||
return await func(room, event)
|
||
|
||
return ret
|
||
|
||
def log_message(self, func):
|
||
@wraps(func)
|
||
async def ret(room: MatrixRoom, event: RoomMessageText):
|
||
print(room.room_id, event.sender, event.body)
|
||
return await func(room, event)
|
||
|
||
return ret
|
||
|
||
def with_typing(self, func):
|
||
@wraps(func)
|
||
async def ret(room, *args, **kargs):
|
||
await self.room_typing(room.room_id, True, 60000 * 3)
|
||
resp = await func(room, *args, **kargs)
|
||
await self.room_typing(room.room_id, False)
|
||
return resp
|
||
|
||
return ret
|
||
|
||
def change_event_id_to_root_id(self, func):
|
||
@wraps(func)
|
||
async def ret(room, event, *args, **kargs):
|
||
root = event.event_id
|
||
if event.flattened().get("content.m.relates_to.rel_type") == "m.thread":
|
||
root = event.source["content"]["m.relates_to"]["event_id"]
|
||
event.event_id = root
|
||
return await func(room, event, *args, **kargs)
|
||
|
||
return ret
|
||
|
||
def ignore_not_mentioned(self, func):
|
||
@wraps(func)
|
||
async def ret(room, event, *args, **kargs):
|
||
flattened = event.flattened()
|
||
if not self.user in flattened.get(
|
||
"content.body", ""
|
||
) and not self.user in flattened.get("content.formatted_body", ""):
|
||
return
|
||
return await func(room, event, *args, **kargs)
|
||
|
||
return ret
|
||
|
||
def replace_command_mark(self, func):
|
||
@wraps(func)
|
||
async def ret(room, event, *args, **kargs):
|
||
if event.body.startswith("!"):
|
||
event.body = "!" + event.body[1:]
|
||
return await func(room, event, *args, **kargs)
|
||
|
||
return ret
|
||
|
||
def handel_no_gpt(self, func):
|
||
@wraps(func)
|
||
async def ret(room, event, *args, **kargs):
|
||
if not event.flattened().get("content.nogpt") is None:
|
||
return
|
||
return await func(room, event, *args, **kargs)
|
||
|
||
return ret
|
||
|
||
def replace_reply_file_with_content(self, func):
|
||
@wraps(func)
|
||
async def ret(room, event, *args, **kargs):
|
||
flattened = event.flattened()
|
||
formatted_body = flattened.get("content.formatted_body", "")
|
||
if (
|
||
formatted_body.startswith("<mx-reply>")
|
||
and "</mx-reply>" in formatted_body
|
||
):
|
||
print("replacing file content")
|
||
formatted_body = formatted_body[
|
||
formatted_body.index("</mx-reply>") + len("</mx-reply>") :
|
||
]
|
||
document_event_id = flattened.get(
|
||
"content.m.relates_to.m.in_reply_to.event_id", ""
|
||
)
|
||
fetch = await self.db.fetch_all(
|
||
query="""select d.content, d.token
|
||
from documents d
|
||
join event_document ed on d.md5 = ed.document_md5
|
||
where ed.event = :event_id""",
|
||
values={"event_id": document_event_id},
|
||
)
|
||
if len(fetch) > 0 and fetch[0][1] < 8192 + 4096:
|
||
content = fetch[0][0]
|
||
print(content)
|
||
print("-----------")
|
||
event.body = content + "\n\n---\n\n" + formatted_body
|
||
else:
|
||
print("document not found or too large", event.event_id)
|
||
|
||
return await func(room, event, *args, **kargs)
|
||
|
||
return ret
|
||
|
||
def ignore_link(self, func):
|
||
@wraps(func)
|
||
async def ret(room, event, *args, **kargs):
|
||
if event.body.startswith("https://") or event.body.startswith("http://"):
|
||
return
|
||
return await func(room, event, *args, **kargs)
|
||
|
||
return ret
|
||
|
||
def safe_try(self, func):
|
||
@wraps(func)
|
||
async def ret(room, event, *args, **kargs):
|
||
try:
|
||
return await func(room, event, *args, **kargs)
|
||
except Exception as e:
|
||
print("--------------")
|
||
print("error:")
|
||
print(traceback.format_exc())
|
||
print("--------------")
|
||
await self.room_send(
|
||
room.room_id,
|
||
"m.reaction",
|
||
{
|
||
"m.relates_to": {
|
||
"event_id": event.event_id,
|
||
"key": "😵",
|
||
"rel_type": "m.annotation",
|
||
}
|
||
},
|
||
)
|
||
|
||
return ret
|
||
|
||
def message_callback_common_wrapper(self, func):
|
||
@wraps(func)
|
||
@self.ignore_self_message
|
||
@self.handel_no_gpt
|
||
@self.log_message
|
||
@self.with_typing
|
||
@self.replace_reply_file_with_content
|
||
@self.change_event_id_to_root_id
|
||
@self.replace_command_mark
|
||
@self.safe_try
|
||
async def ret(*args, **kargs):
|
||
return await func(*args, **kargs)
|
||
|
||
return ret
|
||
|
||
async def react_ok(self, room_id: str, event_id: str):
|
||
await self.room_send(
|
||
room_id,
|
||
"m.reaction",
|
||
{
|
||
"m.relates_to": {
|
||
"event_id": event_id,
|
||
"key": "😘",
|
||
"rel_type": "m.annotation",
|
||
}
|
||
},
|
||
)
|