import io from concurrent import futures from PIL import Image import grpc import msw_pb2 as msw_pb2 import msw_pb2_grpc as msw_pb2_grpc from transformers import AutoModel, AutoProcessor model_name_or_path = "google/siglip-so400m-patch14-384" from PIL import Image import requests from transformers import AutoProcessor, AutoModel import torch print("Loading model and processor") device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").eval().to(device) processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") class MSWService(msw_pb2_grpc.MSWServicer): def ClipEmbeddingForText( self, request: msw_pb2.ClipEmbeddingForTextRequest, context ) -> msw_pb2.ClipEmbeddingResponse: print("Received request", request.Text, type(list(request.Text))) with torch.no_grad(): inputs = processor( text=list(request.Text), return_tensors="pt", padding=True ).to(device) outputs = model.get_text_features(**inputs) return msw_pb2.ClipEmbeddingResponse( Embeddings=[msw_pb2.Embedding(Embedding=i) for i in outputs.tolist()] ) def ClipEmbeddingForImage( self, request: msw_pb2.ClipEmbeddingForImageRequest, context ) -> msw_pb2.ClipEmbeddingResponse: image = Image.open(io.BytesIO(request.Image)) with torch.no_grad(): inputs = processor(images=image, return_tensors="pt", padding=True).to( device ) outputs = model.get_image_features(**inputs) return msw_pb2.ClipEmbeddingResponse( Embeddings=[msw_pb2.Embedding(Embedding=outputs[0].tolist())] ) def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) msw_pb2_grpc.add_MSWServicer_to_server(MSWService(), server) server.add_insecure_port("0.0.0.0:8888") server.start() print("Server started at port 8888") server.wait_for_termination() if __name__ == "__main__": serve()