Commit
·
21fcf42
1
Parent(s):
8ed6ca2
add punctuations
Browse files
app.py
CHANGED
|
@@ -26,7 +26,7 @@ from pathlib import Path
|
|
| 26 |
import gradio as gr
|
| 27 |
|
| 28 |
from decode import decode
|
| 29 |
-
from model import get_pretrained_model, get_vad, language_to_models
|
| 30 |
|
| 31 |
title = "# Next-gen Kaldi: Generate subtitles for videos"
|
| 32 |
|
|
@@ -89,6 +89,7 @@ def show_file_info(in_filename: str):
|
|
| 89 |
def process_uploaded_video_file(
|
| 90 |
language: str,
|
| 91 |
repo_id: str,
|
|
|
|
| 92 |
in_filename: str,
|
| 93 |
):
|
| 94 |
if in_filename is None or in_filename == "":
|
|
@@ -105,13 +106,14 @@ def process_uploaded_video_file(
|
|
| 105 |
|
| 106 |
logging.info(f"Processing uploaded file: {in_filename}")
|
| 107 |
|
| 108 |
-
ans = process(language, repo_id, in_filename)
|
| 109 |
return (in_filename, ans[0]), ans[0], ans[1], ans[2]
|
| 110 |
|
| 111 |
|
| 112 |
def process_uploaded_audio_file(
|
| 113 |
language: str,
|
| 114 |
repo_id: str,
|
|
|
|
| 115 |
in_filename: str,
|
| 116 |
):
|
| 117 |
if in_filename is None or in_filename == "":
|
|
@@ -131,11 +133,15 @@ def process_uploaded_audio_file(
|
|
| 131 |
return process(language, repo_id, in_filename)
|
| 132 |
|
| 133 |
|
| 134 |
-
def process(language: str, repo_id: str, in_filename: str):
|
| 135 |
recognizer = get_pretrained_model(repo_id)
|
| 136 |
vad = get_vad()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
result = decode(recognizer, vad, in_filename)
|
| 139 |
logging.info(result)
|
| 140 |
|
| 141 |
srt_filename = Path(in_filename).with_suffix(".srt")
|
|
@@ -176,6 +182,11 @@ with demo:
|
|
| 176 |
inputs=language_radio,
|
| 177 |
outputs=model_dropdown,
|
| 178 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
with gr.Tabs():
|
| 181 |
with gr.TabItem("Upload video from disk"):
|
|
@@ -218,6 +229,7 @@ with demo:
|
|
| 218 |
inputs=[
|
| 219 |
language_radio,
|
| 220 |
model_dropdown,
|
|
|
|
| 221 |
uploaded_video_file,
|
| 222 |
],
|
| 223 |
outputs=[
|
|
@@ -233,6 +245,7 @@ with demo:
|
|
| 233 |
inputs=[
|
| 234 |
language_radio,
|
| 235 |
model_dropdown,
|
|
|
|
| 236 |
uploaded_audio_file,
|
| 237 |
],
|
| 238 |
outputs=[
|
|
|
|
| 26 |
import gradio as gr
|
| 27 |
|
| 28 |
from decode import decode
|
| 29 |
+
from model import get_pretrained_model, get_vad, language_to_models, get_punct_model
|
| 30 |
|
| 31 |
title = "# Next-gen Kaldi: Generate subtitles for videos"
|
| 32 |
|
|
|
|
| 89 |
def process_uploaded_video_file(
|
| 90 |
language: str,
|
| 91 |
repo_id: str,
|
| 92 |
+
add_punctuation: str,
|
| 93 |
in_filename: str,
|
| 94 |
):
|
| 95 |
if in_filename is None or in_filename == "":
|
|
|
|
| 106 |
|
| 107 |
logging.info(f"Processing uploaded file: {in_filename}")
|
| 108 |
|
| 109 |
+
ans = process(language, repo_id, add_punctuation, in_filename)
|
| 110 |
return (in_filename, ans[0]), ans[0], ans[1], ans[2]
|
| 111 |
|
| 112 |
|
| 113 |
def process_uploaded_audio_file(
|
| 114 |
language: str,
|
| 115 |
repo_id: str,
|
| 116 |
+
add_punctuation: str,
|
| 117 |
in_filename: str,
|
| 118 |
):
|
| 119 |
if in_filename is None or in_filename == "":
|
|
|
|
| 133 |
return process(language, repo_id, in_filename)
|
| 134 |
|
| 135 |
|
| 136 |
+
def process(language: str, repo_id: str, add_punctuation: str, in_filename: str):
|
| 137 |
recognizer = get_pretrained_model(repo_id)
|
| 138 |
vad = get_vad()
|
| 139 |
+
if add_punctuation == "Yes":
|
| 140 |
+
punct = get_punct_model()
|
| 141 |
+
else:
|
| 142 |
+
punct = None
|
| 143 |
|
| 144 |
+
result = decode(recognizer, vad, punct, in_filename)
|
| 145 |
logging.info(result)
|
| 146 |
|
| 147 |
srt_filename = Path(in_filename).with_suffix(".srt")
|
|
|
|
| 182 |
inputs=language_radio,
|
| 183 |
outputs=model_dropdown,
|
| 184 |
)
|
| 185 |
+
punct_radio = gr.Radio(
|
| 186 |
+
label="Whether to add punctuation",
|
| 187 |
+
choices=["Yes", "No"],
|
| 188 |
+
value="Yes",
|
| 189 |
+
)
|
| 190 |
|
| 191 |
with gr.Tabs():
|
| 192 |
with gr.TabItem("Upload video from disk"):
|
|
|
|
| 229 |
inputs=[
|
| 230 |
language_radio,
|
| 231 |
model_dropdown,
|
| 232 |
+
punct_radio,
|
| 233 |
uploaded_video_file,
|
| 234 |
],
|
| 235 |
outputs=[
|
|
|
|
| 245 |
inputs=[
|
| 246 |
language_radio,
|
| 247 |
model_dropdown,
|
| 248 |
+
punct_radio,
|
| 249 |
uploaded_audio_file,
|
| 250 |
],
|
| 251 |
outputs=[
|
decode.py
CHANGED
|
@@ -48,6 +48,7 @@ class Segment:
|
|
| 48 |
def decode(
|
| 49 |
recognizer: sherpa_onnx.OfflineRecognizer,
|
| 50 |
vad: sherpa_onnx.VoiceActivityDetector,
|
|
|
|
| 51 |
filename: str,
|
| 52 |
) -> str:
|
| 53 |
ffmpeg_cmd = [
|
|
@@ -114,6 +115,8 @@ def decode(
|
|
| 114 |
|
| 115 |
for seg, stream in zip(segments, streams):
|
| 116 |
seg.text = stream.result.text.strip()
|
|
|
|
|
|
|
| 117 |
segment_list.append(seg)
|
| 118 |
|
| 119 |
return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1))
|
|
|
|
| 48 |
def decode(
|
| 49 |
recognizer: sherpa_onnx.OfflineRecognizer,
|
| 50 |
vad: sherpa_onnx.VoiceActivityDetector,
|
| 51 |
+
punct: Optional[sherpa_onnx.OfflinePunctuation],
|
| 52 |
filename: str,
|
| 53 |
) -> str:
|
| 54 |
ffmpeg_cmd = [
|
|
|
|
| 115 |
|
| 116 |
for seg, stream in zip(segments, streams):
|
| 117 |
seg.text = stream.result.text.strip()
|
| 118 |
+
if punct is not None:
|
| 119 |
+
seg.text = punct.add_punctuation(seg.text)
|
| 120 |
segment_list.append(seg)
|
| 121 |
|
| 122 |
return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1))
|
model.py
CHANGED
|
@@ -168,6 +168,21 @@ def _get_russian_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognize
|
|
| 168 |
return recognizer
|
| 169 |
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
@lru_cache(maxsize=2)
|
| 172 |
def get_vad() -> sherpa_onnx.VoiceActivityDetector:
|
| 173 |
vad_model = _get_nn_model_filename(
|
|
|
|
| 168 |
return recognizer
|
| 169 |
|
| 170 |
|
| 171 |
+
@lru_cache(maxsize=2)
|
| 172 |
+
def get_punct_model() -> sherpa_onnx.OfflinePunctuation:
|
| 173 |
+
model = _get_nn_model_filename(
|
| 174 |
+
repo_id="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12",
|
| 175 |
+
filename="model.onnx",
|
| 176 |
+
subfolder=".",
|
| 177 |
+
)
|
| 178 |
+
config = sherpa_onnx.OfflinePunctuationConfig(
|
| 179 |
+
model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
punct = sherpa_onnx.OfflinePunctuation(config)
|
| 183 |
+
return punct
|
| 184 |
+
|
| 185 |
+
|
| 186 |
@lru_cache(maxsize=2)
|
| 187 |
def get_vad() -> sherpa_onnx.VoiceActivityDetector:
|
| 188 |
vad_model = _get_nn_model_filename(
|