Added --output_format option (#333)

* Added --output option

--output option will help select the output files that will be generated.

Corrected the logic, which wrongly shows progress bar when verbose is set to False

* Changed output_files variable

* Changed back the tqdm verbose

* refactor output format handling

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
This commit is contained in:
Aaryan YVS
2023-01-22 13:28:38 +05:30
committed by GitHub
parent 9f7aba6099
commit da600abd2b
2 changed files with 82 additions and 50 deletions

View File

@@ -11,7 +11,7 @@ import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, get_writer
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
@@ -260,6 +260,7 @@ def cli():
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
@@ -286,6 +287,7 @@ def cli():
model_name: str = args.pop("model") model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir") model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device") device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@@ -308,22 +310,11 @@ def cli():
from . import load_model from . import load_model
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args) result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
audio_basename = os.path.basename(audio_path)
# save TXT
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result["segments"], file=txt)
# save VTT
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt)
# save SRT
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -1,5 +1,7 @@
import json
import os
import zlib import zlib
from typing import Iterator, TextIO from typing import Callable, TextIO
def exact_div(x, y): def exact_div(x, y):
@@ -45,44 +47,83 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def write_txt(transcript: Iterator[dict], file: TextIO): class ResultWriter:
for segment in transcript: extension: str
print(segment['text'].strip(), file=file, flush=True)
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
def write_result(self, result: dict, file: TextIO):
raise NotImplementedError
def write_vtt(transcript: Iterator[dict], file: TextIO): class WriteTXT(ResultWriter):
print("WEBVTT\n", file=file) extension: str = "txt"
for segment in transcript:
print( def write_result(self, result: dict, file: TextIO):
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" for segment in result["segments"]:
f"{segment['text'].strip().replace('-->', '->')}\n", print(segment['text'].strip(), file=file, flush=True)
file=file,
flush=True,
)
def write_srt(transcript: Iterator[dict], file: TextIO): class WriteVTT(ResultWriter):
""" extension: str = "vtt"
Write a transcript to a file in SRT format.
Example usage: def write_result(self, result: dict, file: TextIO):
from pathlib import Path print("WEBVTT\n", file=file)
from whisper.utils import write_srt for segment in result["segments"]:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT class WriteSRT(ResultWriter):
audio_basename = Path(audio_path).stem extension: str = "srt"
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt) def write_result(self, result: dict, file: TextIO):
""" for i, segment in enumerate(result["segments"], start=1):
for i, segment in enumerate(transcript, start=1): # write srt lines
# write srt lines print(
print( f"{i}\n"
f"{i}\n" f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" f"{segment['text'].strip().replace('-->', '->')}\n",
f"{segment['text'].strip().replace('-->', '->')}\n", file=file,
file=file, flush=True,
flush=True, )
)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO):
json.dump(result, file)
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"json": WriteJSON,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO):
for writer in all_writers:
writer(result, file)
return write_all
return writers[output_format](output_dir)