Remove the usage of transformers.pipeline from BatchedInferencePipeline and fix word timestamps for batched inference (#921)
* fix word timestamps for batched inference * remove hf pipeline
This commit is contained in:
@@ -15,8 +15,7 @@ import tokenizers
|
||||
import torch
|
||||
|
||||
from pyannote.audio import Model
|
||||
from transformers import Pipeline
|
||||
from transformers.pipelines.pt_utils import PipelineIterator
|
||||
from tqdm import tqdm
|
||||
|
||||
from faster_whisper.audio import decode_audio, pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
@@ -105,7 +104,7 @@ class TranscriptionInfo(NamedTuple):
|
||||
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
|
||||
|
||||
|
||||
class BatchedInferencePipeline(Pipeline):
|
||||
class BatchedInferencePipeline:
|
||||
"""
|
||||
Huggingface Pipeline wrapper for WhisperModel.
|
||||
Copyright (c) 2022, Max Bain
|
||||
@@ -119,55 +118,29 @@ class BatchedInferencePipeline(Pipeline):
|
||||
use_vad_model: bool = True,
|
||||
options: Optional[NamedTuple] = None,
|
||||
tokenizer=None,
|
||||
device: Union[int, str, "torch.device"] = -1,
|
||||
chunk_length: int = 30,
|
||||
vad_device: Union[int, str, "torch.device"] = "auto",
|
||||
vad_onset: float = 0.500,
|
||||
vad_offset: float = 0.363,
|
||||
framework="pt",
|
||||
language: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.model: WhisperModel = model
|
||||
self.tokenizer = tokenizer
|
||||
self.options = options
|
||||
self.preset_language = language
|
||||
self._batch_size = kwargs.pop("batch_size", None)
|
||||
self._num_workers = 0
|
||||
self.use_vad_model = use_vad_model
|
||||
self.vad_onset = vad_onset
|
||||
self.vad_offset = vad_offset
|
||||
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
|
||||
self.vad_model = None
|
||||
|
||||
(
|
||||
self._preprocess_params,
|
||||
self._forward_params,
|
||||
self._postprocess_params,
|
||||
) = self._sanitize_parameters(**kwargs)
|
||||
self.call_count = 0
|
||||
self.framework = framework
|
||||
if self.framework == "pt":
|
||||
self.device = self.get_device(device)
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
if self.use_vad_model and self.vad_model is None:
|
||||
if self.use_vad_model:
|
||||
self.vad_device = self.get_device(vad_device)
|
||||
|
||||
# load vad model and perform VAD preprocessing if needed
|
||||
self.vad_model = self.load_vad_model(
|
||||
vad_onset=self.vad_onset, vad_offset=self.vad_offset
|
||||
)
|
||||
else:
|
||||
self.vad_model = None
|
||||
self.chunk_length = chunk_length # VAD merging size
|
||||
self.last_speech_timestamp = 0.0
|
||||
super(Pipeline, self).__init__()
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "tokenizer" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def get_device(self, device: Union[int, str, "torch.device"]):
|
||||
"""
|
||||
@@ -193,27 +166,17 @@ class BatchedInferencePipeline(Pipeline):
|
||||
else:
|
||||
return torch.device(f"cuda:{device}")
|
||||
|
||||
def preprocess(self, inputs):
|
||||
audio = inputs["inputs"]
|
||||
to_cpu = (
|
||||
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
|
||||
)
|
||||
features = self.model.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
|
||||
:, : self.model.feature_extractor.nb_max_frames
|
||||
]
|
||||
|
||||
inputs["features"] = features
|
||||
del features
|
||||
return inputs
|
||||
|
||||
def _forward(self, model_inputs, **forward_params):
|
||||
def forward(self, features, segments_metadata, **forward_params):
|
||||
encoder_output, outputs = self.model.generate_segment_batched(
|
||||
model_inputs["features"], self.tokenizer, forward_params
|
||||
features, self.tokenizer, forward_params
|
||||
)
|
||||
|
||||
segment_size = encoder_output.shape[1] * 2
|
||||
segmented_outputs = []
|
||||
for segment_metadata, output in zip(model_inputs["seg_metadata"], outputs):
|
||||
segment_sizes = []
|
||||
for segment_metadata, output in zip(segments_metadata, outputs):
|
||||
duration = segment_metadata["end_time"] - segment_metadata["start_time"]
|
||||
segment_size = int(duration * self.model.frames_per_second)
|
||||
segment_sizes.append(segment_size)
|
||||
(
|
||||
subsegments,
|
||||
seek,
|
||||
@@ -223,8 +186,7 @@ class BatchedInferencePipeline(Pipeline):
|
||||
tokens=output["tokens"],
|
||||
time_offset=segment_metadata["start_time"],
|
||||
segment_size=segment_size,
|
||||
segment_duration=segment_metadata["end_time"]
|
||||
- segment_metadata["start_time"],
|
||||
segment_duration=duration,
|
||||
seek=0,
|
||||
)
|
||||
segmented_outputs.append(
|
||||
@@ -248,89 +210,13 @@ class BatchedInferencePipeline(Pipeline):
|
||||
segmented_outputs,
|
||||
self.tokenizer,
|
||||
encoder_output,
|
||||
segment_size,
|
||||
segment_sizes,
|
||||
forward_params["prepend_punctuations"],
|
||||
forward_params["append_punctuations"],
|
||||
self.last_speech_timestamp,
|
||||
)
|
||||
|
||||
return {"output": segmented_outputs}
|
||||
|
||||
def __call__(self, inputs, options, batch_size=None, **kwargs):
|
||||
if batch_size is None:
|
||||
if self._batch_size is None:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = self._batch_size
|
||||
|
||||
(
|
||||
preprocess_params,
|
||||
forward_params,
|
||||
postprocess_params,
|
||||
) = self._sanitize_parameters(**kwargs)
|
||||
|
||||
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
||||
preprocess_params = {
|
||||
**self._preprocess_params,
|
||||
**preprocess_params,
|
||||
}
|
||||
options_dict = options._asdict()
|
||||
forward_params = {**self._forward_params, **forward_params, **options_dict}
|
||||
postprocess_params = {**self._postprocess_params, **postprocess_params}
|
||||
|
||||
self.call_count += 1
|
||||
if (
|
||||
self.call_count > 10
|
||||
and self.framework == "pt"
|
||||
and self.device.type == "cuda"
|
||||
):
|
||||
logging.warning(
|
||||
"You seem to be using the pipelines sequentially on GPU. Please use a Dataset"
|
||||
)
|
||||
|
||||
return self.get_iterator(
|
||||
inputs,
|
||||
batch_size,
|
||||
preprocess_params,
|
||||
forward_params,
|
||||
postprocess_params,
|
||||
)
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
inputs,
|
||||
batch_size: int,
|
||||
preprocess_params=None,
|
||||
forward_params=None,
|
||||
postprocess_params=None,
|
||||
):
|
||||
def stack(items):
|
||||
return {
|
||||
"inputs": [x["inputs"] for x in items],
|
||||
"seg_metadata": [x["seg_metadata"] for x in items],
|
||||
"features": torch.stack([x["features"] for x in items]),
|
||||
}
|
||||
|
||||
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=self._num_workers,
|
||||
batch_size=batch_size,
|
||||
collate_fn=stack,
|
||||
)
|
||||
model_iterator = PipelineIterator(
|
||||
dataloader, self.forward, forward_params, loader_batch_size=batch_size
|
||||
)
|
||||
final_iterator = PipelineIterator(
|
||||
model_iterator, self.postprocess, postprocess_params
|
||||
)
|
||||
return final_iterator
|
||||
return segmented_outputs
|
||||
|
||||
def get_language_and_tokenizer(
|
||||
self, audio, task: Optional[str] = None, language: Optional[str] = None
|
||||
@@ -369,7 +255,8 @@ class BatchedInferencePipeline(Pipeline):
|
||||
@staticmethod
|
||||
def audio_split(audio, segments, sampling_rate):
|
||||
"""Returns splitted audio chunks as iterator"""
|
||||
|
||||
audio_segments = []
|
||||
segments_metadata = []
|
||||
for seg in segments:
|
||||
f1 = int(seg["start"] * sampling_rate)
|
||||
f2 = int(seg["end"] * sampling_rate)
|
||||
@@ -378,7 +265,9 @@ class BatchedInferencePipeline(Pipeline):
|
||||
"end_time": seg["end"],
|
||||
"stitched_seg": seg["segments"],
|
||||
}
|
||||
yield {"inputs": audio[f1:f2], "seg_metadata": seg_metadata}
|
||||
audio_segments.append(audio[f1:f2])
|
||||
segments_metadata.append(seg_metadata)
|
||||
return audio_segments, segments_metadata
|
||||
|
||||
def load_vad_model(self, vad_onset=0.500, vad_offset=0.363):
|
||||
vad_model = Model.from_pretrained(self.vad_model_path)
|
||||
@@ -573,7 +462,6 @@ class BatchedInferencePipeline(Pipeline):
|
||||
task,
|
||||
all_language_probs,
|
||||
) = self.get_language_and_tokenizer(audio, task, language)
|
||||
batch_size = batch_size or self._batch_size
|
||||
|
||||
duration_after_vad = sum(
|
||||
segment["end"] - segment["start"] for segment in vad_segments
|
||||
@@ -623,10 +511,27 @@ class BatchedInferencePipeline(Pipeline):
|
||||
all_language_probs=all_language_probs,
|
||||
)
|
||||
|
||||
audio_segments, segments_metadata = self.audio_split(
|
||||
audio, vad_segments, sampling_rate
|
||||
)
|
||||
to_cpu = (
|
||||
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
|
||||
)
|
||||
audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor(
|
||||
padding=0
|
||||
)
|
||||
features = torch.stack(
|
||||
[
|
||||
self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[
|
||||
..., : self.model.feature_extractor.nb_max_frames
|
||||
]
|
||||
for audio_segment in audio_segments
|
||||
]
|
||||
)
|
||||
|
||||
segments = self._batched_segments_generator(
|
||||
audio,
|
||||
vad_segments,
|
||||
sampling_rate,
|
||||
features,
|
||||
segments_metadata,
|
||||
batch_size,
|
||||
batched_options,
|
||||
log_progress,
|
||||
@@ -635,45 +540,40 @@ class BatchedInferencePipeline(Pipeline):
|
||||
return segments, info
|
||||
|
||||
def _batched_segments_generator(
|
||||
self, audio, vad_segments, sampling_rate, batch_size, options, log_progress
|
||||
self, features, segments_metadata, batch_size, options, log_progress
|
||||
):
|
||||
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
|
||||
seg_idx = 0
|
||||
total_segments = len(vad_segments)
|
||||
for idx, out in enumerate(
|
||||
self.__call__(
|
||||
self.audio_split(audio, vad_segments, sampling_rate),
|
||||
batch_size=batch_size,
|
||||
options=options,
|
||||
for i in range(0, len(features), batch_size):
|
||||
results = self.forward(
|
||||
features[i : i + batch_size],
|
||||
segments_metadata[i : i + batch_size],
|
||||
**options._asdict(),
|
||||
)
|
||||
):
|
||||
if log_progress:
|
||||
percent_complete = ((idx + 1) / total_segments) * 100
|
||||
self.model.logger.info(f"Progress: {percent_complete:.2f}%...")
|
||||
|
||||
responses = out["output"]
|
||||
if batch_size == 1:
|
||||
responses = responses[0]
|
||||
for result in results:
|
||||
for segment in result:
|
||||
seg_idx += 1
|
||||
yield Segment(
|
||||
seek=int(result[-1]["end"] * self.model.frames_per_second),
|
||||
id=seg_idx,
|
||||
text=segment["text"],
|
||||
start=round(segment["start"], 3),
|
||||
end=round(segment["end"], 3),
|
||||
words=(
|
||||
None
|
||||
if not options.word_timestamps
|
||||
else [Word(**word) for word in segment["words"]]
|
||||
),
|
||||
tokens=segment["tokens"],
|
||||
avg_logprob=segment["avg_logprob"],
|
||||
no_speech_prob=segment["no_speech_prob"],
|
||||
compression_ratio=segment["compression_ratio"],
|
||||
)
|
||||
|
||||
for response in responses:
|
||||
seg_idx += 1
|
||||
segments = Segment(
|
||||
seek=int(responses[-1]["end"] * self.model.frames_per_second),
|
||||
id=seg_idx,
|
||||
text=response["text"],
|
||||
start=round(response["start"], 3),
|
||||
end=round(response["end"], 3),
|
||||
words=(
|
||||
None
|
||||
if not options.word_timestamps
|
||||
else [Word(**word) for word in response["words"]]
|
||||
),
|
||||
tokens=response["tokens"],
|
||||
avg_logprob=response["avg_logprob"],
|
||||
no_speech_prob=response["no_speech_prob"],
|
||||
compression_ratio=response["compression_ratio"],
|
||||
)
|
||||
yield segments
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
# revert the tokenizer if multilingual inference is enabled
|
||||
if self.preset_language is None:
|
||||
self.tokenizer = None
|
||||
|
||||
@@ -2,7 +2,7 @@ ctranslate2>=4.0,<5
|
||||
huggingface_hub>=0.13
|
||||
tokenizers>=0.13,<1
|
||||
onnxruntime>=1.14,<2
|
||||
transformers
|
||||
pyannote-audio>=3.1.1
|
||||
torch>=2.1.1
|
||||
torchaudio>=2.1.2
|
||||
tqdm
|
||||
Reference in New Issue
Block a user