File size: 10,252 Bytes
ad915da
 
 
 
3c6aeb7
44ea2d4
d7d6438
88c90d9
ad915da
3c6aeb7
 
ad915da
4e73867
 
 
88c90d9
 
 
 
ad915da
44ea2d4
 
 
 
 
 
 
 
 
88c90d9
d7d6438
 
 
 
 
 
88c90d9
d7d6438
 
 
 
 
 
 
 
 
 
 
88c90d9
d7d6438
 
 
 
 
 
 
88c90d9
d7d6438
 
 
 
 
 
 
 
88c90d9
d7d6438
 
 
 
 
 
 
88c90d9
d7d6438
 
 
 
 
 
88c90d9
d7d6438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88c90d9
d9795b9
 
88c90d9
 
d9795b9
 
 
88c90d9
d9795b9
 
 
 
 
d7d6438
88c90d9
 
 
ad915da
88c90d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ea2d4
88c90d9
 
 
d9795b9
88c90d9
 
 
 
 
 
 
7ec068d
88c90d9
 
92a4ace
88c90d9
 
 
 
 
 
d7d6438
88c90d9
 
 
 
 
 
7ec068d
88c90d9
 
 
 
 
 
ad915da
88c90d9
 
 
 
 
 
ad915da
88c90d9
ad915da
88c90d9
 
 
 
 
 
 
 
7ec068d
88c90d9
 
 
 
 
 
 
 
 
 
 
 
c7f8633
88c90d9
 
 
 
381227f
88c90d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad915da
88c90d9
 
 
 
 
 
 
 
 
 
 
 
 
ad915da
 
 
88c90d9
 
 
 
 
 
 
3c6aeb7
88c90d9
 
 
 
 
 
 
 
 
 
 
ad915da
 
 
88c90d9
 
 
ad915da
 
88c90d9
4e73867
d7d6438
ad915da
 
 
 
7ec068d
 
 
 
88c90d9
 
7ec068d
 
88c90d9
 
 
ad915da
 
d9795b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import gradio as gr
import pandas as pd
from datasets import load_dataset
import numpy as np
from functools import lru_cache
import re
from collections import Counter
import editdistance

# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
    try:
        dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
        return dataset
    except Exception:
        # Fallback to explicit file path if default loading fails
        return load_dataset("parquet", 
                          data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")

# Preprocess text for better WER calculation
def preprocess_text(text):
    if not text or not isinstance(text, str):
        return ""
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# N-gram scoring for hypothesis ranking
def score_hypothesis(hypothesis, n=4):
    if not hypothesis:
        return 0
    
    words = hypothesis.split()
    if len(words) < n:
        return len(words)
    
    ngrams = []
    for i in range(len(words) - n + 1):
        ngram = ' '.join(words[i:i+n])
        ngrams.append(ngram)
    
    unique_ngrams = len(set(ngrams))
    total_ngrams = len(ngrams)
    score = len(words) + unique_ngrams/max(1, total_ngrams) * 5
    return score

# N-gram ranking approach
def get_best_hypothesis_lm(hypotheses):
    if not hypotheses:
        return ""
    
    if isinstance(hypotheses, str):
        return hypotheses
    
    hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
    
    if not hypothesis_list:
        return ""
    
    scores = [(score_hypothesis(h), h) for h in hypothesis_list]
    best_hypothesis = max(scores, key=lambda x: x[0])[1]
    return best_hypothesis

# Subwords voting correction approach
def correct_hypotheses(hypotheses):
    if not hypotheses:
        return ""
    
    if isinstance(hypotheses, str):
        return hypotheses
    
    hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
    
    if not hypothesis_list:
        return ""
    
    word_lists = [h.split() for h in hypothesis_list]
    lengths = [len(words) for words in word_lists]
    
    if not lengths:
        return ""
    
    most_common_length = Counter(lengths).most_common(1)[0][0]
    filtered_word_lists = [words for words in word_lists if len(words) == most_common_length]
    
    if not filtered_word_lists:
        return max(hypothesis_list, key=len)
    
    corrected_words = []
    for i in range(most_common_length):
        position_words = [words[i] for words in filtered_word_lists]
        most_common_word = Counter(position_words).most_common(1)[0][0]
        corrected_words.append(most_common_word)
    
    return ' '.join(corrected_words)

# Calculate WER
def calculate_simple_wer(reference, hypothesis):
    if not reference or not hypothesis:
        return 1.0
    
    ref_words = reference.split()
    hyp_words = hypothesis.split()
    
    distance = editdistance.eval(ref_words, hyp_words)
    
    if len(ref_words) == 0:
        return 1.0
    return float(distance) / float(len(ref_words))

# Calculate WER for a group of examples with multiple methods
def calculate_wer_methods(examples, max_samples=200):
    if not examples or len(examples) == 0:
        return np.nan, np.nan, np.nan
    
    # Limit sample size for efficiency
    if hasattr(examples, 'select'):
        items_to_process = examples.select(range(min(max_samples, len(examples))))
    else:
        items_to_process = examples[:max_samples]
    
    wer_values_no_lm = []
    wer_values_lm_ranking = []
    wer_values_n_best_correction = []
    
    for ex in items_to_process:
        # Get reference transcription
        transcription = ex.get("transcription")
        if not transcription or not isinstance(transcription, str):
            continue
        
        reference = preprocess_text(transcription)
        if not reference:
            continue
        
        # Get 1-best hypothesis for baseline
        input1 = ex.get("input1")
        if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
            if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
                input1 = ex["hypothesis"][0]
            elif isinstance(ex["hypothesis"], str):
                input1 = ex["hypothesis"]
        
        # Get n-best hypotheses for other methods
        n_best_hypotheses = ex.get("hypothesis", [])
        
        # Method 1: No LM (1-best ASR output)
        if input1 and isinstance(input1, str):
            no_lm_hyp = preprocess_text(input1)
            if no_lm_hyp:
                wer_no_lm = calculate_simple_wer(reference, no_lm_hyp)
                wer_values_no_lm.append(wer_no_lm)
        
        # Method 2: N-gram ranking
        if n_best_hypotheses:
            lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
            if lm_best_hyp:
                wer_lm = calculate_simple_wer(reference, lm_best_hyp)
                wer_values_lm_ranking.append(wer_lm)
        
        # Method 3: Subwords voting correction
        if n_best_hypotheses:
            corrected_hyp = correct_hypotheses(n_best_hypotheses)
            if corrected_hyp:
                wer_corrected = calculate_simple_wer(reference, corrected_hyp)
                wer_values_n_best_correction.append(wer_corrected)
    
    # Calculate average WER for each method
    no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan
    lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan
    n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan
    
    return no_lm_wer, lm_ranking_wer, n_best_correction_wer

# Get WER metrics by source
def get_wer_metrics(dataset):
    # Group examples by source
    examples_by_source = {}
    
    for ex in dataset:
        source = ex.get("source", "unknown")
        # Skip all_et05_real as requested
        if source == "all_et05_real":
            continue
            
        if source not in examples_by_source:
            examples_by_source[source] = []
        examples_by_source[source].append(ex)
    
    # Get all unique sources
    all_sources = sorted(examples_by_source.keys())
    
    # Calculate metrics for each source
    source_results = {}
    for source in all_sources:
        examples = examples_by_source.get(source, [])
        count = len(examples)
        
        if count > 0:
            no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples)
        else:
            no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan
        
        source_results[source] = {
            "Count": count,
            "No LM Baseline": no_lm_wer,
            "N-best LM Ranking": lm_ranking_wer,
            "N-best Correction": n_best_wer
        }
    
    # Calculate overall metrics
    filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
    total_count = len(filtered_dataset)
    
    sample_size = min(500, total_count)
    sample_dataset = filtered_dataset[:sample_size]
    no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset)
    
    source_results["OVERALL"] = {
        "Count": total_count,
        "No LM Baseline": no_lm_wer,
        "N-best LM Ranking": lm_ranking_wer,
        "N-best Correction": n_best_wer
    }
    
    # Create flat DataFrame with labels in the first column
    rows = []
    
    # First add row for number of examples
    example_row = {"Metric": "Number of Examples"}
    for source in all_sources + ["OVERALL"]:
        example_row[source] = source_results[source]["Count"]
    rows.append(example_row)
    
    # Then add rows for each WER method
    no_lm_row = {"Metric": "Word Error Rate (No LM)"}
    lm_ranking_row = {"Metric": "Word Error Rate (N-gram Ranking)"}
    n_best_row = {"Metric": "Word Error Rate (Subwords Voting Correction)"}
    
    for source in all_sources + ["OVERALL"]:
        no_lm_row[source] = source_results[source]["No LM Baseline"]
        lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
        n_best_row[source] = source_results[source]["N-best Correction"]
    
    rows.append(no_lm_row)
    rows.append(lm_ranking_row)
    rows.append(n_best_row)
    
    # Create DataFrame from rows
    result_df = pd.DataFrame(rows)
    
    return result_df

# Format the dataframe for display
def format_dataframe(df):
    df = df.copy()
    
    # Find the rows containing WER values
    wer_row_indices = []
    for i, metric in enumerate(df["Metric"]):
        if "WER" in metric or "Error Rate" in metric:
            wer_row_indices.append(i)
    
    # Format WER values
    for idx in wer_row_indices:
        for col in df.columns:
            if col != "Metric":
                value = df.loc[idx, col]
                if pd.notna(value):
                    df.loc[idx, col] = f"{value:.4f}"
                else:
                    df.loc[idx, col] = "N/A"
    
    return df

# Main function to create the leaderboard
def create_leaderboard():
    dataset = load_data()
    metrics_df = get_wer_metrics(dataset)
    return format_dataframe(metrics_df)

# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
    gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)")
    gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches")
    
    with gr.Row():
        refresh_btn = gr.Button("Refresh Leaderboard")
    
    with gr.Row():
        try:
            initial_df = create_leaderboard()
            leaderboard = gr.DataFrame(initial_df)
        except Exception:
            leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
    
    def refresh_and_report():
        return create_leaderboard()
    
    refresh_btn.click(refresh_and_report, outputs=[leaderboard])

if __name__ == "__main__":
    demo.launch()