re-init
暂存 tmp
This commit is contained in:
2
clip_embedding/build-proto.sh
Executable file
2
clip_embedding/build-proto.sh
Executable file
@@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
python -m grpc_tools.protoc -I ../proto/ --python_out=. --pyi_out=. --grpc_python_out=. ../proto/msw.proto
|
||||
64
clip_embedding/clip_embedding.py
Normal file
64
clip_embedding/clip_embedding.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import io
|
||||
from concurrent import futures
|
||||
from PIL import Image
|
||||
import grpc
|
||||
import msw_pb2 as msw_pb2
|
||||
import msw_pb2_grpc as msw_pb2_grpc
|
||||
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
model_name_or_path = "google/siglip-so400m-patch14-384"
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
import torch
|
||||
|
||||
print("Loading model and processor")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").eval().to(device)
|
||||
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
||||
|
||||
|
||||
class MSWService(msw_pb2_grpc.MSWServicer):
|
||||
def ClipEmbeddingForText(
|
||||
self, request: msw_pb2.ClipEmbeddingForTextRequest, context
|
||||
) -> msw_pb2.ClipEmbeddingResponse:
|
||||
print("Received request", request.Text, type(list(request.Text)))
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = processor(
|
||||
text=list(request.Text), return_tensors="pt", padding=True
|
||||
).to(device)
|
||||
outputs = model.get_text_features(**inputs)
|
||||
|
||||
return msw_pb2.ClipEmbeddingResponse(
|
||||
Embeddings=[msw_pb2.Embedding(Embedding=i) for i in outputs.tolist()]
|
||||
)
|
||||
|
||||
def ClipEmbeddingForImage(
|
||||
self, request: msw_pb2.ClipEmbeddingForImageRequest, context
|
||||
) -> msw_pb2.ClipEmbeddingResponse:
|
||||
image = Image.open(io.BytesIO(request.Image))
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = processor(images=image, return_tensors="pt", padding=True).to(
|
||||
device
|
||||
)
|
||||
outputs = model.get_image_features(**inputs)
|
||||
|
||||
return msw_pb2.ClipEmbeddingResponse(
|
||||
Embeddings=[msw_pb2.Embedding(Embedding=outputs[0].tolist())]
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
msw_pb2_grpc.add_MSWServicer_to_server(MSWService(), server)
|
||||
server.add_insecure_port("0.0.0.0:8888")
|
||||
server.start()
|
||||
print("Server started at port 8888")
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
19
clip_embedding/clip_embedding_test.py
Normal file
19
clip_embedding/clip_embedding_test.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import msw_pb2 as msw_pb2
|
||||
import msw_pb2_grpc as msw_pb2_grpc
|
||||
|
||||
channel = grpc.insecure_channel("127.0.0.1:8888")
|
||||
stub = msw_pb2_grpc.MSWStub(channel)
|
||||
|
||||
testImageFile = "/home/hmsy/Pictures/Screenshots/Screenshot_20240515_103419.jpeg"
|
||||
with open(testImageFile, "rb") as f:
|
||||
imageData = f.read()
|
||||
|
||||
count = 0
|
||||
while 1:
|
||||
response = stub.ClipEmbeddingForImage(
|
||||
msw_pb2.ClipEmbeddingForImageRequest(Image=imageData)
|
||||
)
|
||||
count += 1
|
||||
print(count)
|
||||
46
clip_embedding/msw_pb2.py
Normal file
46
clip_embedding/msw_pb2.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: msw.proto
|
||||
# Protobuf Python Version: 5.26.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tmsw.proto\x12\x03msw\x1a\x1fgoogle/protobuf/timestamp.proto\"\x07\n\x05\x45mpty\"\x9f\x01\n\x0fVersionResponse\x12\r\n\x05major\x18\x01 \x01(\x05\x12\r\n\x05minor\x18\x02 \x01(\x05\x12\r\n\x05patch\x18\x03 \x01(\x05\x12\x10\n\x08hostname\x18\x04 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x05 \x03(\t\x12\x0f\n\x07latency\x18\x06 \x01(\x03\x12.\n\nstarted_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"\x8a\x01\n\x0b\x44\x61taPackage\x12\x10\n\x08ToPlugin\x18\x01 \x01(\t\x12,\n\x06Params\x18\x02 \x03(\x0b\x32\x1c.msw.DataPackage.ParamsEntry\x12\x0c\n\x04\x42ody\x18\x03 \x01(\t\x1a-\n\x0bParamsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"+\n\x1b\x43lipEmbeddingForTextRequest\x12\x0c\n\x04Text\x18\x01 \x03(\t\"-\n\x1c\x43lipEmbeddingForImageRequest\x12\r\n\x05Image\x18\x01 \x01(\x0c\";\n\x15\x43lipEmbeddingResponse\x12\"\n\nEmbeddings\x18\x01 \x03(\x0b\x32\x0e.msw.Embedding\"\x1e\n\tEmbedding\x12\x11\n\tEmbedding\x18\x02 \x03(\x02\x32\xe6\x01\n\x03MSW\x12-\n\x07Version\x12\n.msw.Empty\x1a\x14.msw.VersionResponse\"\x00\x12V\n\x14\x43lipEmbeddingForText\x12 .msw.ClipEmbeddingForTextRequest\x1a\x1a.msw.ClipEmbeddingResponse\"\x00\x12X\n\x15\x43lipEmbeddingForImage\x12!.msw.ClipEmbeddingForImageRequest\x1a\x1a.msw.ClipEmbeddingResponse\"\x00\x42\x0bZ\tmsw/protob\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'msw_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['DESCRIPTOR']._loaded_options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'Z\tmsw/proto'
|
||||
_globals['_DATAPACKAGE_PARAMSENTRY']._loaded_options = None
|
||||
_globals['_DATAPACKAGE_PARAMSENTRY']._serialized_options = b'8\001'
|
||||
_globals['_EMPTY']._serialized_start=51
|
||||
_globals['_EMPTY']._serialized_end=58
|
||||
_globals['_VERSIONRESPONSE']._serialized_start=61
|
||||
_globals['_VERSIONRESPONSE']._serialized_end=220
|
||||
_globals['_DATAPACKAGE']._serialized_start=223
|
||||
_globals['_DATAPACKAGE']._serialized_end=361
|
||||
_globals['_DATAPACKAGE_PARAMSENTRY']._serialized_start=316
|
||||
_globals['_DATAPACKAGE_PARAMSENTRY']._serialized_end=361
|
||||
_globals['_CLIPEMBEDDINGFORTEXTREQUEST']._serialized_start=363
|
||||
_globals['_CLIPEMBEDDINGFORTEXTREQUEST']._serialized_end=406
|
||||
_globals['_CLIPEMBEDDINGFORIMAGEREQUEST']._serialized_start=408
|
||||
_globals['_CLIPEMBEDDINGFORIMAGEREQUEST']._serialized_end=453
|
||||
_globals['_CLIPEMBEDDINGRESPONSE']._serialized_start=455
|
||||
_globals['_CLIPEMBEDDINGRESPONSE']._serialized_end=514
|
||||
_globals['_EMBEDDING']._serialized_start=516
|
||||
_globals['_EMBEDDING']._serialized_end=546
|
||||
_globals['_MSW']._serialized_start=549
|
||||
_globals['_MSW']._serialized_end=779
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
70
clip_embedding/msw_pb2.pyi
Normal file
70
clip_embedding/msw_pb2.pyi
Normal file
@@ -0,0 +1,70 @@
|
||||
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
||||
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
class Empty(_message.Message):
|
||||
__slots__ = ()
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class VersionResponse(_message.Message):
|
||||
__slots__ = ("major", "minor", "patch", "hostname", "addr", "latency", "started_at")
|
||||
MAJOR_FIELD_NUMBER: _ClassVar[int]
|
||||
MINOR_FIELD_NUMBER: _ClassVar[int]
|
||||
PATCH_FIELD_NUMBER: _ClassVar[int]
|
||||
HOSTNAME_FIELD_NUMBER: _ClassVar[int]
|
||||
ADDR_FIELD_NUMBER: _ClassVar[int]
|
||||
LATENCY_FIELD_NUMBER: _ClassVar[int]
|
||||
STARTED_AT_FIELD_NUMBER: _ClassVar[int]
|
||||
major: int
|
||||
minor: int
|
||||
patch: int
|
||||
hostname: str
|
||||
addr: _containers.RepeatedScalarFieldContainer[str]
|
||||
latency: int
|
||||
started_at: _timestamp_pb2.Timestamp
|
||||
def __init__(self, major: _Optional[int] = ..., minor: _Optional[int] = ..., patch: _Optional[int] = ..., hostname: _Optional[str] = ..., addr: _Optional[_Iterable[str]] = ..., latency: _Optional[int] = ..., started_at: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class DataPackage(_message.Message):
|
||||
__slots__ = ("ToPlugin", "Params", "Body")
|
||||
class ParamsEntry(_message.Message):
|
||||
__slots__ = ("key", "value")
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
VALUE_FIELD_NUMBER: _ClassVar[int]
|
||||
key: str
|
||||
value: str
|
||||
def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...
|
||||
TOPLUGIN_FIELD_NUMBER: _ClassVar[int]
|
||||
PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||
BODY_FIELD_NUMBER: _ClassVar[int]
|
||||
ToPlugin: str
|
||||
Params: _containers.ScalarMap[str, str]
|
||||
Body: str
|
||||
def __init__(self, ToPlugin: _Optional[str] = ..., Params: _Optional[_Mapping[str, str]] = ..., Body: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class ClipEmbeddingForTextRequest(_message.Message):
|
||||
__slots__ = ("Text",)
|
||||
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
Text: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, Text: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class ClipEmbeddingForImageRequest(_message.Message):
|
||||
__slots__ = ("Image",)
|
||||
IMAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
Image: bytes
|
||||
def __init__(self, Image: _Optional[bytes] = ...) -> None: ...
|
||||
|
||||
class ClipEmbeddingResponse(_message.Message):
|
||||
__slots__ = ("Embeddings",)
|
||||
EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
|
||||
Embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
|
||||
def __init__(self, Embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
|
||||
|
||||
class Embedding(_message.Message):
|
||||
__slots__ = ("Embedding",)
|
||||
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
|
||||
Embedding: _containers.RepeatedScalarFieldContainer[float]
|
||||
def __init__(self, Embedding: _Optional[_Iterable[float]] = ...) -> None: ...
|
||||
187
clip_embedding/msw_pb2_grpc.py
Normal file
187
clip_embedding/msw_pb2_grpc.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import msw_pb2 as msw__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.63.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
EXPECTED_ERROR_RELEASE = '1.65.0'
|
||||
SCHEDULED_RELEASE_DATE = 'June 25, 2024'
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
warnings.warn(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in msw_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
|
||||
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
|
||||
RuntimeWarning
|
||||
)
|
||||
|
||||
|
||||
class MSWStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Version = channel.unary_unary(
|
||||
'/msw.MSW/Version',
|
||||
request_serializer=msw__pb2.Empty.SerializeToString,
|
||||
response_deserializer=msw__pb2.VersionResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.ClipEmbeddingForText = channel.unary_unary(
|
||||
'/msw.MSW/ClipEmbeddingForText',
|
||||
request_serializer=msw__pb2.ClipEmbeddingForTextRequest.SerializeToString,
|
||||
response_deserializer=msw__pb2.ClipEmbeddingResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.ClipEmbeddingForImage = channel.unary_unary(
|
||||
'/msw.MSW/ClipEmbeddingForImage',
|
||||
request_serializer=msw__pb2.ClipEmbeddingForImageRequest.SerializeToString,
|
||||
response_deserializer=msw__pb2.ClipEmbeddingResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class MSWServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Version(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def ClipEmbeddingForText(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def ClipEmbeddingForImage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_MSWServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Version': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Version,
|
||||
request_deserializer=msw__pb2.Empty.FromString,
|
||||
response_serializer=msw__pb2.VersionResponse.SerializeToString,
|
||||
),
|
||||
'ClipEmbeddingForText': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ClipEmbeddingForText,
|
||||
request_deserializer=msw__pb2.ClipEmbeddingForTextRequest.FromString,
|
||||
response_serializer=msw__pb2.ClipEmbeddingResponse.SerializeToString,
|
||||
),
|
||||
'ClipEmbeddingForImage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ClipEmbeddingForImage,
|
||||
request_deserializer=msw__pb2.ClipEmbeddingForImageRequest.FromString,
|
||||
response_serializer=msw__pb2.ClipEmbeddingResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'msw.MSW', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class MSW(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Version(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/msw.MSW/Version',
|
||||
msw__pb2.Empty.SerializeToString,
|
||||
msw__pb2.VersionResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def ClipEmbeddingForText(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/msw.MSW/ClipEmbeddingForText',
|
||||
msw__pb2.ClipEmbeddingForTextRequest.SerializeToString,
|
||||
msw__pb2.ClipEmbeddingResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def ClipEmbeddingForImage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/msw.MSW/ClipEmbeddingForImage',
|
||||
msw__pb2.ClipEmbeddingForImageRequest.SerializeToString,
|
||||
msw__pb2.ClipEmbeddingResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
5
clip_embedding/requirements.txt
Normal file
5
clip_embedding/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
grpcio-tools
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
SentencePiece
|
||||
3
clip_embedding/requirements_version.txt
Normal file
3
clip_embedding/requirements_version.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.63.0
|
||||
grpcio-tools==1.63.0
|
||||
protobuf==5.26.1
|
||||
Reference in New Issue
Block a user