11import gradio as gr
2- import gc
32import torch
43import os
54
65from typing import TYPE_CHECKING
76
7+ from tts_webui .utils .manage_model_state import manage_model_state
8+ from tts_webui .utils .list_dir_models import unload_model_button
9+
810if TYPE_CHECKING :
911 from transformers import Pipeline
1012
@@ -14,7 +16,7 @@ def extension__tts_generation_webui():
1416 return {
1517 "package_name" : "extension_whisper" ,
1618 "name" : "Whisper" ,
17- "version" : "0.0.1 " ,
19+ "version" : "0.0.2 " ,
1820 "requirements" : "git+https://github.com/rsxdalv/extension_whisper@main" ,
1921 "description" : "Whisper allows transcribing audio files." ,
2022 "extension_type" : "interface" ,
@@ -28,40 +30,59 @@ def extension__tts_generation_webui():
2830 }
2931
3032
31- local_dir = os .path .join ("data" , "models" , "whisper" )
32- local_cache_dir = os .path .join (local_dir , "cache" )
33+ @manage_model_state ("whisper" )
34+ def get_model (
35+ model_name = "openai/whisper-large-v3" ,
36+ torch_dtype = torch .float16 ,
37+ device = "cuda:0" ,
38+ compile = False ,
39+ ):
40+ from transformers import AutoModelForSpeechSeq2Seq
41+ from transformers import AutoProcessor
42+
43+ model = AutoModelForSpeechSeq2Seq .from_pretrained (
44+ model_name , torch_dtype = torch_dtype , low_cpu_mem_usage = True
45+ ).to (device )
46+ if compile :
47+ model .generation_config .cache_implementation = "static"
48+ model .generation_config .max_new_tokens = 256
49+ model .forward = torch .compile (
50+ model .forward , mode = "reduce-overhead" , fullgraph = True
51+ )
3352
34- pipe = None
35- last_model_name = None
53+ processor = AutoProcessor .from_pretrained (model_name )
3654
55+ return model , processor
3756
38- def unload_models ():
39- global pipe , last_model_name
40- pipe = None
41- last_model_name = None
42- gc .collect ()
43- if torch .cuda .is_available ():
44- torch .cuda .empty_cache ()
45- return "Unloaded"
4657
58+ local_dir = os .path .join ("data" , "models" , "whisper" )
59+ local_cache_dir = os .path .join (local_dir , "cache" )
4760
61+
62+ @manage_model_state ("whisper-pipe" )
4863def get_pipe (model_name , device = "cuda:0" ) -> "Pipeline" :
4964 from transformers import pipeline
5065
51- global pipe , last_model_name
52- if pipe is not None :
53- if model_name == last_model_name :
54- return pipe
55- unload_models ()
56- pipe = pipeline (
57- "automatic-speech-recognition" ,
66+ torch_dtype = torch .float16
67+
68+ model , processor = get_model (
69+ # model_name, torch_dtype=torch.float16, device=device, compile=False
5870 model_name ,
71+ torch_dtype = torch_dtype ,
72+ device = device ,
73+ compile = False ,
74+ )
75+ return pipeline (
76+ "automatic-speech-recognition" ,
77+ model = model ,
78+ tokenizer = processor .tokenizer ,
79+ feature_extractor = processor .feature_extractor ,
80+ # chunk_length_s=30,
81+ # batch_size=16, # batch size for inference - set based on your device
5982 torch_dtype = torch .float16 ,
6083 model_kwargs = {"cache_dir" : local_cache_dir },
6184 device = device ,
6285 )
63- last_model_name = model_name
64- return pipe
6586
6687
6788def transcribe (inputs , model_name = "openai/whisper-large-v3" ):
@@ -72,13 +93,11 @@ def transcribe(inputs, model_name="openai/whisper-large-v3"):
7293
7394 pipe = get_pipe (model_name )
7495
75- generate_kwargs = (
76- {"task" : "transcribe" } if model_name == "openai/whisper-large-v3" else {}
77- )
78-
7996 result = pipe (
8097 inputs ,
81- generate_kwargs = generate_kwargs ,
98+ generate_kwargs = (
99+ {"task" : "transcribe" } if model_name == "openai/whisper-large-v3" else {}
100+ ),
82101 return_timestamps = True ,
83102 )
84103 return result ["text" ]
@@ -108,7 +127,8 @@ def transcribe_ui():
108127 text = gr .Textbox (label = "Transcription" , interactive = False )
109128
110129 with gr .Row ():
111- unload_models_button = gr .Button ("Unload models" )
130+ unload_model_button ("whisper-pipe" )
131+ unload_model_button ("whisper" )
112132
113133 transcribe_button = gr .Button ("Transcribe" , variant = "primary" )
114134
@@ -117,21 +137,12 @@ def transcribe_ui():
117137 inputs = [audio , model_dropdown ],
118138 outputs = [text ],
119139 api_name = "whisper_transcribe" ,
120- ).then (
121- fn = lambda : gr .Button (value = "Unload models" ),
122- outputs = [unload_models_button ],
123- )
124-
125- unload_models_button .click (
126- fn = unload_models ,
127- outputs = [unload_models_button ],
128- api_name = "whisper_unload_models" ,
129140 )
130141
131142
132143if __name__ == "__main__" :
133144 if "demo" in locals ():
134- demo .close ()
145+ locals ()[ " demo" ] .close ()
135146
136147 with gr .Blocks () as demo :
137148 with gr .Tab ("Whisper" ):
0 commit comments