暂存

tmp
This commit is contained in:
2024-05-09 16:35:21 +08:00
parent 3afa0d81bb
commit bcc1b51006
27 changed files with 1526 additions and 0 deletions

View File

@@ -0,0 +1,64 @@
import io
from concurrent import futures
from PIL import Image
import grpc
import msw_pb2 as msw_pb2
import msw_pb2_grpc as msw_pb2_grpc
from transformers import AutoModel, AutoProcessor
model_name_or_path = "google/siglip-so400m-patch14-384"
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch
print("Loading model and processor")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").eval().to(device)
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
class MSWService(msw_pb2_grpc.MSWServicer):
def ClipEmbeddingForText(
self, request: msw_pb2.ClipEmbeddingForTextRequest, context
) -> msw_pb2.ClipEmbeddingResponse:
print("Received request", request.Text, type(list(request.Text)))
with torch.no_grad():
inputs = processor(
text=list(request.Text), return_tensors="pt", padding=True
).to(device)
outputs = model.get_text_features(**inputs)
return msw_pb2.ClipEmbeddingResponse(
Embeddings=[msw_pb2.Embedding(Embedding=i) for i in outputs.tolist()]
)
def ClipEmbeddingForImage(
self, request: msw_pb2.ClipEmbeddingForImageRequest, context
) -> msw_pb2.ClipEmbeddingResponse:
image = Image.open(io.BytesIO(request.Image))
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt", padding=True).to(
device
)
outputs = model.get_image_features(**inputs)
return msw_pb2.ClipEmbeddingResponse(
Embeddings=[msw_pb2.Embedding(Embedding=outputs[0].tolist())]
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
msw_pb2_grpc.add_MSWServicer_to_server(MSWService(), server)
server.add_insecure_port("0.0.0.0:8888")
server.start()
print("Server started at port 8888")
server.wait_for_termination()
if __name__ == "__main__":
serve()