File size: 10,252 Bytes
ad915da 3c6aeb7 44ea2d4 d7d6438 88c90d9 ad915da 3c6aeb7 ad915da 4e73867 88c90d9 ad915da 44ea2d4 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d9795b9 88c90d9 d9795b9 88c90d9 d9795b9 d7d6438 88c90d9 ad915da 88c90d9 44ea2d4 88c90d9 d9795b9 88c90d9 7ec068d 88c90d9 92a4ace 88c90d9 d7d6438 88c90d9 7ec068d 88c90d9 ad915da 88c90d9 ad915da 88c90d9 ad915da 88c90d9 7ec068d 88c90d9 c7f8633 88c90d9 381227f 88c90d9 ad915da 88c90d9 ad915da 88c90d9 3c6aeb7 88c90d9 ad915da 88c90d9 ad915da 88c90d9 4e73867 d7d6438 ad915da 7ec068d 88c90d9 7ec068d 88c90d9 ad915da d9795b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
import gradio as gr
import pandas as pd
from datasets import load_dataset
import numpy as np
from functools import lru_cache
import re
from collections import Counter
import editdistance
# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
try:
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
return dataset
except Exception:
# Fallback to explicit file path if default loading fails
return load_dataset("parquet",
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
# Preprocess text for better WER calculation
def preprocess_text(text):
if not text or not isinstance(text, str):
return ""
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# N-gram scoring for hypothesis ranking
def score_hypothesis(hypothesis, n=4):
if not hypothesis:
return 0
words = hypothesis.split()
if len(words) < n:
return len(words)
ngrams = []
for i in range(len(words) - n + 1):
ngram = ' '.join(words[i:i+n])
ngrams.append(ngram)
unique_ngrams = len(set(ngrams))
total_ngrams = len(ngrams)
score = len(words) + unique_ngrams/max(1, total_ngrams) * 5
return score
# N-gram ranking approach
def get_best_hypothesis_lm(hypotheses):
if not hypotheses:
return ""
if isinstance(hypotheses, str):
return hypotheses
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
if not hypothesis_list:
return ""
scores = [(score_hypothesis(h), h) for h in hypothesis_list]
best_hypothesis = max(scores, key=lambda x: x[0])[1]
return best_hypothesis
# Subwords voting correction approach
def correct_hypotheses(hypotheses):
if not hypotheses:
return ""
if isinstance(hypotheses, str):
return hypotheses
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
if not hypothesis_list:
return ""
word_lists = [h.split() for h in hypothesis_list]
lengths = [len(words) for words in word_lists]
if not lengths:
return ""
most_common_length = Counter(lengths).most_common(1)[0][0]
filtered_word_lists = [words for words in word_lists if len(words) == most_common_length]
if not filtered_word_lists:
return max(hypothesis_list, key=len)
corrected_words = []
for i in range(most_common_length):
position_words = [words[i] for words in filtered_word_lists]
most_common_word = Counter(position_words).most_common(1)[0][0]
corrected_words.append(most_common_word)
return ' '.join(corrected_words)
# Calculate WER
def calculate_simple_wer(reference, hypothesis):
if not reference or not hypothesis:
return 1.0
ref_words = reference.split()
hyp_words = hypothesis.split()
distance = editdistance.eval(ref_words, hyp_words)
if len(ref_words) == 0:
return 1.0
return float(distance) / float(len(ref_words))
# Calculate WER for a group of examples with multiple methods
def calculate_wer_methods(examples, max_samples=200):
if not examples or len(examples) == 0:
return np.nan, np.nan, np.nan
# Limit sample size for efficiency
if hasattr(examples, 'select'):
items_to_process = examples.select(range(min(max_samples, len(examples))))
else:
items_to_process = examples[:max_samples]
wer_values_no_lm = []
wer_values_lm_ranking = []
wer_values_n_best_correction = []
for ex in items_to_process:
# Get reference transcription
transcription = ex.get("transcription")
if not transcription or not isinstance(transcription, str):
continue
reference = preprocess_text(transcription)
if not reference:
continue
# Get 1-best hypothesis for baseline
input1 = ex.get("input1")
if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
input1 = ex["hypothesis"][0]
elif isinstance(ex["hypothesis"], str):
input1 = ex["hypothesis"]
# Get n-best hypotheses for other methods
n_best_hypotheses = ex.get("hypothesis", [])
# Method 1: No LM (1-best ASR output)
if input1 and isinstance(input1, str):
no_lm_hyp = preprocess_text(input1)
if no_lm_hyp:
wer_no_lm = calculate_simple_wer(reference, no_lm_hyp)
wer_values_no_lm.append(wer_no_lm)
# Method 2: N-gram ranking
if n_best_hypotheses:
lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
if lm_best_hyp:
wer_lm = calculate_simple_wer(reference, lm_best_hyp)
wer_values_lm_ranking.append(wer_lm)
# Method 3: Subwords voting correction
if n_best_hypotheses:
corrected_hyp = correct_hypotheses(n_best_hypotheses)
if corrected_hyp:
wer_corrected = calculate_simple_wer(reference, corrected_hyp)
wer_values_n_best_correction.append(wer_corrected)
# Calculate average WER for each method
no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan
lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan
n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan
return no_lm_wer, lm_ranking_wer, n_best_correction_wer
# Get WER metrics by source
def get_wer_metrics(dataset):
# Group examples by source
examples_by_source = {}
for ex in dataset:
source = ex.get("source", "unknown")
# Skip all_et05_real as requested
if source == "all_et05_real":
continue
if source not in examples_by_source:
examples_by_source[source] = []
examples_by_source[source].append(ex)
# Get all unique sources
all_sources = sorted(examples_by_source.keys())
# Calculate metrics for each source
source_results = {}
for source in all_sources:
examples = examples_by_source.get(source, [])
count = len(examples)
if count > 0:
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples)
else:
no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan
source_results[source] = {
"Count": count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
# Calculate overall metrics
filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
total_count = len(filtered_dataset)
sample_size = min(500, total_count)
sample_dataset = filtered_dataset[:sample_size]
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset)
source_results["OVERALL"] = {
"Count": total_count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
# Create flat DataFrame with labels in the first column
rows = []
# First add row for number of examples
example_row = {"Metric": "Number of Examples"}
for source in all_sources + ["OVERALL"]:
example_row[source] = source_results[source]["Count"]
rows.append(example_row)
# Then add rows for each WER method
no_lm_row = {"Metric": "Word Error Rate (No LM)"}
lm_ranking_row = {"Metric": "Word Error Rate (N-gram Ranking)"}
n_best_row = {"Metric": "Word Error Rate (Subwords Voting Correction)"}
for source in all_sources + ["OVERALL"]:
no_lm_row[source] = source_results[source]["No LM Baseline"]
lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
n_best_row[source] = source_results[source]["N-best Correction"]
rows.append(no_lm_row)
rows.append(lm_ranking_row)
rows.append(n_best_row)
# Create DataFrame from rows
result_df = pd.DataFrame(rows)
return result_df
# Format the dataframe for display
def format_dataframe(df):
df = df.copy()
# Find the rows containing WER values
wer_row_indices = []
for i, metric in enumerate(df["Metric"]):
if "WER" in metric or "Error Rate" in metric:
wer_row_indices.append(i)
# Format WER values
for idx in wer_row_indices:
for col in df.columns:
if col != "Metric":
value = df.loc[idx, col]
if pd.notna(value):
df.loc[idx, col] = f"{value:.4f}"
else:
df.loc[idx, col] = "N/A"
return df
# Main function to create the leaderboard
def create_leaderboard():
dataset = load_data()
metrics_df = get_wer_metrics(dataset)
return format_dataframe(metrics_df)
# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)")
gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches")
with gr.Row():
refresh_btn = gr.Button("Refresh Leaderboard")
with gr.Row():
try:
initial_df = create_leaderboard()
leaderboard = gr.DataFrame(initial_df)
except Exception:
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
def refresh_and_report():
return create_leaderboard()
refresh_btn.click(refresh_and_report, outputs=[leaderboard])
if __name__ == "__main__":
demo.launch() |