Spaces:
Running
on
A100
Running
on
A100
| """ | |
| Test-Time Scaling Module | |
| Implements perplexity-based scoring for generated audio codes | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import Tuple, Optional, Dict, Any, List | |
| from loguru import logger | |
| import yaml | |
| import math | |
| import re | |
| def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float: | |
| """ | |
| Calculate Pointwise Mutual Information (PMI) score. | |
| PMI = log P(condition|codes) - log P(condition) | |
| = log [P(codes|condition) / P(codes)] | |
| This removes the bias from P(condition) and measures how much the codes | |
| improve our ability to predict the condition. | |
| Args: | |
| log_prob_conditional: Average log probability of condition given codes | |
| log_prob_unconditional: Average log probability of condition without codes | |
| Returns: | |
| PMI score (higher is better, can be positive or negative) | |
| - Positive: codes improve prediction → good match | |
| - Zero: codes don't help → no correlation | |
| - Negative: codes hurt prediction → poor match | |
| """ | |
| return log_prob_conditional - log_prob_unconditional | |
| def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float: | |
| """ | |
| Convert PMI score to normalized [0, 1] range using sigmoid function. | |
| score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale)) | |
| Args: | |
| pmi: PMI score (can be positive or negative) | |
| scale: Scale parameter to control sensitivity (default 0.1) | |
| - Smaller scale: more sensitive to PMI changes | |
| - Larger scale: less sensitive to PMI changes | |
| Returns: | |
| Normalized score in [0, 1] range, where: | |
| - PMI > 0 → score > 0.5 (good match) | |
| - PMI = 0 → score = 0.5 (neutral) | |
| - PMI < 0 → score < 0.5 (poor match) | |
| Examples (scale=1.0): | |
| PMI=2.0 → score≈0.88 (excellent) | |
| PMI=1.0 → score≈0.73 (good) | |
| PMI=0.0 → score=0.50 (neutral) | |
| PMI=-1.0 → score≈0.27 (poor) | |
| PMI=-2.0 → score≈0.12 (bad) | |
| """ | |
| return 1.0 / (1.0 + math.exp(-pmi / scale)) | |
| def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str, | |
| target_text: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| llm_handler: The handler containing the model and tokenizer. | |
| formatted_prompt: The input context. | |
| target_text: The text we want to calculate probability/recall for. | |
| Returns: | |
| Tuple of (target_logits, target_ids) | |
| - target_logits: Logits used to predict the target tokens. | |
| - target_ids: The ground truth token IDs of the target. | |
| """ | |
| model = llm_handler.get_hf_model_for_scoring() | |
| tokenizer = llm_handler.llm_tokenizer | |
| device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device | |
| # 1. Tokenize prompt ONLY to get its length (used for slicing later). | |
| # We must ensure special tokens are added to count the offset correctly. | |
| prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True) | |
| prompt_len = prompt_tokens_temp['input_ids'].shape[1] | |
| # 2. Tokenize the FULL text (Prompt + Target). | |
| # This ensures subword merging at boundaries is handled correctly by the tokenizer. | |
| full_text = formatted_prompt + target_text | |
| full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device) | |
| input_ids = full_tokens['input_ids'] | |
| # Safety check: if target was empty or truncated entirely | |
| if input_ids.shape[1] <= prompt_len: | |
| return torch.empty(0, device=device), torch.empty(0, device=device) | |
| # 3. Forward Pass (Teacher Forcing) | |
| with torch.no_grad(): | |
| with llm_handler._load_model_context(): | |
| outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask']) | |
| all_logits = outputs.logits # [1, seq_len, vocab_size] | |
| # 4. Extract Logits and Labels | |
| # We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`. | |
| # Target starts at index `prompt_len`. | |
| # So we need logits from `prompt_len - 1` up to the second to last position. | |
| target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size] | |
| target_ids = input_ids[0, prompt_len:] # [target_len] | |
| return target_logits, target_ids | |
| # ============================================================================== | |
| # Scoring Logic | |
| # ============================================================================== | |
| def _calculate_topk_recall(llm_handler, | |
| formatted_prompt: str, | |
| target_text: str, | |
| topk: int = 10) -> Tuple[float, Dict[int, float]]: | |
| """ | |
| Calculate top-k recall for target text given prompt. | |
| Checks if the ground truth token is within the top-k probabilities at each step. | |
| """ | |
| # Use the fixed helper to get aligned logits/labels | |
| pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text) | |
| if target_ids.shape[0] == 0: | |
| return 0.0, {} | |
| target_len = target_ids.shape[0] | |
| # Get top-k indices for all positions at once | |
| # topk_indices: [target_len, topk] | |
| _, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1) | |
| recall_per_k = {} | |
| position_scores = [] | |
| # Convert to list for faster CPU iteration | |
| target_ids_list = target_ids.tolist() | |
| topk_indices_list = topk_indices.tolist() | |
| for k in range(1, topk + 1): | |
| hits = 0 | |
| for pos in range(target_len): | |
| gt_token = target_ids_list[pos] | |
| # Check the top-k slice | |
| topk_at_pos = topk_indices_list[pos][:k] | |
| if gt_token in topk_at_pos: | |
| hits += 1 | |
| # Calculate position-weighted score only once (when k=topk) | |
| if k == topk: | |
| rank = topk_at_pos.index(gt_token) + 1 | |
| # Rank 1 = 1.0, Rank k = small positive | |
| position_weight = 1.0 - (rank - 1) / topk | |
| position_scores.append(position_weight) | |
| recall_per_k[k] = hits / target_len if target_len > 0 else 0.0 | |
| # Fill scores for positions where GT was NOT in top-k | |
| while len(position_scores) < target_len: | |
| position_scores.append(0.0) | |
| average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0 | |
| return average_recall, recall_per_k | |
| def _calculate_metadata_recall(llm_handler, | |
| formatted_prompt: str, | |
| fields_dict: Dict[str, Any], | |
| topk: int = 10) -> Dict[str, float]: | |
| """ | |
| Args: | |
| fields_dict: Dictionary of {field_name: field_value} | |
| """ | |
| if not fields_dict: | |
| return {} | |
| field_scores = {} | |
| for field_name in sorted(fields_dict.keys()): | |
| # Construct target text for this specific field | |
| # e.g. <think>\nbpm: 120\n</think>\n | |
| field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip() | |
| field_target_text = f"<think>\n{field_yaml}\n</think>\n" | |
| # Calculate recall using the robust logic | |
| avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk) | |
| field_scores[field_name] = avg_score | |
| logger.debug(f"Recall for {field_name}: {avg_score:.4f}") | |
| return field_scores | |
| def _calculate_log_prob( | |
| llm_handler, | |
| formatted_prompt: str, | |
| target_text: str, | |
| temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring | |
| ) -> float: | |
| """ | |
| Calculate average log probability of target text given prompt. | |
| """ | |
| pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text) | |
| if target_ids.shape[0] == 0: | |
| return float('-inf') | |
| # FIX: Do not divide by temperature. | |
| # Log-probability for PMI/Perplexity should be exact. | |
| # Calculate log probabilities (log_softmax) | |
| log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size] | |
| # Gather log probabilities of the ground truth tokens | |
| target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids] | |
| # Return average log probability | |
| mean_log_prob = target_log_probs.mean().item() | |
| return mean_log_prob | |
| def calculate_reward_score( | |
| scores: Dict[str, float], | |
| weights_config: Optional[Dict[str, float]] = None | |
| ) -> Tuple[float, str]: | |
| """ | |
| Reward Model Calculator: Computes a final reward based on user priorities. | |
| Priority Logic: | |
| 1. Caption (Highest): The overall vibe/style must match. | |
| 2. Lyrics (Medium): Content accuracy is important but secondary to vibe. | |
| 3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations. | |
| Strategy: Dynamic Weighted Sum | |
| - Metadata fields are aggregated into a single 'metadata' score first. | |
| - Weights are dynamically renormalized if any component (e.g., lyrics) is missing. | |
| Args: | |
| scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module. | |
| weights_config: Optional custom weights. Defaults to: | |
| Caption (50%), Lyrics (30%), Metadata (20%). | |
| Returns: | |
| final_reward: The calculated reward score (0.0 - 1.0). | |
| explanation: A formatted string explaining how the score was derived. | |
| """ | |
| # 1. Default Preference Configuration | |
| # These weights determine the relative importance of each component. | |
| if weights_config is None: | |
| weights_config = { | |
| 'caption': 0.50, # High priority: Style/Vibe | |
| 'lyrics': 0.30, # Medium priority: Content | |
| 'metadata': 0.20 # Low priority: Technical details | |
| } | |
| # 2. Extract and Group Scores | |
| # Caption and Lyrics are standalone high-level features. | |
| caption_score = scores.get('caption') | |
| lyrics_score = scores.get('lyrics') | |
| # Metadata fields (bpm, key, duration, etc.) are aggregated. | |
| # We treat them as a single "Technical Score" to prevent them from | |
| # diluting the weight of Caption/Lyrics simply by having many fields. | |
| meta_scores_list = [ | |
| val for key, val in scores.items() | |
| if key not in ['caption', 'lyrics'] | |
| ] | |
| # Calculate average of all metadata fields (if any exist) | |
| meta_aggregate_score = None | |
| if meta_scores_list: | |
| meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list) | |
| # 3. specific Active Components & Dynamic Weighting | |
| # We only include components that actually exist in this generation. | |
| active_components = {} | |
| if caption_score is not None: | |
| active_components['caption'] = (caption_score, weights_config['caption']) | |
| if lyrics_score is not None: | |
| active_components['lyrics'] = (lyrics_score, weights_config['lyrics']) | |
| if meta_aggregate_score is not None: | |
| active_components['metadata'] = (meta_aggregate_score, weights_config['metadata']) | |
| # 4. Calculate Final Weighted Score | |
| total_base_weight = sum(w for _, w in active_components.values()) | |
| total_score = 0.0 | |
| breakdown_lines = [] | |
| if total_base_weight == 0: | |
| return 0.0, "❌ No valid scores available to calculate reward." | |
| # Sort by weight (importance) for display | |
| sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True) | |
| for name, (score, base_weight) in sorted_components: | |
| # Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately. | |
| normalized_weight = base_weight / total_base_weight | |
| weighted_contribution = score * normalized_weight | |
| total_score += weighted_contribution | |
| breakdown_lines.append( | |
| f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} " | |
| f"-> Contrib: +{weighted_contribution:.4f}" | |
| ) | |
| return total_score, "\n".join(breakdown_lines) | |
| # ============================================================================== | |
| # Main Public API | |
| # ============================================================================== | |
| def calculate_pmi_score_per_condition( | |
| llm_handler, | |
| audio_codes: str, | |
| caption: str = "", | |
| lyrics: str = "", | |
| metadata: Optional[Dict[str, Any]] = None, | |
| temperature: float = 1.0, | |
| topk: int = 10, | |
| score_scale: float = 0.1, | |
| ) -> Tuple[Dict[str, float], float, str]: | |
| """ | |
| Calculate quality score separately for each condition. | |
| - Metadata: Uses Top-k Recall. | |
| - Caption/Lyrics: Uses PMI (Normalized). | |
| """ | |
| if not llm_handler.llm_initialized: | |
| return {}, 0.0, "❌ LLM not initialized" | |
| if not audio_codes or not audio_codes.strip(): | |
| return {}, 0.0, "❌ No audio codes provided" | |
| if "caption" not in metadata: | |
| metadata['caption'] = caption | |
| formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False) | |
| prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) | |
| try: | |
| # 1. Calculate Recall for Metadata Fields | |
| if metadata and isinstance(metadata, dict): | |
| scores = {} | |
| # Define which fields use which metric | |
| metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature'] | |
| metadata_pmi_keys = ['caption'] | |
| for key in metadata_recall_keys: | |
| if key in metadata and metadata[key] is not None: | |
| recall_metadata = {key: metadata[key]} | |
| field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk) | |
| scores.update(field_scores) | |
| # 2. Calculate PMI for Caption | |
| for key in metadata_pmi_keys: | |
| if key in metadata and metadata[key] is not None: | |
| cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip() | |
| target_text = f"<think>\n{cot_yaml}\n</think>\n" | |
| log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) | |
| log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) | |
| pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) | |
| scores[key] = pmi_normalized | |
| # 3. Calculate PMI for Lyrics | |
| if lyrics: | |
| target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n" | |
| log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) | |
| prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) | |
| log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) | |
| scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) | |
| if not scores: | |
| return {}, 0.0, "❌ No conditions to evaluate" | |
| # 4. Global Score | |
| global_score = sum(scores.values()) / len(scores) | |
| global_score, breakdown_lines = calculate_reward_score(scores) | |
| # Status Message | |
| status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"] | |
| for key, score in sorted(scores.items()): | |
| metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)" | |
| status_lines.append(f" {key}: {score:.4f} ({metric})") | |
| status = "\n".join(status_lines) | |
| logger.info(f"Calculated scores: {global_score:.4f}\n{status}") | |
| return scores, global_score, status | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Error: {str(e)}" | |
| logger.error(error_msg) | |
| logger.error(traceback.format_exc()) | |
| return {}, float('-inf'), error_msg | |