word-level timestamps in transcribe() (#869)
* word-level timestamps in `transcribe()` * moving to `timing.py` * numba implementation for dtw, replacing dtw-python * triton implementation for dtw * add test for dtw implementations * triton implementation of median_filter * a simple word-level timestamps test * add scipy as dev dependency * installs an older version of Triton if CUDA < 11.4 * fix broken merge * loosen nvcc version match regex * find_alignment() function * miscellaneous improvements * skip median filtering when the input is too small * Expose punctuation options in cli and transcribe() (#973) * fix merge error * fix merge error 2 * annotating that word_timestamps is experimental --------- Co-authored-by: ryanheise <ryan@ryanheise.com>
This commit is contained in:
20
setup.py
20
setup.py
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pkg_resources
|
||||
from setuptools import setup, find_packages
|
||||
@@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"):
|
||||
return locals()["__version__"]
|
||||
|
||||
|
||||
requirements = []
|
||||
if sys.platform.startswith("linux"):
|
||||
triton_requirement = "triton>=2.0.0.dev20221202"
|
||||
try:
|
||||
import re
|
||||
import subprocess
|
||||
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
|
||||
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
|
||||
if (int(major), int(minor)) < (11, 4):
|
||||
# the last version supporting CUDA < 11.4
|
||||
triton_requirement = "triton==2.0.0.dev20221011"
|
||||
except (IndexError, OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
requirements.append(triton_requirement)
|
||||
|
||||
setup(
|
||||
name="openai-whisper",
|
||||
py_modules=["whisper"],
|
||||
@@ -22,7 +38,7 @@ setup(
|
||||
url="https://github.com/openai/whisper",
|
||||
license="MIT",
|
||||
packages=find_packages(exclude=["tests*"]),
|
||||
install_requires=[
|
||||
install_requires=requirements + [
|
||||
str(r)
|
||||
for r in pkg_resources.parse_requirements(
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
||||
@@ -32,5 +48,5 @@ setup(
|
||||
"console_scripts": ["whisper=whisper.transcribe:cli"],
|
||||
},
|
||||
include_package_data=True,
|
||||
extras_require={"dev": ["pytest"]},
|
||||
extras_require={"dev": ["pytest", "scipy"]},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user