Remove the usage of transformers.pipeline from BatchedInferencePipeline and fix word timestamps for batched inference (#921)
* fix word timestamps for batched inference * remove hf pipeline
This commit is contained in:
@@ -15,8 +15,7 @@ import tokenizers
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pyannote.audio import Model
|
from pyannote.audio import Model
|
||||||
from transformers import Pipeline
|
from tqdm import tqdm
|
||||||
from transformers.pipelines.pt_utils import PipelineIterator
|
|
||||||
|
|
||||||
from faster_whisper.audio import decode_audio, pad_or_trim
|
from faster_whisper.audio import decode_audio, pad_or_trim
|
||||||
from faster_whisper.feature_extractor import FeatureExtractor
|
from faster_whisper.feature_extractor import FeatureExtractor
|
||||||
@@ -105,7 +104,7 @@ class TranscriptionInfo(NamedTuple):
|
|||||||
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
|
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
|
||||||
|
|
||||||
|
|
||||||
class BatchedInferencePipeline(Pipeline):
|
class BatchedInferencePipeline:
|
||||||
"""
|
"""
|
||||||
Huggingface Pipeline wrapper for WhisperModel.
|
Huggingface Pipeline wrapper for WhisperModel.
|
||||||
Copyright (c) 2022, Max Bain
|
Copyright (c) 2022, Max Bain
|
||||||
@@ -119,55 +118,29 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
use_vad_model: bool = True,
|
use_vad_model: bool = True,
|
||||||
options: Optional[NamedTuple] = None,
|
options: Optional[NamedTuple] = None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
device: Union[int, str, "torch.device"] = -1,
|
|
||||||
chunk_length: int = 30,
|
chunk_length: int = 30,
|
||||||
vad_device: Union[int, str, "torch.device"] = "auto",
|
vad_device: Union[int, str, "torch.device"] = "auto",
|
||||||
vad_onset: float = 0.500,
|
vad_onset: float = 0.500,
|
||||||
vad_offset: float = 0.363,
|
vad_offset: float = 0.363,
|
||||||
framework="pt",
|
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
self.model: WhisperModel = model
|
self.model: WhisperModel = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.options = options
|
self.options = options
|
||||||
self.preset_language = language
|
self.preset_language = language
|
||||||
self._batch_size = kwargs.pop("batch_size", None)
|
|
||||||
self._num_workers = 0
|
|
||||||
self.use_vad_model = use_vad_model
|
self.use_vad_model = use_vad_model
|
||||||
self.vad_onset = vad_onset
|
self.vad_onset = vad_onset
|
||||||
self.vad_offset = vad_offset
|
self.vad_offset = vad_offset
|
||||||
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
|
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
|
||||||
self.vad_model = None
|
if self.use_vad_model:
|
||||||
|
|
||||||
(
|
|
||||||
self._preprocess_params,
|
|
||||||
self._forward_params,
|
|
||||||
self._postprocess_params,
|
|
||||||
) = self._sanitize_parameters(**kwargs)
|
|
||||||
self.call_count = 0
|
|
||||||
self.framework = framework
|
|
||||||
if self.framework == "pt":
|
|
||||||
self.device = self.get_device(device)
|
|
||||||
else:
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
if self.use_vad_model and self.vad_model is None:
|
|
||||||
self.vad_device = self.get_device(vad_device)
|
self.vad_device = self.get_device(vad_device)
|
||||||
|
|
||||||
# load vad model and perform VAD preprocessing if needed
|
|
||||||
self.vad_model = self.load_vad_model(
|
self.vad_model = self.load_vad_model(
|
||||||
vad_onset=self.vad_onset, vad_offset=self.vad_offset
|
vad_onset=self.vad_onset, vad_offset=self.vad_offset
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.vad_model = None
|
||||||
self.chunk_length = chunk_length # VAD merging size
|
self.chunk_length = chunk_length # VAD merging size
|
||||||
self.last_speech_timestamp = 0.0
|
self.last_speech_timestamp = 0.0
|
||||||
super(Pipeline, self).__init__()
|
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
|
||||||
preprocess_kwargs = {}
|
|
||||||
if "tokenizer" in kwargs:
|
|
||||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
|
||||||
return preprocess_kwargs, {}, {}
|
|
||||||
|
|
||||||
def get_device(self, device: Union[int, str, "torch.device"]):
|
def get_device(self, device: Union[int, str, "torch.device"]):
|
||||||
"""
|
"""
|
||||||
@@ -193,27 +166,17 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
return torch.device(f"cuda:{device}")
|
return torch.device(f"cuda:{device}")
|
||||||
|
|
||||||
def preprocess(self, inputs):
|
def forward(self, features, segments_metadata, **forward_params):
|
||||||
audio = inputs["inputs"]
|
|
||||||
to_cpu = (
|
|
||||||
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
|
|
||||||
)
|
|
||||||
features = self.model.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
|
|
||||||
:, : self.model.feature_extractor.nb_max_frames
|
|
||||||
]
|
|
||||||
|
|
||||||
inputs["features"] = features
|
|
||||||
del features
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def _forward(self, model_inputs, **forward_params):
|
|
||||||
encoder_output, outputs = self.model.generate_segment_batched(
|
encoder_output, outputs = self.model.generate_segment_batched(
|
||||||
model_inputs["features"], self.tokenizer, forward_params
|
features, self.tokenizer, forward_params
|
||||||
)
|
)
|
||||||
|
|
||||||
segment_size = encoder_output.shape[1] * 2
|
|
||||||
segmented_outputs = []
|
segmented_outputs = []
|
||||||
for segment_metadata, output in zip(model_inputs["seg_metadata"], outputs):
|
segment_sizes = []
|
||||||
|
for segment_metadata, output in zip(segments_metadata, outputs):
|
||||||
|
duration = segment_metadata["end_time"] - segment_metadata["start_time"]
|
||||||
|
segment_size = int(duration * self.model.frames_per_second)
|
||||||
|
segment_sizes.append(segment_size)
|
||||||
(
|
(
|
||||||
subsegments,
|
subsegments,
|
||||||
seek,
|
seek,
|
||||||
@@ -223,8 +186,7 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
tokens=output["tokens"],
|
tokens=output["tokens"],
|
||||||
time_offset=segment_metadata["start_time"],
|
time_offset=segment_metadata["start_time"],
|
||||||
segment_size=segment_size,
|
segment_size=segment_size,
|
||||||
segment_duration=segment_metadata["end_time"]
|
segment_duration=duration,
|
||||||
- segment_metadata["start_time"],
|
|
||||||
seek=0,
|
seek=0,
|
||||||
)
|
)
|
||||||
segmented_outputs.append(
|
segmented_outputs.append(
|
||||||
@@ -248,89 +210,13 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
segmented_outputs,
|
segmented_outputs,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
encoder_output,
|
encoder_output,
|
||||||
segment_size,
|
segment_sizes,
|
||||||
forward_params["prepend_punctuations"],
|
forward_params["prepend_punctuations"],
|
||||||
forward_params["append_punctuations"],
|
forward_params["append_punctuations"],
|
||||||
self.last_speech_timestamp,
|
self.last_speech_timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"output": segmented_outputs}
|
return segmented_outputs
|
||||||
|
|
||||||
def __call__(self, inputs, options, batch_size=None, **kwargs):
|
|
||||||
if batch_size is None:
|
|
||||||
if self._batch_size is None:
|
|
||||||
batch_size = 1
|
|
||||||
else:
|
|
||||||
batch_size = self._batch_size
|
|
||||||
|
|
||||||
(
|
|
||||||
preprocess_params,
|
|
||||||
forward_params,
|
|
||||||
postprocess_params,
|
|
||||||
) = self._sanitize_parameters(**kwargs)
|
|
||||||
|
|
||||||
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
|
||||||
preprocess_params = {
|
|
||||||
**self._preprocess_params,
|
|
||||||
**preprocess_params,
|
|
||||||
}
|
|
||||||
options_dict = options._asdict()
|
|
||||||
forward_params = {**self._forward_params, **forward_params, **options_dict}
|
|
||||||
postprocess_params = {**self._postprocess_params, **postprocess_params}
|
|
||||||
|
|
||||||
self.call_count += 1
|
|
||||||
if (
|
|
||||||
self.call_count > 10
|
|
||||||
and self.framework == "pt"
|
|
||||||
and self.device.type == "cuda"
|
|
||||||
):
|
|
||||||
logging.warning(
|
|
||||||
"You seem to be using the pipelines sequentially on GPU. Please use a Dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.get_iterator(
|
|
||||||
inputs,
|
|
||||||
batch_size,
|
|
||||||
preprocess_params,
|
|
||||||
forward_params,
|
|
||||||
postprocess_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
def postprocess(self, model_outputs):
|
|
||||||
return model_outputs
|
|
||||||
|
|
||||||
def get_iterator(
|
|
||||||
self,
|
|
||||||
inputs,
|
|
||||||
batch_size: int,
|
|
||||||
preprocess_params=None,
|
|
||||||
forward_params=None,
|
|
||||||
postprocess_params=None,
|
|
||||||
):
|
|
||||||
def stack(items):
|
|
||||||
return {
|
|
||||||
"inputs": [x["inputs"] for x in items],
|
|
||||||
"seg_metadata": [x["seg_metadata"] for x in items],
|
|
||||||
"features": torch.stack([x["features"] for x in items]),
|
|
||||||
}
|
|
||||||
|
|
||||||
if "TOKENIZERS_PARALLELISM" not in os.environ:
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=self._num_workers,
|
|
||||||
batch_size=batch_size,
|
|
||||||
collate_fn=stack,
|
|
||||||
)
|
|
||||||
model_iterator = PipelineIterator(
|
|
||||||
dataloader, self.forward, forward_params, loader_batch_size=batch_size
|
|
||||||
)
|
|
||||||
final_iterator = PipelineIterator(
|
|
||||||
model_iterator, self.postprocess, postprocess_params
|
|
||||||
)
|
|
||||||
return final_iterator
|
|
||||||
|
|
||||||
def get_language_and_tokenizer(
|
def get_language_and_tokenizer(
|
||||||
self, audio, task: Optional[str] = None, language: Optional[str] = None
|
self, audio, task: Optional[str] = None, language: Optional[str] = None
|
||||||
@@ -369,7 +255,8 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def audio_split(audio, segments, sampling_rate):
|
def audio_split(audio, segments, sampling_rate):
|
||||||
"""Returns splitted audio chunks as iterator"""
|
"""Returns splitted audio chunks as iterator"""
|
||||||
|
audio_segments = []
|
||||||
|
segments_metadata = []
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
f1 = int(seg["start"] * sampling_rate)
|
f1 = int(seg["start"] * sampling_rate)
|
||||||
f2 = int(seg["end"] * sampling_rate)
|
f2 = int(seg["end"] * sampling_rate)
|
||||||
@@ -378,7 +265,9 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
"end_time": seg["end"],
|
"end_time": seg["end"],
|
||||||
"stitched_seg": seg["segments"],
|
"stitched_seg": seg["segments"],
|
||||||
}
|
}
|
||||||
yield {"inputs": audio[f1:f2], "seg_metadata": seg_metadata}
|
audio_segments.append(audio[f1:f2])
|
||||||
|
segments_metadata.append(seg_metadata)
|
||||||
|
return audio_segments, segments_metadata
|
||||||
|
|
||||||
def load_vad_model(self, vad_onset=0.500, vad_offset=0.363):
|
def load_vad_model(self, vad_onset=0.500, vad_offset=0.363):
|
||||||
vad_model = Model.from_pretrained(self.vad_model_path)
|
vad_model = Model.from_pretrained(self.vad_model_path)
|
||||||
@@ -573,7 +462,6 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
task,
|
task,
|
||||||
all_language_probs,
|
all_language_probs,
|
||||||
) = self.get_language_and_tokenizer(audio, task, language)
|
) = self.get_language_and_tokenizer(audio, task, language)
|
||||||
batch_size = batch_size or self._batch_size
|
|
||||||
|
|
||||||
duration_after_vad = sum(
|
duration_after_vad = sum(
|
||||||
segment["end"] - segment["start"] for segment in vad_segments
|
segment["end"] - segment["start"] for segment in vad_segments
|
||||||
@@ -623,10 +511,27 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
all_language_probs=all_language_probs,
|
all_language_probs=all_language_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
audio_segments, segments_metadata = self.audio_split(
|
||||||
|
audio, vad_segments, sampling_rate
|
||||||
|
)
|
||||||
|
to_cpu = (
|
||||||
|
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
|
||||||
|
)
|
||||||
|
audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor(
|
||||||
|
padding=0
|
||||||
|
)
|
||||||
|
features = torch.stack(
|
||||||
|
[
|
||||||
|
self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[
|
||||||
|
..., : self.model.feature_extractor.nb_max_frames
|
||||||
|
]
|
||||||
|
for audio_segment in audio_segments
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
segments = self._batched_segments_generator(
|
segments = self._batched_segments_generator(
|
||||||
audio,
|
features,
|
||||||
vad_segments,
|
segments_metadata,
|
||||||
sampling_rate,
|
|
||||||
batch_size,
|
batch_size,
|
||||||
batched_options,
|
batched_options,
|
||||||
log_progress,
|
log_progress,
|
||||||
@@ -635,45 +540,40 @@ class BatchedInferencePipeline(Pipeline):
|
|||||||
return segments, info
|
return segments, info
|
||||||
|
|
||||||
def _batched_segments_generator(
|
def _batched_segments_generator(
|
||||||
self, audio, vad_segments, sampling_rate, batch_size, options, log_progress
|
self, features, segments_metadata, batch_size, options, log_progress
|
||||||
):
|
):
|
||||||
|
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
|
||||||
seg_idx = 0
|
seg_idx = 0
|
||||||
total_segments = len(vad_segments)
|
for i in range(0, len(features), batch_size):
|
||||||
for idx, out in enumerate(
|
results = self.forward(
|
||||||
self.__call__(
|
features[i : i + batch_size],
|
||||||
self.audio_split(audio, vad_segments, sampling_rate),
|
segments_metadata[i : i + batch_size],
|
||||||
batch_size=batch_size,
|
**options._asdict(),
|
||||||
options=options,
|
|
||||||
)
|
)
|
||||||
):
|
|
||||||
if log_progress:
|
|
||||||
percent_complete = ((idx + 1) / total_segments) * 100
|
|
||||||
self.model.logger.info(f"Progress: {percent_complete:.2f}%...")
|
|
||||||
|
|
||||||
responses = out["output"]
|
for result in results:
|
||||||
if batch_size == 1:
|
for segment in result:
|
||||||
responses = responses[0]
|
seg_idx += 1
|
||||||
|
yield Segment(
|
||||||
|
seek=int(result[-1]["end"] * self.model.frames_per_second),
|
||||||
|
id=seg_idx,
|
||||||
|
text=segment["text"],
|
||||||
|
start=round(segment["start"], 3),
|
||||||
|
end=round(segment["end"], 3),
|
||||||
|
words=(
|
||||||
|
None
|
||||||
|
if not options.word_timestamps
|
||||||
|
else [Word(**word) for word in segment["words"]]
|
||||||
|
),
|
||||||
|
tokens=segment["tokens"],
|
||||||
|
avg_logprob=segment["avg_logprob"],
|
||||||
|
no_speech_prob=segment["no_speech_prob"],
|
||||||
|
compression_ratio=segment["compression_ratio"],
|
||||||
|
)
|
||||||
|
|
||||||
for response in responses:
|
pbar.update(1)
|
||||||
seg_idx += 1
|
|
||||||
segments = Segment(
|
|
||||||
seek=int(responses[-1]["end"] * self.model.frames_per_second),
|
|
||||||
id=seg_idx,
|
|
||||||
text=response["text"],
|
|
||||||
start=round(response["start"], 3),
|
|
||||||
end=round(response["end"], 3),
|
|
||||||
words=(
|
|
||||||
None
|
|
||||||
if not options.word_timestamps
|
|
||||||
else [Word(**word) for word in response["words"]]
|
|
||||||
),
|
|
||||||
tokens=response["tokens"],
|
|
||||||
avg_logprob=response["avg_logprob"],
|
|
||||||
no_speech_prob=response["no_speech_prob"],
|
|
||||||
compression_ratio=response["compression_ratio"],
|
|
||||||
)
|
|
||||||
yield segments
|
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
# revert the tokenizer if multilingual inference is enabled
|
# revert the tokenizer if multilingual inference is enabled
|
||||||
if self.preset_language is None:
|
if self.preset_language is None:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ ctranslate2>=4.0,<5
|
|||||||
huggingface_hub>=0.13
|
huggingface_hub>=0.13
|
||||||
tokenizers>=0.13,<1
|
tokenizers>=0.13,<1
|
||||||
onnxruntime>=1.14,<2
|
onnxruntime>=1.14,<2
|
||||||
transformers
|
|
||||||
pyannote-audio>=3.1.1
|
pyannote-audio>=3.1.1
|
||||||
torch>=2.1.1
|
torch>=2.1.1
|
||||||
torchaudio>=2.1.2
|
torchaudio>=2.1.2
|
||||||
|
tqdm
|
||||||
Reference in New Issue
Block a user