Skip to content

Commit

Permalink
Handle gradio None values
Browse files Browse the repository at this point in the history
  • Loading branch information
jhj0517 committed Oct 28, 2024
1 parent 8933c2e commit e58ee71
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions modules/whisper/data_classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gradio as gr
import torch
from typing import Optional, Dict, List
from typing import Optional, Dict, List, Union
from pydantic import BaseModel, Field, field_validator, ConfigDict
from gradio_i18n import Translate, gettext as _
from enum import Enum
Expand Down Expand Up @@ -241,7 +241,7 @@ class WhisperParams(BaseParams):
default=True,
description="Suppress blank outputs at start of sampling"
)
suppress_tokens: Optional[str] = Field(default="[-1]", description="Token IDs to suppress")
suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress")
max_initial_timestamp: float = Field(
default=0.0,
ge=0.0,
Expand Down Expand Up @@ -279,6 +279,20 @@ def validate_lang(cls, v):
from modules.utils.constants import AUTOMATIC_DETECTION
return None if v == AUTOMATIC_DETECTION.unwrap() else v

@field_validator('suppress_tokens')
def validate_supress_tokens(cls, v):
import ast
try:
if isinstance(v, str):
suppress_tokens = ast.literal_eval(v)
if not isinstance(suppress_tokens, list):
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
return suppress_tokens
if isinstance(v, list):
return v
except Exception as e:
raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")

@classmethod
def to_gradio_inputs(cls,
defaults: Optional[Dict] = None,
Expand All @@ -301,7 +315,7 @@ def to_gradio_inputs(cls,
gr.Dropdown(
label=_("Language"),
choices=available_langs,
value=defaults.get("lang", cls.__fields__["lang"].default),
value=defaults.get("lang", AUTOMATIC_DETECTION),
),
gr.Checkbox(
label=_("Translate to English?"),
Expand Down Expand Up @@ -407,7 +421,7 @@ def to_gradio_inputs(cls,
),
gr.Textbox(
label="Suppress Tokens",
value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default),
value=defaults.get("suppress_tokens", "[-1]"),
info="Token IDs to suppress"
),
gr.Number(
Expand Down

0 comments on commit e58ee71

Please sign in to comment.