Ace-Step-v1.5 / acestep /dit_alignment_score.py
keylxiao's picture
feat :sparkles: : add lyrics alignment scores
5ab4485
raw
history blame
30.4 kB
"""
DiT Alignment Score Module
This module provides lyrics-to-audio alignment using cross-attention matrices
from DiT model for generating LRC timestamps.
Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
"""
import numba
import torch
import numpy as np
import torch.nn.functional as F
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple, Union
# ================= Data Classes =================
@dataclass
class TokenTimestamp:
"""Stores per-token timing information."""
token_id: int
text: str
start: float
end: float
probability: float
@dataclass
class SentenceTimestamp:
"""Stores per-sentence timing information with token list."""
text: str
start: float
end: float
tokens: List[TokenTimestamp]
confidence: float
# ================= DTW Algorithm (Numba Optimized) =================
@numba.jit(nopython=True)
def dtw_cpu(x: np.ndarray):
"""
Dynamic Time Warping algorithm optimized with Numba.
Args:
x: Cost matrix of shape [N, M]
Returns:
Tuple of (text_indices, time_indices) arrays
"""
N, M = x.shape
# Use float32 for memory efficiency
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t
return _backtrace(trace, N, M)
@numba.jit(nopython=True)
def _backtrace(trace: np.ndarray, N: int, M: int):
"""
Optimized backtrace function for DTW.
Args:
trace: Trace matrix of shape (N+1, M+1)
N, M: Original matrix dimensions
Returns:
Path array of shape (2, path_len) - first row is text indices, second is time indices
"""
# Boundary handling
trace[0, :] = 2
trace[:, 0] = 1
# Pre-allocate array, max path length is N+M
max_path_len = N + M
path = np.zeros((2, max_path_len), dtype=np.int32)
i, j = N, M
path_idx = max_path_len - 1
while i > 0 or j > 0:
path[0, path_idx] = i - 1 # text index
path[1, path_idx] = j - 1 # time index
path_idx -= 1
t = trace[i, j]
if t == 0:
i -= 1
j -= 1
elif t == 1:
i -= 1
elif t == 2:
j -= 1
else:
break
actual_len = max_path_len - path_idx - 1
return path[:, path_idx + 1:max_path_len]
# ================= Utility Functions =================
def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
"""
Apply median filter to tensor.
Args:
x: Input tensor
filter_width: Width of median filter
Returns:
Filtered tensor
"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
return x
if x.ndim == 2:
x = x[None, :]
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if result.ndim > 2:
result = result.squeeze(0)
return result
# ================= Main Aligner Class =================
class MusicStampsAligner:
"""
Aligner class for generating lyrics timestamps from cross-attention matrices.
Uses bidirectional consensus denoising and DTW for alignment.
"""
def __init__(self, tokenizer):
"""
Initialize the aligner.
Args:
tokenizer: Text tokenizer for decoding tokens
"""
self.tokenizer = tokenizer
def _apply_bidirectional_consensus(
self,
weights_stack: torch.Tensor,
violence_level: float,
medfilt_width: int
) -> tuple:
"""
Core denoising logic using bidirectional consensus.
Args:
weights_stack: Attention weights [Heads, Tokens, Frames]
violence_level: Denoising strength coefficient
medfilt_width: Median filter width
Returns:
Tuple of (calc_matrix, energy_matrix) as numpy arrays
"""
# A. Bidirectional Consensus
row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
processed = row_prob * col_prob
# 1. Row suppression (kill horizontal crossing lines)
row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
processed = processed - (violence_level * row_medians)
processed = torch.relu(processed)
# 2. Column suppression (kill vertical crossing lines)
col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
processed = processed - (violence_level * col_medians)
processed = torch.relu(processed)
# C. Power sharpening
processed = processed ** 2
# Energy matrix for confidence
energy_matrix = processed.mean(dim=0).cpu().numpy()
# D. Z-Score normalization
std, mean = torch.std_mean(processed, unbiased=False)
weights_processed = (processed - mean) / (std + 1e-9)
# E. Median filtering
weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
calc_matrix = weights_processed.mean(dim=0).numpy()
return calc_matrix, energy_matrix
def _preprocess_attention(
self,
attention_matrix: torch.Tensor,
custom_config: Dict[int, List[int]],
violence_level: float,
medfilt_width: int = 7
) -> tuple:
"""
Preprocess attention matrix for alignment.
Args:
attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
custom_config: Dict mapping layer indices to head indices
violence_level: Denoising strength
medfilt_width: Median filter width
Returns:
Tuple of (calc_matrix, energy_matrix, visual_matrix)
"""
if not isinstance(attention_matrix, torch.Tensor):
weights = torch.tensor(attention_matrix)
else:
weights = attention_matrix.clone()
weights = weights.cpu().float()
selected_tensors = []
for layer_idx, head_indices in custom_config.items():
for head_idx in head_indices:
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
head_matrix = weights[layer_idx, head_idx]
selected_tensors.append(head_matrix)
if not selected_tensors:
return None, None, None
# Stack selected heads: [Heads, Tokens, Frames]
weights_stack = torch.stack(selected_tensors, dim=0)
visual_matrix = weights_stack.mean(dim=0).numpy()
calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
weights_stack, violence_level, medfilt_width
)
return calc_matrix, energy_matrix, visual_matrix
def stamps_align_info(
self,
attention_matrix: torch.Tensor,
lyrics_tokens: List[int],
total_duration_seconds: float,
custom_config: Dict[int, List[int]],
return_matrices: bool = False,
violence_level: float = 2.0,
medfilt_width: int = 1
) -> Dict[str, Any]:
"""
Get alignment information from attention matrix.
Args:
attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
lyrics_tokens: List of lyrics token IDs
total_duration_seconds: Total audio duration in seconds
custom_config: Dict mapping layer indices to head indices
return_matrices: Whether to return intermediate matrices
violence_level: Denoising strength
medfilt_width: Median filter width
Returns:
Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
and optionally energy_matrix and vis_matrix
"""
calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
attention_matrix, custom_config, violence_level, medfilt_width
)
if calc_matrix is None:
return {
"calc_matrix": None,
"lyrics_tokens": lyrics_tokens,
"total_duration_seconds": total_duration_seconds,
"error": "No valid attention heads found"
}
return_dict = {
"calc_matrix": calc_matrix,
"lyrics_tokens": lyrics_tokens,
"total_duration_seconds": total_duration_seconds
}
if return_matrices:
return_dict['energy_matrix'] = energy_matrix
return_dict['vis_matrix'] = visual_matrix
return return_dict
def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
"""
Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
For Chinese and other multi-byte characters, the tokenizer may split them
into multiple byte-level tokens. Decoding each token individually produces
invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
to correctly track which characters each token contributes.
Args:
token_ids: List of token IDs
Returns:
List of decoded text for each token position
"""
decoded_tokens = []
prev_bytes = b""
for i in range(len(token_ids)):
# Decode tokens from start to current position
current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
current_bytes = current_text.encode('utf-8', errors='surrogatepass')
# The contribution of current token is the new bytes added
if len(current_bytes) >= len(prev_bytes):
new_bytes = current_bytes[len(prev_bytes):]
# Try to decode the new bytes; if incomplete, use empty string
try:
token_text = new_bytes.decode('utf-8')
except UnicodeDecodeError:
# Incomplete UTF-8 sequence, this token doesn't complete a character
token_text = ""
else:
# Edge case: current decode is shorter (shouldn't happen normally)
token_text = ""
decoded_tokens.append(token_text)
prev_bytes = current_bytes
return decoded_tokens
def token_timestamps(
self,
calc_matrix: np.ndarray,
lyrics_tokens: List[int],
total_duration_seconds: float
) -> List[TokenTimestamp]:
"""
Generate per-token timestamps using DTW.
Args:
calc_matrix: Processed attention matrix [Tokens, Frames]
lyrics_tokens: List of token IDs
total_duration_seconds: Total audio duration
Returns:
List of TokenTimestamp objects
"""
n_frames = calc_matrix.shape[-1]
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
seconds_per_frame = total_duration_seconds / n_frames
alignment_results = []
# Use incremental decoding to properly handle multi-byte UTF-8 characters
decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
for i in range(len(lyrics_tokens)):
mask = (text_indices == i)
if not np.any(mask):
start = alignment_results[-1].end if alignment_results else 0.0
end = start
token_conf = 0.0
else:
times = time_indices[mask] * seconds_per_frame
start = times[0]
end = times[-1]
token_conf = 0.0
if end < start:
end = start
alignment_results.append(TokenTimestamp(
token_id=lyrics_tokens[i],
text=decoded_tokens[i],
start=float(start),
end=float(end),
probability=token_conf
))
return alignment_results
def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
"""
Decode a sentence by decoding all token IDs together.
This avoids UTF-8 encoding issues from joining individual token texts.
Args:
tokens: List of TokenTimestamp objects
Returns:
Properly decoded sentence text
"""
token_ids = [t.token_id for t in tokens]
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
def sentence_timestamps(
self,
token_alignment: List[TokenTimestamp]
) -> List[SentenceTimestamp]:
"""
Group token timestamps into sentence timestamps.
Args:
token_alignment: List of TokenTimestamp objects
Returns:
List of SentenceTimestamp objects
"""
results = []
current_tokens = []
for token in token_alignment:
current_tokens.append(token)
if '\n' in token.text:
# Decode all token IDs together to avoid UTF-8 issues
full_text = self._decode_sentence_from_tokens(current_tokens)
if full_text.strip():
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
results.append(SentenceTimestamp(
text=full_text.strip(),
start=round(current_tokens[0].start, 3),
end=round(current_tokens[-1].end, 3),
tokens=list(current_tokens),
confidence=sent_conf
))
current_tokens = []
# Handle last sentence
if current_tokens:
# Decode all token IDs together to avoid UTF-8 issues
full_text = self._decode_sentence_from_tokens(current_tokens)
if full_text.strip():
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
results.append(SentenceTimestamp(
text=full_text.strip(),
start=round(current_tokens[0].start, 3),
end=round(current_tokens[-1].end, 3),
tokens=list(current_tokens),
confidence=sent_conf
))
# Normalize confidence scores
if results:
all_scores = [s.confidence for s in results]
min_score = min(all_scores)
max_score = max(all_scores)
score_range = max_score - min_score
if score_range > 1e-9:
for s in results:
normalized_score = (s.confidence - min_score) / score_range
s.confidence = round(normalized_score, 2)
else:
for s in results:
s.confidence = round(s.confidence, 2)
return results
def format_lrc(
self,
sentence_timestamps: List[SentenceTimestamp],
include_end_time: bool = False
) -> str:
"""
Format sentence timestamps as LRC lyrics format.
Args:
sentence_timestamps: List of SentenceTimestamp objects
include_end_time: Whether to include end time (enhanced LRC format)
Returns:
LRC formatted string
"""
lines = []
for sentence in sentence_timestamps:
# Convert seconds to mm:ss.xx format
start_minutes = int(sentence.start // 60)
start_seconds = sentence.start % 60
if include_end_time:
end_minutes = int(sentence.end // 60)
end_seconds = sentence.end % 60
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
else:
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
# Clean the text (remove structural tags like [verse], [chorus])
text = sentence.text
lines.append(f"{timestamp}{text}")
return "\n".join(lines)
def get_timestamps_and_lrc(
self,
calc_matrix: np.ndarray,
lyrics_tokens: List[int],
total_duration_seconds: float
) -> Dict[str, Any]:
"""
Convenience method to get both timestamps and LRC in one call.
Args:
calc_matrix: Processed attention matrix
lyrics_tokens: List of token IDs
total_duration_seconds: Total audio duration
Returns:
Dict containing token_timestamps, sentence_timestamps, and lrc_text
"""
token_stamps = self.token_timestamps(
calc_matrix=calc_matrix,
lyrics_tokens=lyrics_tokens,
total_duration_seconds=total_duration_seconds
)
sentence_stamps = self.sentence_timestamps(token_stamps)
lrc_text = self.format_lrc(sentence_stamps)
return {
"token_timestamps": token_stamps,
"sentence_timestamps": sentence_stamps,
"lrc_text": lrc_text
}
class MusicLyricScorer:
"""
Scorer class for evaluating lyrics-to-audio alignment quality.
Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
using tensor operations for potential differentiability or GPU acceleration.
"""
def __init__(self, tokenizer: Any):
"""
Initialize the aligner.
Args:
tokenizer: Tokenizer instance (must implement .decode()).
"""
self.tokenizer = tokenizer
def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
"""
Generate a mask distinguishing lyrics (1) from structural tags (0).
Uses self.tokenizer to decode tokens.
Args:
token_ids: List of token IDs.
Returns:
Numpy array of shape [len(token_ids)] with 1 or 0.
"""
decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
mask = np.ones(len(token_ids), dtype=np.int32)
in_bracket = False
for i, token_str in enumerate(decoded_tokens):
if '[' in token_str:
in_bracket = True
if in_bracket:
mask[i] = 0
if ']' in token_str:
in_bracket = False
mask[i] = 0
return mask
def _preprocess_attention(
self,
attention_matrix: Union[torch.Tensor, np.ndarray],
custom_config: Dict[int, List[int]],
medfilt_width: int = 1
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
"""
Extracts and normalizes the attention matrix.
Logic V4: Uses Min-Max normalization to highlight energy differences.
Args:
attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
custom_config: Config mapping layers to heads.
medfilt_width: Width for median filtering.
Returns:
Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
"""
# 1. Prepare Tensor
if not isinstance(attention_matrix, torch.Tensor):
weights = torch.tensor(attention_matrix)
else:
weights = attention_matrix.clone()
weights = weights.cpu().float()
# 2. Select Heads based on config
selected_tensors = []
for layer_idx, head_indices in custom_config.items():
for head_idx in head_indices:
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
selected_tensors.append(weights[layer_idx, head_idx])
if not selected_tensors:
return None, None, None
weights_stack = torch.stack(selected_tensors, dim=0)
# 3. Average Heads
avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
# 4. Preprocessing Logic
# Min-Max normalization preserving energy distribution
# Median filter is applied to the energy matrix
energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
energy_matrix = energy_tensor.numpy()
e_min, e_max = energy_matrix.min(), energy_matrix.max()
if e_max - e_min > 1e-9:
energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
else:
energy_matrix = np.zeros_like(energy_matrix)
# Contrast enhancement for DTW pathfinding
# calc_matrix is used for pathfinding, energy_matrix for scoring
calc_matrix = energy_matrix ** 2
return calc_matrix, energy_matrix, avg_weights
def _compute_alignment_metrics(
self,
energy_matrix: torch.Tensor,
path_coords: torch.Tensor,
type_mask: torch.Tensor,
time_weight: float = 0.01,
overlap_frames: float = 9.0,
instrumental_weight: float = 1.0
) -> Tuple[float, float, float]:
"""
Core metric calculation logic using high-precision Tensor operations.
Args:
energy_matrix: Normalized energy [Rows, Cols].
path_coords: DTW path coordinates [Steps, 2].
type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
time_weight: Minimum energy threshold for monotonicity.
overlap_frames: Allowed overlap for monotonicity check.
instrumental_weight: Weight for non-lyric tokens in confidence calc.
Returns:
Tuple of (coverage, monotonicity, confidence).
"""
# Ensure high precision for internal calculation
energy_matrix = energy_matrix.to(dtype=torch.float64)
path_coords = path_coords.long()
type_mask = type_mask.long()
device = energy_matrix.device
rows, cols = energy_matrix.shape
is_lyrics_row = (type_mask == 1)
# ================= A. Coverage Score =================
# Ratio of lyric lines that have significant energy peak
row_max_energies = energy_matrix.max(dim=1).values
total_sung_rows = is_lyrics_row.sum().double()
coverage_threshold = 0.1
valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
valid_sung_rows = valid_sung_mask.sum().double()
if total_sung_rows > 0:
coverage_score = valid_sung_rows / total_sung_rows
else:
coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
# ================= B. Monotonicity Score =================
# Check if the "center of mass" of lyric lines moves forward in time
col_indices = torch.arange(cols, device=device, dtype=torch.float64)
# Zero out low energy noise
weights = torch.where(
energy_matrix > time_weight,
energy_matrix,
torch.zeros_like(energy_matrix)
)
sum_w = weights.sum(dim=1)
sum_t = (weights * col_indices).sum(dim=1)
# Calculate centroids
centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
valid_w_mask = sum_w > 1e-9
centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
# Extract sequence of valid lyrics centroids
valid_sequence_mask = is_lyrics_row & (centroids >= 0)
sung_centroids = centroids[valid_sequence_mask]
cnt = sung_centroids.shape[0]
if cnt > 1:
curr_c = sung_centroids[:-1]
next_c = sung_centroids[1:]
# Check non-decreasing order with overlap tolerance
non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
monotonicity_score = non_decreasing / pairs
else:
monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
# ================= C. Path Confidence =================
# Average energy along the optimal path
if path_coords.shape[0] > 0:
p_rows = path_coords[:, 0]
p_cols = path_coords[:, 1]
path_energies = energy_matrix[p_rows, p_cols]
step_weights = torch.ones_like(path_energies)
# Lower weight for instrumental/tag steps
is_inst_step = (type_mask[p_rows] == 0)
step_weights[is_inst_step] = instrumental_weight
total_energy = (path_energies * step_weights).sum()
total_steps = step_weights.sum()
if total_steps > 0:
path_confidence = total_energy / total_steps
else:
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
else:
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
def lyrics_alignment_info(
self,
attention_matrix: Union[torch.Tensor, np.ndarray],
token_ids: List[int],
custom_config: Dict[int, List[int]],
return_matrices: bool = False,
medfilt_width: int = 1
) -> Dict[str, Any]:
"""
Generates alignment path and processed matrices.
Args:
attention_matrix: Input attention tensor.
token_ids: Corresponding token IDs.
custom_config: Layer/Head configuration.
return_matrices: If True, returns matrices in the output.
medfilt_width: Median filter width.
Returns:
Dict or AlignmentInfo object containing path and masks.
"""
calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
attention_matrix, custom_config, medfilt_width
)
if calc_matrix is None:
return {
"calc_matrix": None,
"error": "No valid attention heads found"
}
# 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
# Uses self.tokenizer internally
type_mask = self._generate_token_type_mask(token_ids)
# Safety check for shape mismatch
if len(type_mask) != energy_matrix.shape[0]:
# Fallback to all lyrics if shapes don't align
type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
# 2. DTW Pathfinding
# Using negative calc_matrix because DTW minimizes cost
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
path_coords = np.stack([text_indices, time_indices], axis=1)
return_dict = {
"path_coords": path_coords,
"type_mask": type_mask,
"energy_matrix": energy_matrix
}
if return_matrices:
return_dict['calc_matrix'] = calc_matrix
return_dict['vis_matrix'] = vis_matrix
return return_dict
def calculate_score(
self,
energy_matrix: Union[torch.Tensor, np.ndarray],
type_mask: Union[torch.Tensor, np.ndarray],
path_coords: Union[torch.Tensor, np.ndarray],
time_weight: float = 0.01,
overlap_frames: float = 9.0,
instrumental_weight: float = 1.0
) -> Dict[str, Any]:
"""
Calculates the final alignment score based on pre-computed components.
Args:
energy_matrix: Processed energy matrix.
type_mask: Token type mask.
path_coords: DTW path coordinates.
time_weight: Minimum energy threshold for monotonicity.
overlap_frames: Allowed backward movement frames.
instrumental_weight: Weight for non-lyric path steps.
Returns:
AlignmentScore object containing individual metrics and final score.
"""
# Ensure Inputs are Tensors on the correct device
if not isinstance(energy_matrix, torch.Tensor):
energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32)
device = energy_matrix.device
if not isinstance(type_mask, torch.Tensor):
type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
else:
type_mask = type_mask.to(device=device, dtype=torch.long)
if not isinstance(path_coords, torch.Tensor):
path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
else:
path_coords = path_coords.to(device=device, dtype=torch.long)
# Compute Metrics
coverage, monotonicity, confidence = self._compute_alignment_metrics(
energy_matrix=energy_matrix,
path_coords=path_coords,
type_mask=type_mask,
time_weight=time_weight,
overlap_frames=overlap_frames,
instrumental_weight=instrumental_weight
)
# Final Score Calculation
# (Cov^2 * Mono^2 * Conf)
final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
final_score = float(np.clip(final_score, 0.0, 1.0))
return {
"lyrics_score": round(final_score, 4)
}