File size: 16,305 Bytes
a161649
 
 
 
 
 
0659e3b
a161649
 
0659e3b
 
a161649
 
0659e3b
a161649
0659e3b
a161649
0659e3b
 
 
 
 
a161649
 
0659e3b
 
a161649
 
0659e3b
 
 
 
a161649
0659e3b
a161649
 
0659e3b
a161649
0659e3b
a161649
0659e3b
a161649
 
0659e3b
 
 
 
a161649
 
0659e3b
 
 
 
a161649
0659e3b
 
 
 
 
 
a161649
0659e3b
a161649
 
0659e3b
 
a161649
 
0659e3b
 
 
a161649
 
0659e3b
 
 
a161649
 
 
 
0659e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a161649
 
0659e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a161649
 
0659e3b
 
 
 
 
a161649
0659e3b
 
 
1da0418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0659e3b
 
 
 
 
 
a161649
0659e3b
 
 
 
 
 
 
 
a161649
0659e3b
 
 
a161649
0659e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1da0418
0659e3b
 
1da0418
0659e3b
 
 
1da0418
 
 
0659e3b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
"""
Test-Time Scaling Module
Implements perplexity-based scoring for generated audio codes
"""
import torch
import torch.nn.functional as F
from typing import Tuple, Optional, Dict, Any, List
from loguru import logger
import yaml
import math
import re


def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
    """
    Calculate Pointwise Mutual Information (PMI) score.
    
    PMI = log P(condition|codes) - log P(condition)
        = log [P(codes|condition) / P(codes)]
    
    This removes the bias from P(condition) and measures how much the codes
    improve our ability to predict the condition.
    
    Args:
        log_prob_conditional: Average log probability of condition given codes
        log_prob_unconditional: Average log probability of condition without codes
        
    Returns:
        PMI score (higher is better, can be positive or negative)
        - Positive: codes improve prediction → good match
        - Zero: codes don't help → no correlation
        - Negative: codes hurt prediction → poor match
    """
    return log_prob_conditional - log_prob_unconditional


def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
    """
    Convert PMI score to normalized [0, 1] range using sigmoid function.
    
    score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
    
    Args:
        pmi: PMI score (can be positive or negative)
        scale: Scale parameter to control sensitivity (default 0.1)
               - Smaller scale: more sensitive to PMI changes
               - Larger scale: less sensitive to PMI changes
        
    Returns:
        Normalized score in [0, 1] range, where:
        - PMI > 0 → score > 0.5 (good match)
        - PMI = 0 → score = 0.5 (neutral)
        - PMI < 0 → score < 0.5 (poor match)
        
    Examples (scale=1.0):
        PMI=2.0  → score≈0.88  (excellent)
        PMI=1.0  → score≈0.73  (good)
        PMI=0.0  → score=0.50  (neutral)
        PMI=-1.0 → score≈0.27  (poor)
        PMI=-2.0 → score≈0.12  (bad)
    """
    return 1.0 / (1.0 + math.exp(-pmi / scale))


def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
                                       target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        llm_handler: The handler containing the model and tokenizer.
        formatted_prompt: The input context.
        target_text: The text we want to calculate probability/recall for.
        
    Returns:
        Tuple of (target_logits, target_ids)
        - target_logits: Logits used to predict the target tokens.
        - target_ids: The ground truth token IDs of the target.
    """
    model = llm_handler.get_hf_model_for_scoring()
    tokenizer = llm_handler.llm_tokenizer
    device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device

    # 1. Tokenize prompt ONLY to get its length (used for slicing later).
    #    We must ensure special tokens are added to count the offset correctly.
    prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
    prompt_len = prompt_tokens_temp['input_ids'].shape[1]

    # 2. Tokenize the FULL text (Prompt + Target).
    #    This ensures subword merging at boundaries is handled correctly by the tokenizer.
    full_text = formatted_prompt + target_text
    full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)

    input_ids = full_tokens['input_ids']

    # Safety check: if target was empty or truncated entirely
    if input_ids.shape[1] <= prompt_len:
        return torch.empty(0, device=device), torch.empty(0, device=device)

    # 3. Forward Pass (Teacher Forcing)
    with torch.no_grad():
        with llm_handler._load_model_context():
            outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
            all_logits = outputs.logits  # [1, seq_len, vocab_size]

    # 4. Extract Logits and Labels
    #    We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
    #    Target starts at index `prompt_len`.
    #    So we need logits from `prompt_len - 1` up to the second to last position.

    target_logits = all_logits[0, prompt_len - 1:-1, :]  # [target_len, vocab_size]
    target_ids = input_ids[0, prompt_len:]  # [target_len]

    return target_logits, target_ids


# ==============================================================================
# Scoring Logic
# ==============================================================================


def _calculate_topk_recall(llm_handler,
                           formatted_prompt: str,
                           target_text: str,
                           topk: int = 10) -> Tuple[float, Dict[int, float]]:
    """
    Calculate top-k recall for target text given prompt.
    Checks if the ground truth token is within the top-k probabilities at each step.
    """
    # Use the fixed helper to get aligned logits/labels
    pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)

    if target_ids.shape[0] == 0:
        return 0.0, {}

    target_len = target_ids.shape[0]

    # Get top-k indices for all positions at once
    # topk_indices: [target_len, topk]
    _, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)

    recall_per_k = {}
    position_scores = []

    # Convert to list for faster CPU iteration
    target_ids_list = target_ids.tolist()
    topk_indices_list = topk_indices.tolist()

    for k in range(1, topk + 1):
        hits = 0
        for pos in range(target_len):
            gt_token = target_ids_list[pos]
            # Check the top-k slice
            topk_at_pos = topk_indices_list[pos][:k]

            if gt_token in topk_at_pos:
                hits += 1
                # Calculate position-weighted score only once (when k=topk)
                if k == topk:
                    rank = topk_at_pos.index(gt_token) + 1
                    # Rank 1 = 1.0, Rank k = small positive
                    position_weight = 1.0 - (rank - 1) / topk
                    position_scores.append(position_weight)

        recall_per_k[k] = hits / target_len if target_len > 0 else 0.0

    # Fill scores for positions where GT was NOT in top-k
    while len(position_scores) < target_len:
        position_scores.append(0.0)

    average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0

    return average_recall, recall_per_k


def _calculate_metadata_recall(llm_handler,
                               formatted_prompt: str,
                               fields_dict: Dict[str, Any],
                               topk: int = 10) -> Dict[str, float]:
    """
    Args:
        fields_dict: Dictionary of {field_name: field_value}
    """
    if not fields_dict:
        return {}

    field_scores = {}

    for field_name in sorted(fields_dict.keys()):
        # Construct target text for this specific field
        # e.g. <think>\nbpm: 120\n</think>\n
        field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
        field_target_text = f"<think>\n{field_yaml}\n</think>\n"

        # Calculate recall using the robust logic
        avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)

        field_scores[field_name] = avg_score
        logger.debug(f"Recall for {field_name}: {avg_score:.4f}")

    return field_scores


def _calculate_log_prob(
        llm_handler,
        formatted_prompt: str,
        target_text: str,
        temperature: float = 1.0  # Kept for API compatibility, but ignored for scoring
) -> float:
    """
    Calculate average log probability of target text given prompt.
    """
    pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)

    if target_ids.shape[0] == 0:
        return float('-inf')

    # FIX: Do not divide by temperature.
    # Log-probability for PMI/Perplexity should be exact.

    # Calculate log probabilities (log_softmax)
    log_probs = F.log_softmax(pred_logits, dim=-1)  # [target_len, vocab_size]

    # Gather log probabilities of the ground truth tokens
    target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]

    # Return average log probability
    mean_log_prob = target_log_probs.mean().item()

    return mean_log_prob


def calculate_reward_score(
    scores: Dict[str, float],
    weights_config: Optional[Dict[str, float]] = None
) -> Tuple[float, str]:
    """
    Reward Model Calculator: Computes a final reward based on user priorities.
    
    Priority Logic:
        1. Caption (Highest): The overall vibe/style must match.
        2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
        3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
    
    Strategy: Dynamic Weighted Sum
    - Metadata fields are aggregated into a single 'metadata' score first.
    - Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
    
    Args:
        scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
        weights_config: Optional custom weights. Defaults to:
                        Caption (50%), Lyrics (30%), Metadata (20%).
        
    Returns:
        final_reward: The calculated reward score (0.0 - 1.0).
        explanation: A formatted string explaining how the score was derived.
    """
    
    # 1. Default Preference Configuration
    # These weights determine the relative importance of each component.
    if weights_config is None:
        weights_config = {
            'caption': 0.50,  # High priority: Style/Vibe
            'lyrics':  0.30,  # Medium priority: Content
            'metadata': 0.20  # Low priority: Technical details
        }
    
    # 2. Extract and Group Scores
    # Caption and Lyrics are standalone high-level features.
    caption_score = scores.get('caption')
    lyrics_score = scores.get('lyrics')
    
    # Metadata fields (bpm, key, duration, etc.) are aggregated.
    # We treat them as a single "Technical Score" to prevent them from 
    # diluting the weight of Caption/Lyrics simply by having many fields.
    meta_scores_list = [
        val for key, val in scores.items() 
        if key not in ['caption', 'lyrics']
    ]
    
    # Calculate average of all metadata fields (if any exist)
    meta_aggregate_score = None
    if meta_scores_list:
        meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
    
    # 3. specific Active Components & Dynamic Weighting
    # We only include components that actually exist in this generation.
    active_components = {}
    
    if caption_score is not None:
        active_components['caption'] = (caption_score, weights_config['caption'])
        
    if lyrics_score is not None:
        active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
        
    if meta_aggregate_score is not None:
        active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
    
    # 4. Calculate Final Weighted Score
    total_base_weight = sum(w for _, w in active_components.values())
    total_score = 0.0
    
    breakdown_lines = []
    
    if total_base_weight == 0:
        return 0.0, "❌ No valid scores available to calculate reward."
    
    # Sort by weight (importance) for display
    sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
    
    for name, (score, base_weight) in sorted_components:
        # Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
        normalized_weight = base_weight / total_base_weight
        weighted_contribution = score * normalized_weight
        total_score += weighted_contribution
        
        breakdown_lines.append(
            f"  • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
            f"-> Contrib: +{weighted_contribution:.4f}"
        )

    return total_score, "\n".join(breakdown_lines)

# ==============================================================================
# Main Public API
# ==============================================================================


def calculate_pmi_score_per_condition(
    llm_handler,
    audio_codes: str,
    caption: str = "",
    lyrics: str = "",
    metadata: Optional[Dict[str, Any]] = None,
    temperature: float = 1.0,
    topk: int = 10,
    score_scale: float = 0.1,
) -> Tuple[Dict[str, float], float, str]:
    """
    Calculate quality score separately for each condition.
    - Metadata: Uses Top-k Recall.
    - Caption/Lyrics: Uses PMI (Normalized).
    """
    if not llm_handler.llm_initialized:
        return {}, 0.0, "❌ LLM not initialized"

    if not audio_codes or not audio_codes.strip():
        return {}, 0.0, "❌ No audio codes provided"

    if "caption" not in metadata:
        metadata['caption'] = caption

    formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
    prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
    try:
        # 1. Calculate Recall for Metadata Fields
        if metadata and isinstance(metadata, dict):
            scores = {}
            # Define which fields use which metric
            metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
            metadata_pmi_keys = ['caption']
            for key in metadata_recall_keys:
                if key in metadata and metadata[key] is not None:
                    recall_metadata = {key: metadata[key]}
                    field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
                    scores.update(field_scores)

            # 2. Calculate PMI for Caption
            for key in metadata_pmi_keys:
                if key in metadata and metadata[key] is not None:
                    cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
                    target_text = f"<think>\n{cot_yaml}\n</think>\n"

                    log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
                    log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)

                    pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
                    scores[key] = pmi_normalized

        # 3. Calculate PMI for Lyrics
        if lyrics:
            target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"

            log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)

            prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
            log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)

            scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)

        if not scores:
            return {}, 0.0, "❌ No conditions to evaluate"

        # 4. Global Score
        global_score = sum(scores.values()) / len(scores)
        global_score, breakdown_lines = calculate_reward_score(scores)

        # Status Message
        status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
        for key, score in sorted(scores.items()):
            metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
            status_lines.append(f"  {key}: {score:.4f} ({metric})")
        status = "\n".join(status_lines)
        logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
        return scores, global_score, status

    except Exception as e:
        import traceback
        error_msg = f"❌ Error: {str(e)}"
        logger.error(error_msg)
        logger.error(traceback.format_exc())
        return {}, float('-inf'), error_msg