diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index b6f5709..fbecc48 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -15,19 +15,27 @@ import av import numpy as np -def decode_audio(input_file: Union[str, BinaryIO], sampling_rate: int = 16000): +def decode_audio( + input_file: Union[str, BinaryIO], + sampling_rate: int = 16000, + split_stereo: bool = False, +): """Decodes the audio. Args: input_file: Path to the input file or a file-like object. sampling_rate: Resample the audio to this sample rate. + split_stereo: Return separate left and right channels. Returns: A float32 Numpy array. + + If `split_stereo` is enabled, the function returns a 2-tuple with the + separated left and right channels. """ resampler = av.audio.resampler.AudioResampler( format="s16", - layout="mono", + layout="mono" if not split_stereo else "stereo", rate=sampling_rate, ) @@ -48,7 +56,14 @@ def decode_audio(input_file: Union[str, BinaryIO], sampling_rate: int = 16000): audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype) # Convert s16 back to f32. - return audio.astype(np.float32) / 32768.0 + audio = audio.astype(np.float32) / 32768.0 + + if split_stereo: + left_channel = audio[0::2] + right_channel = audio[1::2] + return left_channel, right_channel + + return audio def _ignore_invalid_frames(frames): diff --git a/tests/data/stereo_diarization.wav b/tests/data/stereo_diarization.wav new file mode 100644 index 0000000..3f5ae75 Binary files /dev/null and b/tests/data/stereo_diarization.wav differ diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 575bbd4..10e39db 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -1,4 +1,6 @@ -from faster_whisper import WhisperModel +import os + +from faster_whisper import WhisperModel, decode_audio def test_transcribe(jfk_path): @@ -23,3 +25,21 @@ def test_transcribe(jfk_path): assert segment.text == "".join(word.word for word in segment.words) assert segment.start == segment.words[0].start assert segment.end == segment.words[-1].end + + +def test_stereo_diarization(data_dir): + model = WhisperModel("tiny") + + audio_path = os.path.join(data_dir, "stereo_diarization.wav") + left, right = decode_audio(audio_path, split_stereo=True) + + segments, _ = model.transcribe(left) + transcription = "".join(segment.text for segment in segments).strip() + assert transcription == ( + "He began a confused complaint against the wizard, " + "who had vanished behind the curtain on the left." + ) + + segments, _ = model.transcribe(right) + transcription = "".join(segment.text for segment in segments).strip() + assert transcription == "The horizon seems extremely distant."