From 727ab81f31fccddb9fc4a9a5028871bc598d8c41 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Tue, 12 Sep 2023 10:02:23 +0200 Subject: [PATCH] Improve error message for invalid task and language parameters (#466) --- faster_whisper/tokenizer.py | 128 ++++++++++++++++++++++++++++++++++-- 1 file changed, 121 insertions(+), 7 deletions(-) diff --git a/faster_whisper/tokenizer.py b/faster_whisper/tokenizer.py index b040044..1af70b9 100644 --- a/faster_whisper/tokenizer.py +++ b/faster_whisper/tokenizer.py @@ -19,15 +19,21 @@ class Tokenizer: self.tokenizer = tokenizer if multilingual: + if task not in _TASKS: + raise ValueError( + "'%s' is not a valid task (accepted tasks: %s)" + % (task, ", ".join(_TASKS)) + ) + + if language not in _LANGUAGE_CODES: + raise ValueError( + "'%s' is not a valid language code (accepted language codes: %s)" + % (language, ", ".join(_LANGUAGE_CODES)) + ) + self.task = self.tokenizer.token_to_id("<|%s|>" % task) - if self.task is None: - raise ValueError("%s is not a valid task" % task) - - self.language_code = language self.language = self.tokenizer.token_to_id("<|%s|>" % language) - if self.language is None: - raise ValueError("%s is not a valid language code" % language) - + self.language_code = language else: self.task = None self.language = None @@ -161,3 +167,111 @@ class Tokenizer: word_tokens[-1].extend(subword_tokens) return words, word_tokens + + +_TASKS = ( + "transcribe", + "translate", +) + +_LANGUAGE_CODES = ( + "af", + "am", + "ar", + "as", + "az", + "ba", + "be", + "bg", + "bn", + "bo", + "br", + "bs", + "ca", + "cs", + "cy", + "da", + "de", + "el", + "en", + "es", + "et", + "eu", + "fa", + "fi", + "fo", + "fr", + "gl", + "gu", + "ha", + "haw", + "he", + "hi", + "hr", + "ht", + "hu", + "hy", + "id", + "is", + "it", + "ja", + "jw", + "ka", + "kk", + "km", + "kn", + "ko", + "la", + "lb", + "ln", + "lo", + "lt", + "lv", + "mg", + "mi", + "mk", + "ml", + "mn", + "mr", + "ms", + "mt", + "my", + "ne", + "nl", + "nn", + "no", + "oc", + "pa", + "pl", + "ps", + "pt", + "ro", + "ru", + "sa", + "sd", + "si", + "sk", + "sl", + "sn", + "so", + "sq", + "sr", + "su", + "sv", + "sw", + "ta", + "te", + "tg", + "th", + "tk", + "tl", + "tr", + "tt", + "uk", + "ur", + "uz", + "vi", + "yi", + "yo", + "zh", +)