""" Test-Time Scaling Module Implements perplexity-based scoring for generated audio codes """ import torch import torch.nn.functional as F from typing import Tuple, Optional, Dict, Any, List from loguru import logger import yaml import math import re def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float: """ Calculate Pointwise Mutual Information (PMI) score. PMI = log P(condition|codes) - log P(condition) = log [P(codes|condition) / P(codes)] This removes the bias from P(condition) and measures how much the codes improve our ability to predict the condition. Args: log_prob_conditional: Average log probability of condition given codes log_prob_unconditional: Average log probability of condition without codes Returns: PMI score (higher is better, can be positive or negative) - Positive: codes improve prediction → good match - Zero: codes don't help → no correlation - Negative: codes hurt prediction → poor match """ return log_prob_conditional - log_prob_unconditional def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float: """ Convert PMI score to normalized [0, 1] range using sigmoid function. score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale)) Args: pmi: PMI score (can be positive or negative) scale: Scale parameter to control sensitivity (default 0.1) - Smaller scale: more sensitive to PMI changes - Larger scale: less sensitive to PMI changes Returns: Normalized score in [0, 1] range, where: - PMI > 0 → score > 0.5 (good match) - PMI = 0 → score = 0.5 (neutral) - PMI < 0 → score < 0.5 (poor match) Examples (scale=1.0): PMI=2.0 → score≈0.88 (excellent) PMI=1.0 → score≈0.73 (good) PMI=0.0 → score=0.50 (neutral) PMI=-1.0 → score≈0.27 (poor) PMI=-2.0 → score≈0.12 (bad) """ return 1.0 / (1.0 + math.exp(-pmi / scale)) def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str, target_text: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: llm_handler: The handler containing the model and tokenizer. formatted_prompt: The input context. target_text: The text we want to calculate probability/recall for. Returns: Tuple of (target_logits, target_ids) - target_logits: Logits used to predict the target tokens. - target_ids: The ground truth token IDs of the target. """ model = llm_handler.get_hf_model_for_scoring() tokenizer = llm_handler.llm_tokenizer device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device # 1. Tokenize prompt ONLY to get its length (used for slicing later). # We must ensure special tokens are added to count the offset correctly. prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True) prompt_len = prompt_tokens_temp['input_ids'].shape[1] # 2. Tokenize the FULL text (Prompt + Target). # This ensures subword merging at boundaries is handled correctly by the tokenizer. full_text = formatted_prompt + target_text full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device) input_ids = full_tokens['input_ids'] # Safety check: if target was empty or truncated entirely if input_ids.shape[1] <= prompt_len: return torch.empty(0, device=device), torch.empty(0, device=device) # 3. Forward Pass (Teacher Forcing) with torch.no_grad(): with llm_handler._load_model_context(): outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask']) all_logits = outputs.logits # [1, seq_len, vocab_size] # 4. Extract Logits and Labels # We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`. # Target starts at index `prompt_len`. # So we need logits from `prompt_len - 1` up to the second to last position. target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size] target_ids = input_ids[0, prompt_len:] # [target_len] return target_logits, target_ids # ============================================================================== # Scoring Logic # ============================================================================== def _calculate_topk_recall(llm_handler, formatted_prompt: str, target_text: str, topk: int = 10) -> Tuple[float, Dict[int, float]]: """ Calculate top-k recall for target text given prompt. Checks if the ground truth token is within the top-k probabilities at each step. """ # Use the fixed helper to get aligned logits/labels pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text) if target_ids.shape[0] == 0: return 0.0, {} target_len = target_ids.shape[0] # Get top-k indices for all positions at once # topk_indices: [target_len, topk] _, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1) recall_per_k = {} position_scores = [] # Convert to list for faster CPU iteration target_ids_list = target_ids.tolist() topk_indices_list = topk_indices.tolist() for k in range(1, topk + 1): hits = 0 for pos in range(target_len): gt_token = target_ids_list[pos] # Check the top-k slice topk_at_pos = topk_indices_list[pos][:k] if gt_token in topk_at_pos: hits += 1 # Calculate position-weighted score only once (when k=topk) if k == topk: rank = topk_at_pos.index(gt_token) + 1 # Rank 1 = 1.0, Rank k = small positive position_weight = 1.0 - (rank - 1) / topk position_scores.append(position_weight) recall_per_k[k] = hits / target_len if target_len > 0 else 0.0 # Fill scores for positions where GT was NOT in top-k while len(position_scores) < target_len: position_scores.append(0.0) average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0 return average_recall, recall_per_k def _calculate_metadata_recall(llm_handler, formatted_prompt: str, fields_dict: Dict[str, Any], topk: int = 10) -> Dict[str, float]: """ Args: fields_dict: Dictionary of {field_name: field_value} """ if not fields_dict: return {} field_scores = {} for field_name in sorted(fields_dict.keys()): # Construct target text for this specific field # e.g. \nbpm: 120\n\n field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip() field_target_text = f"\n{field_yaml}\n\n" # Calculate recall using the robust logic avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk) field_scores[field_name] = avg_score logger.debug(f"Recall for {field_name}: {avg_score:.4f}") return field_scores def _calculate_log_prob( llm_handler, formatted_prompt: str, target_text: str, temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring ) -> float: """ Calculate average log probability of target text given prompt. """ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text) if target_ids.shape[0] == 0: return float('-inf') # FIX: Do not divide by temperature. # Log-probability for PMI/Perplexity should be exact. # Calculate log probabilities (log_softmax) log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size] # Gather log probabilities of the ground truth tokens target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids] # Return average log probability mean_log_prob = target_log_probs.mean().item() return mean_log_prob def calculate_reward_score( scores: Dict[str, float], weights_config: Optional[Dict[str, float]] = None ) -> Tuple[float, str]: """ Reward Model Calculator: Computes a final reward based on user priorities. Priority Logic: 1. Caption (Highest): The overall vibe/style must match. 2. Lyrics (Medium): Content accuracy is important but secondary to vibe. 3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations. Strategy: Dynamic Weighted Sum - Metadata fields are aggregated into a single 'metadata' score first. - Weights are dynamically renormalized if any component (e.g., lyrics) is missing. Args: scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module. weights_config: Optional custom weights. Defaults to: Caption (50%), Lyrics (30%), Metadata (20%). Returns: final_reward: The calculated reward score (0.0 - 1.0). explanation: A formatted string explaining how the score was derived. """ # 1. Default Preference Configuration # These weights determine the relative importance of each component. if weights_config is None: weights_config = { 'caption': 0.50, # High priority: Style/Vibe 'lyrics': 0.30, # Medium priority: Content 'metadata': 0.20 # Low priority: Technical details } # 2. Extract and Group Scores # Caption and Lyrics are standalone high-level features. caption_score = scores.get('caption') lyrics_score = scores.get('lyrics') # Metadata fields (bpm, key, duration, etc.) are aggregated. # We treat them as a single "Technical Score" to prevent them from # diluting the weight of Caption/Lyrics simply by having many fields. meta_scores_list = [ val for key, val in scores.items() if key not in ['caption', 'lyrics'] ] # Calculate average of all metadata fields (if any exist) meta_aggregate_score = None if meta_scores_list: meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list) # 3. specific Active Components & Dynamic Weighting # We only include components that actually exist in this generation. active_components = {} if caption_score is not None: active_components['caption'] = (caption_score, weights_config['caption']) if lyrics_score is not None: active_components['lyrics'] = (lyrics_score, weights_config['lyrics']) if meta_aggregate_score is not None: active_components['metadata'] = (meta_aggregate_score, weights_config['metadata']) # 4. Calculate Final Weighted Score total_base_weight = sum(w for _, w in active_components.values()) total_score = 0.0 breakdown_lines = [] if total_base_weight == 0: return 0.0, "❌ No valid scores available to calculate reward." # Sort by weight (importance) for display sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True) for name, (score, base_weight) in sorted_components: # Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately. normalized_weight = base_weight / total_base_weight weighted_contribution = score * normalized_weight total_score += weighted_contribution breakdown_lines.append( f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} " f"-> Contrib: +{weighted_contribution:.4f}" ) return total_score, "\n".join(breakdown_lines) # ============================================================================== # Main Public API # ============================================================================== def calculate_pmi_score_per_condition( llm_handler, audio_codes: str, caption: str = "", lyrics: str = "", metadata: Optional[Dict[str, Any]] = None, temperature: float = 1.0, topk: int = 10, score_scale: float = 0.1, ) -> Tuple[Dict[str, float], float, str]: """ Calculate quality score separately for each condition. - Metadata: Uses Top-k Recall. - Caption/Lyrics: Uses PMI (Normalized). """ if not llm_handler.llm_initialized: return {}, 0.0, "❌ LLM not initialized" if not audio_codes or not audio_codes.strip(): return {}, 0.0, "❌ No audio codes provided" if "caption" not in metadata: metadata['caption'] = caption formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False) prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) try: # 1. Calculate Recall for Metadata Fields if metadata and isinstance(metadata, dict): scores = {} # Define which fields use which metric metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature'] metadata_pmi_keys = ['caption'] for key in metadata_recall_keys: if key in metadata and metadata[key] is not None: recall_metadata = {key: metadata[key]} field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk) scores.update(field_scores) # 2. Calculate PMI for Caption for key in metadata_pmi_keys: if key in metadata and metadata[key] is not None: cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip() target_text = f"\n{cot_yaml}\n\n" log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) scores[key] = pmi_normalized # 3. Calculate PMI for Lyrics if lyrics: target_text = f"\n\n# Lyric\n{lyrics}\n" log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) if not scores: return {}, 0.0, "❌ No conditions to evaluate" # 4. Global Score global_score = sum(scores.values()) / len(scores) global_score, breakdown_lines = calculate_reward_score(scores) # Status Message status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"] for key, score in sorted(scores.items()): metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)" status_lines.append(f" {key}: {score:.4f} ({metric})") status = "\n".join(status_lines) logger.info(f"Calculated scores: {global_score:.4f}\n{status}") return scores, global_score, status except Exception as e: import traceback error_msg = f"❌ Error: {str(e)}" logger.error(error_msg) logger.error(traceback.format_exc()) return {}, float('-inf'), error_msg