暂存

tmp
This commit is contained in:
2024-05-09 16:35:21 +08:00
parent 3afa0d81bb
commit bcc1b51006
27 changed files with 1526 additions and 0 deletions

2
clip_embedding/build-proto.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/bash
python -m grpc_tools.protoc -I ../proto/ --python_out=. --pyi_out=. --grpc_python_out=. ../proto/msw.proto

View File

@@ -0,0 +1,64 @@
import io
from concurrent import futures
from PIL import Image
import grpc
import msw_pb2 as msw_pb2
import msw_pb2_grpc as msw_pb2_grpc
from transformers import AutoModel, AutoProcessor
model_name_or_path = "google/siglip-so400m-patch14-384"
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch
print("Loading model and processor")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").eval().to(device)
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
class MSWService(msw_pb2_grpc.MSWServicer):
def ClipEmbeddingForText(
self, request: msw_pb2.ClipEmbeddingForTextRequest, context
) -> msw_pb2.ClipEmbeddingResponse:
print("Received request", request.Text, type(list(request.Text)))
with torch.no_grad():
inputs = processor(
text=list(request.Text), return_tensors="pt", padding=True
).to(device)
outputs = model.get_text_features(**inputs)
return msw_pb2.ClipEmbeddingResponse(
Embeddings=[msw_pb2.Embedding(Embedding=i) for i in outputs.tolist()]
)
def ClipEmbeddingForImage(
self, request: msw_pb2.ClipEmbeddingForImageRequest, context
) -> msw_pb2.ClipEmbeddingResponse:
image = Image.open(io.BytesIO(request.Image))
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt", padding=True).to(
device
)
outputs = model.get_image_features(**inputs)
return msw_pb2.ClipEmbeddingResponse(
Embeddings=[msw_pb2.Embedding(Embedding=outputs[0].tolist())]
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
msw_pb2_grpc.add_MSWServicer_to_server(MSWService(), server)
server.add_insecure_port("0.0.0.0:8888")
server.start()
print("Server started at port 8888")
server.wait_for_termination()
if __name__ == "__main__":
serve()

View File

@@ -0,0 +1,19 @@
from concurrent import futures
import grpc
import msw_pb2 as msw_pb2
import msw_pb2_grpc as msw_pb2_grpc
channel = grpc.insecure_channel("127.0.0.1:8888")
stub = msw_pb2_grpc.MSWStub(channel)
testImageFile = "/home/hmsy/Pictures/Screenshots/Screenshot_20240515_103419.jpeg"
with open(testImageFile, "rb") as f:
imageData = f.read()
count = 0
while 1:
response = stub.ClipEmbeddingForImage(
msw_pb2.ClipEmbeddingForImageRequest(Image=imageData)
)
count += 1
print(count)

46
clip_embedding/msw_pb2.py Normal file
View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: msw.proto
# Protobuf Python Version: 5.26.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tmsw.proto\x12\x03msw\x1a\x1fgoogle/protobuf/timestamp.proto\"\x07\n\x05\x45mpty\"\x9f\x01\n\x0fVersionResponse\x12\r\n\x05major\x18\x01 \x01(\x05\x12\r\n\x05minor\x18\x02 \x01(\x05\x12\r\n\x05patch\x18\x03 \x01(\x05\x12\x10\n\x08hostname\x18\x04 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x05 \x03(\t\x12\x0f\n\x07latency\x18\x06 \x01(\x03\x12.\n\nstarted_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"\x8a\x01\n\x0b\x44\x61taPackage\x12\x10\n\x08ToPlugin\x18\x01 \x01(\t\x12,\n\x06Params\x18\x02 \x03(\x0b\x32\x1c.msw.DataPackage.ParamsEntry\x12\x0c\n\x04\x42ody\x18\x03 \x01(\t\x1a-\n\x0bParamsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"+\n\x1b\x43lipEmbeddingForTextRequest\x12\x0c\n\x04Text\x18\x01 \x03(\t\"-\n\x1c\x43lipEmbeddingForImageRequest\x12\r\n\x05Image\x18\x01 \x01(\x0c\";\n\x15\x43lipEmbeddingResponse\x12\"\n\nEmbeddings\x18\x01 \x03(\x0b\x32\x0e.msw.Embedding\"\x1e\n\tEmbedding\x12\x11\n\tEmbedding\x18\x02 \x03(\x02\x32\xe6\x01\n\x03MSW\x12-\n\x07Version\x12\n.msw.Empty\x1a\x14.msw.VersionResponse\"\x00\x12V\n\x14\x43lipEmbeddingForText\x12 .msw.ClipEmbeddingForTextRequest\x1a\x1a.msw.ClipEmbeddingResponse\"\x00\x12X\n\x15\x43lipEmbeddingForImage\x12!.msw.ClipEmbeddingForImageRequest\x1a\x1a.msw.ClipEmbeddingResponse\"\x00\x42\x0bZ\tmsw/protob\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'msw_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'Z\tmsw/proto'
_globals['_DATAPACKAGE_PARAMSENTRY']._loaded_options = None
_globals['_DATAPACKAGE_PARAMSENTRY']._serialized_options = b'8\001'
_globals['_EMPTY']._serialized_start=51
_globals['_EMPTY']._serialized_end=58
_globals['_VERSIONRESPONSE']._serialized_start=61
_globals['_VERSIONRESPONSE']._serialized_end=220
_globals['_DATAPACKAGE']._serialized_start=223
_globals['_DATAPACKAGE']._serialized_end=361
_globals['_DATAPACKAGE_PARAMSENTRY']._serialized_start=316
_globals['_DATAPACKAGE_PARAMSENTRY']._serialized_end=361
_globals['_CLIPEMBEDDINGFORTEXTREQUEST']._serialized_start=363
_globals['_CLIPEMBEDDINGFORTEXTREQUEST']._serialized_end=406
_globals['_CLIPEMBEDDINGFORIMAGEREQUEST']._serialized_start=408
_globals['_CLIPEMBEDDINGFORIMAGEREQUEST']._serialized_end=453
_globals['_CLIPEMBEDDINGRESPONSE']._serialized_start=455
_globals['_CLIPEMBEDDINGRESPONSE']._serialized_end=514
_globals['_EMBEDDING']._serialized_start=516
_globals['_EMBEDDING']._serialized_end=546
_globals['_MSW']._serialized_start=549
_globals['_MSW']._serialized_end=779
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,70 @@
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class Empty(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class VersionResponse(_message.Message):
__slots__ = ("major", "minor", "patch", "hostname", "addr", "latency", "started_at")
MAJOR_FIELD_NUMBER: _ClassVar[int]
MINOR_FIELD_NUMBER: _ClassVar[int]
PATCH_FIELD_NUMBER: _ClassVar[int]
HOSTNAME_FIELD_NUMBER: _ClassVar[int]
ADDR_FIELD_NUMBER: _ClassVar[int]
LATENCY_FIELD_NUMBER: _ClassVar[int]
STARTED_AT_FIELD_NUMBER: _ClassVar[int]
major: int
minor: int
patch: int
hostname: str
addr: _containers.RepeatedScalarFieldContainer[str]
latency: int
started_at: _timestamp_pb2.Timestamp
def __init__(self, major: _Optional[int] = ..., minor: _Optional[int] = ..., patch: _Optional[int] = ..., hostname: _Optional[str] = ..., addr: _Optional[_Iterable[str]] = ..., latency: _Optional[int] = ..., started_at: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ...
class DataPackage(_message.Message):
__slots__ = ("ToPlugin", "Params", "Body")
class ParamsEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: str
def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...
TOPLUGIN_FIELD_NUMBER: _ClassVar[int]
PARAMS_FIELD_NUMBER: _ClassVar[int]
BODY_FIELD_NUMBER: _ClassVar[int]
ToPlugin: str
Params: _containers.ScalarMap[str, str]
Body: str
def __init__(self, ToPlugin: _Optional[str] = ..., Params: _Optional[_Mapping[str, str]] = ..., Body: _Optional[str] = ...) -> None: ...
class ClipEmbeddingForTextRequest(_message.Message):
__slots__ = ("Text",)
TEXT_FIELD_NUMBER: _ClassVar[int]
Text: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, Text: _Optional[_Iterable[str]] = ...) -> None: ...
class ClipEmbeddingForImageRequest(_message.Message):
__slots__ = ("Image",)
IMAGE_FIELD_NUMBER: _ClassVar[int]
Image: bytes
def __init__(self, Image: _Optional[bytes] = ...) -> None: ...
class ClipEmbeddingResponse(_message.Message):
__slots__ = ("Embeddings",)
EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
Embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
def __init__(self, Embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message):
__slots__ = ("Embedding",)
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
Embedding: _containers.RepeatedScalarFieldContainer[float]
def __init__(self, Embedding: _Optional[_Iterable[float]] = ...) -> None: ...

View File

@@ -0,0 +1,187 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import msw_pb2 as msw__pb2
GRPC_GENERATED_VERSION = '1.63.0'
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = '1.65.0'
SCHEDULED_RELEASE_DATE = 'June 25, 2024'
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
warnings.warn(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in msw_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
RuntimeWarning
)
class MSWStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Version = channel.unary_unary(
'/msw.MSW/Version',
request_serializer=msw__pb2.Empty.SerializeToString,
response_deserializer=msw__pb2.VersionResponse.FromString,
_registered_method=True)
self.ClipEmbeddingForText = channel.unary_unary(
'/msw.MSW/ClipEmbeddingForText',
request_serializer=msw__pb2.ClipEmbeddingForTextRequest.SerializeToString,
response_deserializer=msw__pb2.ClipEmbeddingResponse.FromString,
_registered_method=True)
self.ClipEmbeddingForImage = channel.unary_unary(
'/msw.MSW/ClipEmbeddingForImage',
request_serializer=msw__pb2.ClipEmbeddingForImageRequest.SerializeToString,
response_deserializer=msw__pb2.ClipEmbeddingResponse.FromString,
_registered_method=True)
class MSWServicer(object):
"""Missing associated documentation comment in .proto file."""
def Version(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ClipEmbeddingForText(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def ClipEmbeddingForImage(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_MSWServicer_to_server(servicer, server):
rpc_method_handlers = {
'Version': grpc.unary_unary_rpc_method_handler(
servicer.Version,
request_deserializer=msw__pb2.Empty.FromString,
response_serializer=msw__pb2.VersionResponse.SerializeToString,
),
'ClipEmbeddingForText': grpc.unary_unary_rpc_method_handler(
servicer.ClipEmbeddingForText,
request_deserializer=msw__pb2.ClipEmbeddingForTextRequest.FromString,
response_serializer=msw__pb2.ClipEmbeddingResponse.SerializeToString,
),
'ClipEmbeddingForImage': grpc.unary_unary_rpc_method_handler(
servicer.ClipEmbeddingForImage,
request_deserializer=msw__pb2.ClipEmbeddingForImageRequest.FromString,
response_serializer=msw__pb2.ClipEmbeddingResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'msw.MSW', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class MSW(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Version(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/msw.MSW/Version',
msw__pb2.Empty.SerializeToString,
msw__pb2.VersionResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def ClipEmbeddingForText(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/msw.MSW/ClipEmbeddingForText',
msw__pb2.ClipEmbeddingForTextRequest.SerializeToString,
msw__pb2.ClipEmbeddingResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def ClipEmbeddingForImage(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/msw.MSW/ClipEmbeddingForImage',
msw__pb2.ClipEmbeddingForImageRequest.SerializeToString,
msw__pb2.ClipEmbeddingResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,5 @@
grpcio-tools
torch
torchvision
transformers
SentencePiece

View File

@@ -0,0 +1,3 @@
grpcio==1.63.0
grpcio-tools==1.63.0
protobuf==5.26.1