diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f15a189 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/resources +/msw +venv +__pycache__ \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..bd61119 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +all: + protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative proto/msw.proto + go build -o msw \ + main/main.go + +# python -m grpc_tools.protoc -I ../proto/ --python_out=. --pyi_out=. --grpc_python_out=. ../proto/msw.proto \ No newline at end of file diff --git a/clip_embedding/build-proto.sh b/clip_embedding/build-proto.sh new file mode 100755 index 0000000..eef3c5d --- /dev/null +++ b/clip_embedding/build-proto.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python -m grpc_tools.protoc -I ../proto/ --python_out=. --pyi_out=. --grpc_python_out=. ../proto/msw.proto diff --git a/clip_embedding/clip_embedding.py b/clip_embedding/clip_embedding.py new file mode 100644 index 0000000..fb648cf --- /dev/null +++ b/clip_embedding/clip_embedding.py @@ -0,0 +1,64 @@ +import io +from concurrent import futures +from PIL import Image +import grpc +import msw_pb2 as msw_pb2 +import msw_pb2_grpc as msw_pb2_grpc + +from transformers import AutoModel, AutoProcessor + +model_name_or_path = "google/siglip-so400m-patch14-384" +from PIL import Image +import requests +from transformers import AutoProcessor, AutoModel +import torch + +print("Loading model and processor") +device = "cuda" if torch.cuda.is_available() else "cpu" +model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").eval().to(device) +processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") + + +class MSWService(msw_pb2_grpc.MSWServicer): + def ClipEmbeddingForText( + self, request: msw_pb2.ClipEmbeddingForTextRequest, context + ) -> msw_pb2.ClipEmbeddingResponse: + print("Received request", request.Text, type(list(request.Text))) + + with torch.no_grad(): + inputs = processor( + text=list(request.Text), return_tensors="pt", padding=True + ).to(device) + outputs = model.get_text_features(**inputs) + + return msw_pb2.ClipEmbeddingResponse( + Embeddings=[msw_pb2.Embedding(Embedding=i) for i in outputs.tolist()] + ) + + def ClipEmbeddingForImage( + self, request: msw_pb2.ClipEmbeddingForImageRequest, context + ) -> msw_pb2.ClipEmbeddingResponse: + image = Image.open(io.BytesIO(request.Image)) + + with torch.no_grad(): + inputs = processor(images=image, return_tensors="pt", padding=True).to( + device + ) + outputs = model.get_image_features(**inputs) + + return msw_pb2.ClipEmbeddingResponse( + Embeddings=[msw_pb2.Embedding(Embedding=outputs[0].tolist())] + ) + + +def serve(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + msw_pb2_grpc.add_MSWServicer_to_server(MSWService(), server) + server.add_insecure_port("0.0.0.0:8888") + server.start() + print("Server started at port 8888") + server.wait_for_termination() + + +if __name__ == "__main__": + serve() diff --git a/clip_embedding/clip_embedding_test.py b/clip_embedding/clip_embedding_test.py new file mode 100644 index 0000000..a9ebf20 --- /dev/null +++ b/clip_embedding/clip_embedding_test.py @@ -0,0 +1,19 @@ +from concurrent import futures +import grpc +import msw_pb2 as msw_pb2 +import msw_pb2_grpc as msw_pb2_grpc + +channel = grpc.insecure_channel("127.0.0.1:8888") +stub = msw_pb2_grpc.MSWStub(channel) + +testImageFile = "/home/hmsy/Pictures/Screenshots/Screenshot_20240515_103419.jpeg" +with open(testImageFile, "rb") as f: + imageData = f.read() + +count = 0 +while 1: + response = stub.ClipEmbeddingForImage( + msw_pb2.ClipEmbeddingForImageRequest(Image=imageData) + ) + count += 1 + print(count) diff --git a/clip_embedding/msw_pb2.py b/clip_embedding/msw_pb2.py new file mode 100644 index 0000000..5ed5a94 --- /dev/null +++ b/clip_embedding/msw_pb2.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: msw.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tmsw.proto\x12\x03msw\x1a\x1fgoogle/protobuf/timestamp.proto\"\x07\n\x05\x45mpty\"\x9f\x01\n\x0fVersionResponse\x12\r\n\x05major\x18\x01 \x01(\x05\x12\r\n\x05minor\x18\x02 \x01(\x05\x12\r\n\x05patch\x18\x03 \x01(\x05\x12\x10\n\x08hostname\x18\x04 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x05 \x03(\t\x12\x0f\n\x07latency\x18\x06 \x01(\x03\x12.\n\nstarted_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"\x8a\x01\n\x0b\x44\x61taPackage\x12\x10\n\x08ToPlugin\x18\x01 \x01(\t\x12,\n\x06Params\x18\x02 \x03(\x0b\x32\x1c.msw.DataPackage.ParamsEntry\x12\x0c\n\x04\x42ody\x18\x03 \x01(\t\x1a-\n\x0bParamsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"+\n\x1b\x43lipEmbeddingForTextRequest\x12\x0c\n\x04Text\x18\x01 \x03(\t\"-\n\x1c\x43lipEmbeddingForImageRequest\x12\r\n\x05Image\x18\x01 \x01(\x0c\";\n\x15\x43lipEmbeddingResponse\x12\"\n\nEmbeddings\x18\x01 \x03(\x0b\x32\x0e.msw.Embedding\"\x1e\n\tEmbedding\x12\x11\n\tEmbedding\x18\x02 \x03(\x02\x32\xe6\x01\n\x03MSW\x12-\n\x07Version\x12\n.msw.Empty\x1a\x14.msw.VersionResponse\"\x00\x12V\n\x14\x43lipEmbeddingForText\x12 .msw.ClipEmbeddingForTextRequest\x1a\x1a.msw.ClipEmbeddingResponse\"\x00\x12X\n\x15\x43lipEmbeddingForImage\x12!.msw.ClipEmbeddingForImageRequest\x1a\x1a.msw.ClipEmbeddingResponse\"\x00\x42\x0bZ\tmsw/protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'msw_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'Z\tmsw/proto' + _globals['_DATAPACKAGE_PARAMSENTRY']._loaded_options = None + _globals['_DATAPACKAGE_PARAMSENTRY']._serialized_options = b'8\001' + _globals['_EMPTY']._serialized_start=51 + _globals['_EMPTY']._serialized_end=58 + _globals['_VERSIONRESPONSE']._serialized_start=61 + _globals['_VERSIONRESPONSE']._serialized_end=220 + _globals['_DATAPACKAGE']._serialized_start=223 + _globals['_DATAPACKAGE']._serialized_end=361 + _globals['_DATAPACKAGE_PARAMSENTRY']._serialized_start=316 + _globals['_DATAPACKAGE_PARAMSENTRY']._serialized_end=361 + _globals['_CLIPEMBEDDINGFORTEXTREQUEST']._serialized_start=363 + _globals['_CLIPEMBEDDINGFORTEXTREQUEST']._serialized_end=406 + _globals['_CLIPEMBEDDINGFORIMAGEREQUEST']._serialized_start=408 + _globals['_CLIPEMBEDDINGFORIMAGEREQUEST']._serialized_end=453 + _globals['_CLIPEMBEDDINGRESPONSE']._serialized_start=455 + _globals['_CLIPEMBEDDINGRESPONSE']._serialized_end=514 + _globals['_EMBEDDING']._serialized_start=516 + _globals['_EMBEDDING']._serialized_end=546 + _globals['_MSW']._serialized_start=549 + _globals['_MSW']._serialized_end=779 +# @@protoc_insertion_point(module_scope) diff --git a/clip_embedding/msw_pb2.pyi b/clip_embedding/msw_pb2.pyi new file mode 100644 index 0000000..d20814c --- /dev/null +++ b/clip_embedding/msw_pb2.pyi @@ -0,0 +1,70 @@ +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Empty(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class VersionResponse(_message.Message): + __slots__ = ("major", "minor", "patch", "hostname", "addr", "latency", "started_at") + MAJOR_FIELD_NUMBER: _ClassVar[int] + MINOR_FIELD_NUMBER: _ClassVar[int] + PATCH_FIELD_NUMBER: _ClassVar[int] + HOSTNAME_FIELD_NUMBER: _ClassVar[int] + ADDR_FIELD_NUMBER: _ClassVar[int] + LATENCY_FIELD_NUMBER: _ClassVar[int] + STARTED_AT_FIELD_NUMBER: _ClassVar[int] + major: int + minor: int + patch: int + hostname: str + addr: _containers.RepeatedScalarFieldContainer[str] + latency: int + started_at: _timestamp_pb2.Timestamp + def __init__(self, major: _Optional[int] = ..., minor: _Optional[int] = ..., patch: _Optional[int] = ..., hostname: _Optional[str] = ..., addr: _Optional[_Iterable[str]] = ..., latency: _Optional[int] = ..., started_at: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... + +class DataPackage(_message.Message): + __slots__ = ("ToPlugin", "Params", "Body") + class ParamsEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + TOPLUGIN_FIELD_NUMBER: _ClassVar[int] + PARAMS_FIELD_NUMBER: _ClassVar[int] + BODY_FIELD_NUMBER: _ClassVar[int] + ToPlugin: str + Params: _containers.ScalarMap[str, str] + Body: str + def __init__(self, ToPlugin: _Optional[str] = ..., Params: _Optional[_Mapping[str, str]] = ..., Body: _Optional[str] = ...) -> None: ... + +class ClipEmbeddingForTextRequest(_message.Message): + __slots__ = ("Text",) + TEXT_FIELD_NUMBER: _ClassVar[int] + Text: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, Text: _Optional[_Iterable[str]] = ...) -> None: ... + +class ClipEmbeddingForImageRequest(_message.Message): + __slots__ = ("Image",) + IMAGE_FIELD_NUMBER: _ClassVar[int] + Image: bytes + def __init__(self, Image: _Optional[bytes] = ...) -> None: ... + +class ClipEmbeddingResponse(_message.Message): + __slots__ = ("Embeddings",) + EMBEDDINGS_FIELD_NUMBER: _ClassVar[int] + Embeddings: _containers.RepeatedCompositeFieldContainer[Embedding] + def __init__(self, Embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ... + +class Embedding(_message.Message): + __slots__ = ("Embedding",) + EMBEDDING_FIELD_NUMBER: _ClassVar[int] + Embedding: _containers.RepeatedScalarFieldContainer[float] + def __init__(self, Embedding: _Optional[_Iterable[float]] = ...) -> None: ... diff --git a/clip_embedding/msw_pb2_grpc.py b/clip_embedding/msw_pb2_grpc.py new file mode 100644 index 0000000..bad9d3e --- /dev/null +++ b/clip_embedding/msw_pb2_grpc.py @@ -0,0 +1,187 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import msw_pb2 as msw__pb2 + +GRPC_GENERATED_VERSION = '1.63.0' +GRPC_VERSION = grpc.__version__ +EXPECTED_ERROR_RELEASE = '1.65.0' +SCHEDULED_RELEASE_DATE = 'June 25, 2024' +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + warnings.warn( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in msw_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', + RuntimeWarning + ) + + +class MSWStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Version = channel.unary_unary( + '/msw.MSW/Version', + request_serializer=msw__pb2.Empty.SerializeToString, + response_deserializer=msw__pb2.VersionResponse.FromString, + _registered_method=True) + self.ClipEmbeddingForText = channel.unary_unary( + '/msw.MSW/ClipEmbeddingForText', + request_serializer=msw__pb2.ClipEmbeddingForTextRequest.SerializeToString, + response_deserializer=msw__pb2.ClipEmbeddingResponse.FromString, + _registered_method=True) + self.ClipEmbeddingForImage = channel.unary_unary( + '/msw.MSW/ClipEmbeddingForImage', + request_serializer=msw__pb2.ClipEmbeddingForImageRequest.SerializeToString, + response_deserializer=msw__pb2.ClipEmbeddingResponse.FromString, + _registered_method=True) + + +class MSWServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Version(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ClipEmbeddingForText(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ClipEmbeddingForImage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_MSWServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Version': grpc.unary_unary_rpc_method_handler( + servicer.Version, + request_deserializer=msw__pb2.Empty.FromString, + response_serializer=msw__pb2.VersionResponse.SerializeToString, + ), + 'ClipEmbeddingForText': grpc.unary_unary_rpc_method_handler( + servicer.ClipEmbeddingForText, + request_deserializer=msw__pb2.ClipEmbeddingForTextRequest.FromString, + response_serializer=msw__pb2.ClipEmbeddingResponse.SerializeToString, + ), + 'ClipEmbeddingForImage': grpc.unary_unary_rpc_method_handler( + servicer.ClipEmbeddingForImage, + request_deserializer=msw__pb2.ClipEmbeddingForImageRequest.FromString, + response_serializer=msw__pb2.ClipEmbeddingResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'msw.MSW', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class MSW(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Version(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/msw.MSW/Version', + msw__pb2.Empty.SerializeToString, + msw__pb2.VersionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ClipEmbeddingForText(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/msw.MSW/ClipEmbeddingForText', + msw__pb2.ClipEmbeddingForTextRequest.SerializeToString, + msw__pb2.ClipEmbeddingResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ClipEmbeddingForImage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/msw.MSW/ClipEmbeddingForImage', + msw__pb2.ClipEmbeddingForImageRequest.SerializeToString, + msw__pb2.ClipEmbeddingResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/clip_embedding/requirements.txt b/clip_embedding/requirements.txt new file mode 100644 index 0000000..6ad3dc6 --- /dev/null +++ b/clip_embedding/requirements.txt @@ -0,0 +1,5 @@ +grpcio-tools +torch +torchvision +transformers +SentencePiece diff --git a/clip_embedding/requirements_version.txt b/clip_embedding/requirements_version.txt new file mode 100644 index 0000000..0705ea6 --- /dev/null +++ b/clip_embedding/requirements_version.txt @@ -0,0 +1,3 @@ +grpcio==1.63.0 +grpcio-tools==1.63.0 +protobuf==5.26.1 diff --git a/core/channel.go b/core/channel.go new file mode 100644 index 0000000..34d043d --- /dev/null +++ b/core/channel.go @@ -0,0 +1,3 @@ +package core + +var pluginChannels = map[string]chan string{} diff --git a/core/conns.go b/core/conns.go new file mode 100644 index 0000000..b719e23 --- /dev/null +++ b/core/conns.go @@ -0,0 +1,36 @@ +package core + +import ( + "fmt" + "log" + "sync" + + "google.golang.org/grpc" +) + +var conns = map[string]*grpc.ClientConn{} + +var connsLock sync.Mutex + +func AddConn(name string, conn *grpc.ClientConn) { + connsLock.Lock() + defer connsLock.Unlock() + + if oldConn, ok := conns[name]; ok { + oldConn.Close() + } + + conns[name] = conn +} + +func PrintConns() { + connsLock.Lock() + defer connsLock.Unlock() + + report := "Current connections:\n" + for name, conn := range conns { + report += fmt.Sprintf(" %s %s -> %s\n", conn.GetState(), name, conn.Target()) + } + + log.Print(report) +} diff --git a/core/core.go b/core/core.go new file mode 100644 index 0000000..34329ee --- /dev/null +++ b/core/core.go @@ -0,0 +1,47 @@ +package core + +import ( + "log" + "msw/proto" + "msw/rpc" + "msw/shell" + "net" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +var kaep = keepalive.EnforcementPolicy{ + MinTime: 5 * time.Second, + PermitWithoutStream: true, +} + +var kasp = keepalive.ServerParameters{ + Time: 1 * time.Second, + Timeout: 5 * time.Second, +} + +func Start() { + log.Println("Starting main loop") + + go shell.ExecuteOne("./resources/node_exporter") + go shell.ExecuteOne("./resources/ipmi_exporter") + + go func() { + listAddr := "0.0.0.0:3939" + list, err := net.Listen("tcp", listAddr) + if err != nil { + log.Fatal(err) + } + s := grpc.NewServer( + grpc.KeepaliveEnforcementPolicy(kaep), + grpc.KeepaliveParams(kasp), + ) + proto.RegisterMSWServer(s, &rpc.MSWServer{}) + log.Println("RPC Server started on", listAddr) + if err := s.Serve(list); err != nil { + log.Fatal(err) + } + }() +} diff --git a/core/discover.go b/core/discover.go new file mode 100644 index 0000000..ad88d3b --- /dev/null +++ b/core/discover.go @@ -0,0 +1,155 @@ +package core + +import ( + "context" + "fmt" + "log" + proto "msw/proto" + "net" + "os" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" +) + +func inc(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} + +func Discover(cidrs ...string) map[string]*proto.VersionResponse { + tasks := make(chan net.IP, 1024) + resultLock := sync.Mutex{} + seen := map[string]*proto.VersionResponse{} + wg := sync.WaitGroup{} + + // start 256 workers + for i := 0; i < 256; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for ip := range tasks { + version, conn, err := CheckHealth(ip) + if err != nil { + continue + } + + resultLock.Lock() + oldVersion, ok := seen[version.Hostname] + if ok { + oldVersion.Addr = removeDuplicate(append(oldVersion.Addr, version.Addr...)) + } else { + seen[version.Hostname] = version + } + resultLock.Unlock() + + AddConn(version.Hostname, conn) + } + }() + } + + // start producer + go func() { + defer close(tasks) + for _, cidr := range cidrs { + ip, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + log.Println(err) + } + + for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { + ipCopy := make(net.IP, len(ip)) + copy(ipCopy, ip) + tasks <- ipCopy + } + } + }() + + wg.Wait() + + report := fmt.Sprintf("Discovered %d nodes:\n", len(seen)) + for _, v := range seen { + report += fmt.Sprintf(" %s: v%d.%d.%d started_at %s (%s), at %s %dms\n", + v.Hostname, v.Major, v.Minor, v.Patch, + v.StartedAt.AsTime().In(time.Local), + time.Now().Sub(v.StartedAt.AsTime().In(time.Local)).Round(time.Second), + v.Addr, v.Latency) + } + log.Print(report) + + PrintConns() + + return seen +} + +var kacp = keepalive.ClientParameters{ + Time: 1 * time.Second, // send pings every 10 seconds if there is no activity + Timeout: 5 * time.Second, // wait 1 second for ping ack before considering the connection dead + PermitWithoutStream: true, // send pings even without active streams +} + +func CheckHealth(ip net.IP) (*proto.VersionResponse, *grpc.ClientConn, error) { + target := ip.String() + ":3939" + conn, err := grpc.NewClient( + target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithKeepaliveParams(kacp), + ) + if err != nil { + log.Fatal("Create client error:", err) + } + + rpc := proto.NewMSWClient(conn) + + ctx, cancle := context.WithTimeout(context.Background(), 1*time.Second) + defer cancle() + + begin := time.Now() + + r, err := rpc.Version(ctx, &proto.Empty{}) + if err != nil { + defer conn.Close() + return &proto.VersionResponse{}, conn, err + } + + r.Latency = time.Since(begin).Milliseconds() + + r.Addr = append(r.Addr, ip.String()+":3939") + + if r.Hostname == "" { + log.Println("Empty hostname from", ip) + return &proto.VersionResponse{}, nil, fmt.Errorf("Empty hostname from %s", ip) + } + + myHostname, err := os.Hostname() + if err != nil { + log.Println("Failed to get hostname:", err) + myHostname = "unknown" + } + + if r.Hostname == myHostname { + log.Println("Skip self", r.Hostname, ip) + return &proto.VersionResponse{}, nil, fmt.Errorf("Skip self %s", r.Hostname) + } + + return r, conn, nil +} + +func removeDuplicate[T comparable](sliceList []T) []T { + allKeys := make(map[T]bool) + list := []T{} + for _, item := range sliceList { + if _, value := allKeys[item]; !value { + allKeys[item] = true + list = append(list, item) + } + } + return list +} diff --git a/core/restart.go b/core/restart.go new file mode 100644 index 0000000..8533b56 --- /dev/null +++ b/core/restart.go @@ -0,0 +1,22 @@ +package core + +import ( + "os" + "syscall" +) + +func RestartSelf() { + executable, err := os.Executable() + if err != nil { + panic("获取可执行文件路径失败: " + err.Error()) + } + + args := os.Args + env := os.Environ() + + // 使用exec替换当前进程为新的进程 + err = syscall.Exec(executable, args, env) + if err != nil { + panic("重启失败: " + err.Error()) + } +} diff --git a/deploy.sh b/deploy.sh new file mode 100755 index 0000000..d49fac1 --- /dev/null +++ b/deploy.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +make + +rsync ./msw root@$1:/tmp/msw -rvPhz + +echo "stoping remote services" +# ssh root@$1 systemctl disable --now prometheus-node-exporter + +echo "syncing" +rsync ./resources root@$1:/tmp/ -rvPhz + +echo killing remote services +ssh root@$1 pkill -f msw +ssh root@$1 tmux kill-session -t "msw" + +echo starting remote services +ssh root@$1 " + tmux new-session -d -s "msw" /tmp/msw && + tmux ls +" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e527270 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module msw + +go 1.22.2 + +require ( + google.golang.org/grpc v1.63.2 + google.golang.org/protobuf v1.33.0 +) + +require ( + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.14.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e1723cb --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY= +google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= +google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= diff --git a/main/main.go b/main/main.go new file mode 100644 index 0000000..d83acde --- /dev/null +++ b/main/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "bufio" + "log" + "msw/core" + "os" + "strings" +) + +func main() { + core.Start() + + for { + reader := bufio.NewReaderSize(os.Stdin, 1024*16) + line, err := reader.ReadString('\n') + if err != nil { + log.Println("Error reading input:", err) + continue + } + line = strings.TrimSpace(line) + + parts := strings.Split(line, " ") + if len(parts) == 0 { + continue + } + command := parts[0] + args := parts[1:] + + switch command { + case "": + continue + case "exit": + log.Println("Exiting main loop") + return + case "discover": + core.Discover(args...) + case "conns": + core.PrintConns() + case "restart": + core.RestartSelf() + default: + log.Println("Unknown command:", command) + } + + } +} diff --git a/msw b/msw new file mode 100755 index 0000000..c266fae Binary files /dev/null and b/msw differ diff --git a/msw.go b/msw.go new file mode 100644 index 0000000..e9f648f --- /dev/null +++ b/msw.go @@ -0,0 +1 @@ +package msw diff --git a/proto/msw.pb.go b/proto/msw.pb.go new file mode 100644 index 0000000..9c9d453 --- /dev/null +++ b/proto/msw.pb.go @@ -0,0 +1,487 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.33.0 +// protoc v4.25.3 +// source: proto/msw.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Empty struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *Empty) Reset() { + *x = Empty{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_msw_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Empty) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Empty) ProtoMessage() {} + +func (x *Empty) ProtoReflect() protoreflect.Message { + mi := &file_proto_msw_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Empty.ProtoReflect.Descriptor instead. +func (*Empty) Descriptor() ([]byte, []int) { + return file_proto_msw_proto_rawDescGZIP(), []int{0} +} + +type VersionResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Major int32 `protobuf:"varint,1,opt,name=major,proto3" json:"major,omitempty"` + Minor int32 `protobuf:"varint,2,opt,name=minor,proto3" json:"minor,omitempty"` + Patch int32 `protobuf:"varint,3,opt,name=patch,proto3" json:"patch,omitempty"` + Hostname string `protobuf:"bytes,4,opt,name=hostname,proto3" json:"hostname,omitempty"` + Addr []string `protobuf:"bytes,5,rep,name=addr,proto3" json:"addr,omitempty"` + Latency int64 `protobuf:"varint,6,opt,name=latency,proto3" json:"latency,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` +} + +func (x *VersionResponse) Reset() { + *x = VersionResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_msw_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *VersionResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*VersionResponse) ProtoMessage() {} + +func (x *VersionResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_msw_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use VersionResponse.ProtoReflect.Descriptor instead. +func (*VersionResponse) Descriptor() ([]byte, []int) { + return file_proto_msw_proto_rawDescGZIP(), []int{1} +} + +func (x *VersionResponse) GetMajor() int32 { + if x != nil { + return x.Major + } + return 0 +} + +func (x *VersionResponse) GetMinor() int32 { + if x != nil { + return x.Minor + } + return 0 +} + +func (x *VersionResponse) GetPatch() int32 { + if x != nil { + return x.Patch + } + return 0 +} + +func (x *VersionResponse) GetHostname() string { + if x != nil { + return x.Hostname + } + return "" +} + +func (x *VersionResponse) GetAddr() []string { + if x != nil { + return x.Addr + } + return nil +} + +func (x *VersionResponse) GetLatency() int64 { + if x != nil { + return x.Latency + } + return 0 +} + +func (x *VersionResponse) GetStartedAt() *timestamppb.Timestamp { + if x != nil { + return x.StartedAt + } + return nil +} + +type DataPackage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ToPlugin string `protobuf:"bytes,1,opt,name=ToPlugin,proto3" json:"ToPlugin,omitempty"` + Params map[string]string `protobuf:"bytes,2,rep,name=Params,proto3" json:"Params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + Body string `protobuf:"bytes,3,opt,name=Body,proto3" json:"Body,omitempty"` +} + +func (x *DataPackage) Reset() { + *x = DataPackage{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_msw_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DataPackage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DataPackage) ProtoMessage() {} + +func (x *DataPackage) ProtoReflect() protoreflect.Message { + mi := &file_proto_msw_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DataPackage.ProtoReflect.Descriptor instead. +func (*DataPackage) Descriptor() ([]byte, []int) { + return file_proto_msw_proto_rawDescGZIP(), []int{2} +} + +func (x *DataPackage) GetToPlugin() string { + if x != nil { + return x.ToPlugin + } + return "" +} + +func (x *DataPackage) GetParams() map[string]string { + if x != nil { + return x.Params + } + return nil +} + +func (x *DataPackage) GetBody() string { + if x != nil { + return x.Body + } + return "" +} + +type ClipEmbeddingForTextRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Text string `protobuf:"bytes,1,opt,name=Text,proto3" json:"Text,omitempty"` +} + +func (x *ClipEmbeddingForTextRequest) Reset() { + *x = ClipEmbeddingForTextRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_msw_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ClipEmbeddingForTextRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClipEmbeddingForTextRequest) ProtoMessage() {} + +func (x *ClipEmbeddingForTextRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_msw_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClipEmbeddingForTextRequest.ProtoReflect.Descriptor instead. +func (*ClipEmbeddingForTextRequest) Descriptor() ([]byte, []int) { + return file_proto_msw_proto_rawDescGZIP(), []int{3} +} + +func (x *ClipEmbeddingForTextRequest) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +type ClipEmbeddingForTextResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Embedding []float32 `protobuf:"fixed32,1,rep,packed,name=Embedding,proto3" json:"Embedding,omitempty"` +} + +func (x *ClipEmbeddingForTextResponse) Reset() { + *x = ClipEmbeddingForTextResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_msw_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ClipEmbeddingForTextResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClipEmbeddingForTextResponse) ProtoMessage() {} + +func (x *ClipEmbeddingForTextResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_msw_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClipEmbeddingForTextResponse.ProtoReflect.Descriptor instead. +func (*ClipEmbeddingForTextResponse) Descriptor() ([]byte, []int) { + return file_proto_msw_proto_rawDescGZIP(), []int{4} +} + +func (x *ClipEmbeddingForTextResponse) GetEmbedding() []float32 { + if x != nil { + return x.Embedding + } + return nil +} + +var File_proto_msw_proto protoreflect.FileDescriptor + +var file_proto_msw_proto_rawDesc = []byte{ + 0x0a, 0x0f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6d, 0x73, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x12, 0x03, 0x6d, 0x73, 0x77, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, + 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0xd8, 0x01, 0x0a, 0x0f, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x61, 0x6a, 0x6f, 0x72, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x05, 0x6d, 0x61, 0x6a, 0x6f, 0x72, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x69, + 0x6e, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x6d, 0x69, 0x6e, 0x6f, 0x72, + 0x12, 0x14, 0x0a, 0x05, 0x70, 0x61, 0x74, 0x63, 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x05, 0x70, 0x61, 0x74, 0x63, 0x68, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, + 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x61, 0x64, 0x64, 0x72, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x04, 0x61, 0x64, 0x64, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, + 0x79, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, + 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, + 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x22, 0xae, 0x01, 0x0a, 0x0b, + 0x44, 0x61, 0x74, 0x61, 0x50, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x54, + 0x6f, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x54, + 0x6f, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x12, 0x34, 0x0a, 0x06, 0x50, 0x61, 0x72, 0x61, 0x6d, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x73, 0x77, 0x2e, 0x44, 0x61, + 0x74, 0x61, 0x50, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x2e, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x06, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x12, 0x12, 0x0a, + 0x04, 0x42, 0x6f, 0x64, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x42, 0x6f, 0x64, + 0x79, 0x1a, 0x39, 0x0a, 0x0b, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, + 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, + 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x31, 0x0a, 0x1b, + 0x43, 0x6c, 0x69, 0x70, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x46, 0x6f, 0x72, + 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, + 0x65, 0x78, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x54, 0x65, 0x78, 0x74, 0x22, + 0x3c, 0x0a, 0x1c, 0x43, 0x6c, 0x69, 0x70, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, + 0x46, 0x6f, 0x72, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x1c, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x02, 0x52, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x32, 0x8f, 0x01, + 0x0a, 0x03, 0x4d, 0x53, 0x57, 0x12, 0x2b, 0x0a, 0x07, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x12, 0x0a, 0x2e, 0x6d, 0x73, 0x77, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x14, 0x2e, 0x6d, + 0x73, 0x77, 0x2e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x5b, 0x0a, 0x14, 0x43, 0x6c, 0x69, 0x70, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, + 0x69, 0x6e, 0x67, 0x46, 0x6f, 0x72, 0x54, 0x65, 0x78, 0x74, 0x12, 0x20, 0x2e, 0x6d, 0x73, 0x77, + 0x2e, 0x43, 0x6c, 0x69, 0x70, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x46, 0x6f, + 0x72, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, + 0x73, 0x77, 0x2e, 0x43, 0x6c, 0x69, 0x70, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, + 0x46, 0x6f, 0x72, 0x54, 0x65, 0x78, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, + 0x0b, 0x5a, 0x09, 0x6d, 0x73, 0x77, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_proto_msw_proto_rawDescOnce sync.Once + file_proto_msw_proto_rawDescData = file_proto_msw_proto_rawDesc +) + +func file_proto_msw_proto_rawDescGZIP() []byte { + file_proto_msw_proto_rawDescOnce.Do(func() { + file_proto_msw_proto_rawDescData = protoimpl.X.CompressGZIP(file_proto_msw_proto_rawDescData) + }) + return file_proto_msw_proto_rawDescData +} + +var file_proto_msw_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_proto_msw_proto_goTypes = []interface{}{ + (*Empty)(nil), // 0: msw.Empty + (*VersionResponse)(nil), // 1: msw.VersionResponse + (*DataPackage)(nil), // 2: msw.DataPackage + (*ClipEmbeddingForTextRequest)(nil), // 3: msw.ClipEmbeddingForTextRequest + (*ClipEmbeddingForTextResponse)(nil), // 4: msw.ClipEmbeddingForTextResponse + nil, // 5: msw.DataPackage.ParamsEntry + (*timestamppb.Timestamp)(nil), // 6: google.protobuf.Timestamp +} +var file_proto_msw_proto_depIdxs = []int32{ + 6, // 0: msw.VersionResponse.started_at:type_name -> google.protobuf.Timestamp + 5, // 1: msw.DataPackage.Params:type_name -> msw.DataPackage.ParamsEntry + 0, // 2: msw.MSW.Version:input_type -> msw.Empty + 3, // 3: msw.MSW.ClipEmbeddingForText:input_type -> msw.ClipEmbeddingForTextRequest + 1, // 4: msw.MSW.Version:output_type -> msw.VersionResponse + 4, // 5: msw.MSW.ClipEmbeddingForText:output_type -> msw.ClipEmbeddingForTextResponse + 4, // [4:6] is the sub-list for method output_type + 2, // [2:4] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_proto_msw_proto_init() } +func file_proto_msw_proto_init() { + if File_proto_msw_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_proto_msw_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Empty); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_msw_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*VersionResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_msw_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DataPackage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_msw_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ClipEmbeddingForTextRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proto_msw_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ClipEmbeddingForTextResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_proto_msw_proto_rawDesc, + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_proto_msw_proto_goTypes, + DependencyIndexes: file_proto_msw_proto_depIdxs, + MessageInfos: file_proto_msw_proto_msgTypes, + }.Build() + File_proto_msw_proto = out.File + file_proto_msw_proto_rawDesc = nil + file_proto_msw_proto_goTypes = nil + file_proto_msw_proto_depIdxs = nil +} diff --git a/proto/msw.proto b/proto/msw.proto new file mode 100644 index 0000000..e7ef809 --- /dev/null +++ b/proto/msw.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +package msw; + +import "google/protobuf/timestamp.proto"; + +option go_package = "msw/proto"; + +service MSW { + rpc Version(Empty) returns (VersionResponse) {}; + rpc ClipEmbeddingForText(ClipEmbeddingForTextRequest) + returns (ClipEmbeddingResponse) {}; + rpc ClipEmbeddingForImage(ClipEmbeddingForImageRequest) + returns (ClipEmbeddingResponse) {}; +} + +message Empty {} + +message VersionResponse { + int32 major = 1; + int32 minor = 2; + int32 patch = 3; + string hostname = 4; + repeated string addr = 5; + int64 latency = 6; + google.protobuf.Timestamp started_at = 7; +} + +message DataPackage { + string ToPlugin = 1; + map Params = 2; + string Body = 3; +} + +message ClipEmbeddingForTextRequest { repeated string Text = 1; } + +message ClipEmbeddingForImageRequest { bytes Image = 1; } + +message ClipEmbeddingResponse { repeated Embedding Embeddings = 1; } + +message Embedding { repeated float Embedding = 2; } \ No newline at end of file diff --git a/proto/msw_grpc.pb.go b/proto/msw_grpc.pb.go new file mode 100644 index 0000000..97e9ca5 --- /dev/null +++ b/proto/msw_grpc.pb.go @@ -0,0 +1,146 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.3.0 +// - protoc v4.25.3 +// source: proto/msw.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +const ( + MSW_Version_FullMethodName = "/msw.MSW/Version" + MSW_ClipEmbeddingForText_FullMethodName = "/msw.MSW/ClipEmbeddingForText" +) + +// MSWClient is the client API for MSW service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type MSWClient interface { + Version(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*VersionResponse, error) + ClipEmbeddingForText(ctx context.Context, in *ClipEmbeddingForTextRequest, opts ...grpc.CallOption) (*ClipEmbeddingForTextResponse, error) +} + +type mSWClient struct { + cc grpc.ClientConnInterface +} + +func NewMSWClient(cc grpc.ClientConnInterface) MSWClient { + return &mSWClient{cc} +} + +func (c *mSWClient) Version(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*VersionResponse, error) { + out := new(VersionResponse) + err := c.cc.Invoke(ctx, MSW_Version_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *mSWClient) ClipEmbeddingForText(ctx context.Context, in *ClipEmbeddingForTextRequest, opts ...grpc.CallOption) (*ClipEmbeddingForTextResponse, error) { + out := new(ClipEmbeddingForTextResponse) + err := c.cc.Invoke(ctx, MSW_ClipEmbeddingForText_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// MSWServer is the server API for MSW service. +// All implementations must embed UnimplementedMSWServer +// for forward compatibility +type MSWServer interface { + Version(context.Context, *Empty) (*VersionResponse, error) + ClipEmbeddingForText(context.Context, *ClipEmbeddingForTextRequest) (*ClipEmbeddingForTextResponse, error) + mustEmbedUnimplementedMSWServer() +} + +// UnimplementedMSWServer must be embedded to have forward compatible implementations. +type UnimplementedMSWServer struct { +} + +func (UnimplementedMSWServer) Version(context.Context, *Empty) (*VersionResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Version not implemented") +} +func (UnimplementedMSWServer) ClipEmbeddingForText(context.Context, *ClipEmbeddingForTextRequest) (*ClipEmbeddingForTextResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ClipEmbeddingForText not implemented") +} +func (UnimplementedMSWServer) mustEmbedUnimplementedMSWServer() {} + +// UnsafeMSWServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to MSWServer will +// result in compilation errors. +type UnsafeMSWServer interface { + mustEmbedUnimplementedMSWServer() +} + +func RegisterMSWServer(s grpc.ServiceRegistrar, srv MSWServer) { + s.RegisterService(&MSW_ServiceDesc, srv) +} + +func _MSW_Version_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MSWServer).Version(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: MSW_Version_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MSWServer).Version(ctx, req.(*Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _MSW_ClipEmbeddingForText_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ClipEmbeddingForTextRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MSWServer).ClipEmbeddingForText(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: MSW_ClipEmbeddingForText_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MSWServer).ClipEmbeddingForText(ctx, req.(*ClipEmbeddingForTextRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// MSW_ServiceDesc is the grpc.ServiceDesc for MSW service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var MSW_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "msw.MSW", + HandlerType: (*MSWServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Version", + Handler: _MSW_Version_Handler, + }, + { + MethodName: "ClipEmbeddingForText", + Handler: _MSW_ClipEmbeddingForText_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "proto/msw.proto", +} diff --git a/rpc/server.go b/rpc/server.go new file mode 100644 index 0000000..1aa1cd1 --- /dev/null +++ b/rpc/server.go @@ -0,0 +1,9 @@ +package rpc + +import ( + proto "msw/proto" +) + +type MSWServer struct { + proto.UnimplementedMSWServer +} diff --git a/rpc/version.go b/rpc/version.go new file mode 100644 index 0000000..fef43a8 --- /dev/null +++ b/rpc/version.go @@ -0,0 +1,30 @@ +package rpc + +import ( + "context" + "log" + proto "msw/proto" + "os" + "time" + + "google.golang.org/protobuf/types/known/timestamppb" +) + +var startedAt = time.Now() + +func (r *MSWServer) Version(ctx context.Context, in *proto.Empty) (*proto.VersionResponse, error) { + log.Println("Health check request received") + hostname, err := os.Hostname() + if err != nil { + log.Println("Failed to get hostname:", err) + hostname = "unknown" + } + log.Println("[rpc.version] Health check response sent") + return &proto.VersionResponse{ + Major: 0, + Minor: 0, + Patch: 1, + Hostname: hostname, + StartedAt: ×tamppb.Timestamp{Seconds: startedAt.Unix()}, + }, nil +} diff --git a/shell/shell.go b/shell/shell.go new file mode 100644 index 0000000..181a958 --- /dev/null +++ b/shell/shell.go @@ -0,0 +1,46 @@ +package shell + +import ( + "log" + "os" + "os/exec" + "path/filepath" +) + +func ExecuteOne(name string, args ...string) { + KillAll(name) + Execute(name, args...) +} + +func Execute(name string, args ...string) { + // get the directory of the executable + file, err := os.Executable() + if err != nil { + log.Fatal("Error getting executable path:", err) + } + + dir := filepath.Dir(file) + + // change the current working directory to the directory of the executable + err = os.Chdir(dir) + if err != nil { + log.Fatal("Error changing working directory:", err) + } + + log.Println("Executing command:", name, args) + cmd := exec.Command(name, args...) + // cmd.Stderr = os.Stderr + // cmd.Stdout = os.Stdout + err = cmd.Run() + if err != nil { + log.Println("Error executing command:", name, args, err) + } +} + +func KillAll(name string) { + cmd := exec.Command("pkill", "-f", name) + err := cmd.Run() + if err != nil { + log.Println("Error killing all processes with name:", name, err) + } +}