65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
import io
|
|
from concurrent import futures
|
|
from PIL import Image
|
|
import grpc
|
|
import msw_pb2 as msw_pb2
|
|
import msw_pb2_grpc as msw_pb2_grpc
|
|
|
|
from transformers import AutoModel, AutoProcessor
|
|
|
|
model_name_or_path = "google/siglip-so400m-patch14-384"
|
|
from PIL import Image
|
|
import requests
|
|
from transformers import AutoProcessor, AutoModel
|
|
import torch
|
|
|
|
print("Loading model and processor")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").eval().to(device)
|
|
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
|
|
|
|
|
class MSWService(msw_pb2_grpc.MSWServicer):
|
|
def ClipEmbeddingForText(
|
|
self, request: msw_pb2.ClipEmbeddingForTextRequest, context
|
|
) -> msw_pb2.ClipEmbeddingResponse:
|
|
print("Received request", request.Text, type(list(request.Text)))
|
|
|
|
with torch.no_grad():
|
|
inputs = processor(
|
|
text=list(request.Text), return_tensors="pt", padding=True
|
|
).to(device)
|
|
outputs = model.get_text_features(**inputs)
|
|
|
|
return msw_pb2.ClipEmbeddingResponse(
|
|
Embeddings=[msw_pb2.Embedding(Embedding=i) for i in outputs.tolist()]
|
|
)
|
|
|
|
def ClipEmbeddingForImage(
|
|
self, request: msw_pb2.ClipEmbeddingForImageRequest, context
|
|
) -> msw_pb2.ClipEmbeddingResponse:
|
|
image = Image.open(io.BytesIO(request.Image))
|
|
|
|
with torch.no_grad():
|
|
inputs = processor(images=image, return_tensors="pt", padding=True).to(
|
|
device
|
|
)
|
|
outputs = model.get_image_features(**inputs)
|
|
|
|
return msw_pb2.ClipEmbeddingResponse(
|
|
Embeddings=[msw_pb2.Embedding(Embedding=outputs[0].tolist())]
|
|
)
|
|
|
|
|
|
def serve():
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
|
msw_pb2_grpc.add_MSWServicer_to_server(MSWService(), server)
|
|
server.add_insecure_port("0.0.0.0:8888")
|
|
server.start()
|
|
print("Server started at port 8888")
|
|
server.wait_for_termination()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
serve()
|