re-init
暂存 tmp
This commit is contained in:
64
clip_embedding/clip_embedding.py
Normal file
64
clip_embedding/clip_embedding.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import io
|
||||
from concurrent import futures
|
||||
from PIL import Image
|
||||
import grpc
|
||||
import msw_pb2 as msw_pb2
|
||||
import msw_pb2_grpc as msw_pb2_grpc
|
||||
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
model_name_or_path = "google/siglip-so400m-patch14-384"
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
import torch
|
||||
|
||||
print("Loading model and processor")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").eval().to(device)
|
||||
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
||||
|
||||
|
||||
class MSWService(msw_pb2_grpc.MSWServicer):
|
||||
def ClipEmbeddingForText(
|
||||
self, request: msw_pb2.ClipEmbeddingForTextRequest, context
|
||||
) -> msw_pb2.ClipEmbeddingResponse:
|
||||
print("Received request", request.Text, type(list(request.Text)))
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = processor(
|
||||
text=list(request.Text), return_tensors="pt", padding=True
|
||||
).to(device)
|
||||
outputs = model.get_text_features(**inputs)
|
||||
|
||||
return msw_pb2.ClipEmbeddingResponse(
|
||||
Embeddings=[msw_pb2.Embedding(Embedding=i) for i in outputs.tolist()]
|
||||
)
|
||||
|
||||
def ClipEmbeddingForImage(
|
||||
self, request: msw_pb2.ClipEmbeddingForImageRequest, context
|
||||
) -> msw_pb2.ClipEmbeddingResponse:
|
||||
image = Image.open(io.BytesIO(request.Image))
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = processor(images=image, return_tensors="pt", padding=True).to(
|
||||
device
|
||||
)
|
||||
outputs = model.get_image_features(**inputs)
|
||||
|
||||
return msw_pb2.ClipEmbeddingResponse(
|
||||
Embeddings=[msw_pb2.Embedding(Embedding=outputs[0].tolist())]
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
msw_pb2_grpc.add_MSWServicer_to_server(MSWService(), server)
|
||||
server.add_insecure_port("0.0.0.0:8888")
|
||||
server.start()
|
||||
print("Server started at port 8888")
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
Reference in New Issue
Block a user