* Add CSV format output in transcript, containing lines of characters formatted like: <startTime-in-integer-milliseconds>, <endTime-in-integer-milliseconds>, <transcript-including-commas> * for easier reading by spreadsheets importing CSV, the third column of the CSV file is delimited by quotes, and any quote characters that might be in the transcript (which would interfere with parsing the third column as a string) are converted to "''". * fix syntax error * docstring edit Co-authored-by: Jong Wook Kim <jongwook@openai.com> Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
151 lines
4.6 KiB
Python
151 lines
4.6 KiB
Python
import json
|
|
import os
|
|
import zlib
|
|
from typing import Callable, TextIO
|
|
|
|
|
|
def exact_div(x, y):
|
|
assert x % y == 0
|
|
return x // y
|
|
|
|
|
|
def str2bool(string):
|
|
str2val = {"True": True, "False": False}
|
|
if string in str2val:
|
|
return str2val[string]
|
|
else:
|
|
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
|
|
|
|
|
def optional_int(string):
|
|
return None if string == "None" else int(string)
|
|
|
|
|
|
def optional_float(string):
|
|
return None if string == "None" else float(string)
|
|
|
|
|
|
def compression_ratio(text) -> float:
|
|
text_bytes = text.encode("utf-8")
|
|
return len(text_bytes) / len(zlib.compress(text_bytes))
|
|
|
|
|
|
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
|
assert seconds >= 0, "non-negative timestamp expected"
|
|
milliseconds = round(seconds * 1000.0)
|
|
|
|
hours = milliseconds // 3_600_000
|
|
milliseconds -= hours * 3_600_000
|
|
|
|
minutes = milliseconds // 60_000
|
|
milliseconds -= minutes * 60_000
|
|
|
|
seconds = milliseconds // 1_000
|
|
milliseconds -= seconds * 1_000
|
|
|
|
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
|
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
|
|
|
|
|
class ResultWriter:
|
|
extension: str
|
|
|
|
def __init__(self, output_dir: str):
|
|
self.output_dir = output_dir
|
|
|
|
def __call__(self, result: dict, audio_path: str):
|
|
audio_basename = os.path.basename(audio_path)
|
|
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
self.write_result(result, file=f)
|
|
|
|
def write_result(self, result: dict, file: TextIO):
|
|
raise NotImplementedError
|
|
|
|
|
|
class WriteTXT(ResultWriter):
|
|
extension: str = "txt"
|
|
|
|
def write_result(self, result: dict, file: TextIO):
|
|
for segment in result["segments"]:
|
|
print(segment['text'].strip(), file=file, flush=True)
|
|
|
|
|
|
class WriteVTT(ResultWriter):
|
|
extension: str = "vtt"
|
|
|
|
def write_result(self, result: dict, file: TextIO):
|
|
print("WEBVTT\n", file=file)
|
|
for segment in result["segments"]:
|
|
print(
|
|
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
|
f"{segment['text'].strip().replace('-->', '->')}\n",
|
|
file=file,
|
|
flush=True,
|
|
)
|
|
|
|
|
|
class WriteSRT(ResultWriter):
|
|
extension: str = "srt"
|
|
|
|
def write_result(self, result: dict, file: TextIO):
|
|
for i, segment in enumerate(result["segments"], start=1):
|
|
# write srt lines
|
|
print(
|
|
f"{i}\n"
|
|
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
|
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
|
f"{segment['text'].strip().replace('-->', '->')}\n",
|
|
file=file,
|
|
flush=True,
|
|
)
|
|
|
|
|
|
class WriteTSV(ResultWriter):
|
|
"""
|
|
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
|
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
|
|
|
Using integer milliseconds as start and end times means there's no chance of interference from
|
|
an environment setting a language encoding that causes the decimal in a floating point number
|
|
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
|
"""
|
|
extension: str = "tsv"
|
|
|
|
def write_result(self, result: dict, file: TextIO):
|
|
print("start", "end", "text", sep="\t", file=file)
|
|
for segment in result["segments"]:
|
|
print(round(1000 * segment['start']), file=file, end="\t")
|
|
print(round(1000 * segment['end']), file=file, end="\t")
|
|
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
|
|
|
|
|
class WriteJSON(ResultWriter):
|
|
extension: str = "json"
|
|
|
|
def write_result(self, result: dict, file: TextIO):
|
|
json.dump(result, file)
|
|
|
|
|
|
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
|
writers = {
|
|
"txt": WriteTXT,
|
|
"vtt": WriteVTT,
|
|
"srt": WriteSRT,
|
|
"tsv": WriteTSV,
|
|
"json": WriteJSON,
|
|
}
|
|
|
|
if output_format == "all":
|
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
|
|
|
def write_all(result: dict, file: TextIO):
|
|
for writer in all_writers:
|
|
writer(result, file)
|
|
|
|
return write_all
|
|
|
|
return writers[output_format](output_dir)
|
|
|