| | import torch |
| | import numpy as np |
| | from tqdm import tqdm |
| | from scipy.signal import find_peaks |
| | import argparse |
| | import os |
| |
|
| | from .model import ResNet |
| | from ..baseline1.utils import MultiViewSpectrogram |
| | from ..data.load import ds |
| | from ..data.eval import evaluate_all, format_results |
| |
|
| |
|
| | def get_activation_function(model, waveform, device): |
| | """ |
| | Computes probability curve over time. |
| | """ |
| | processor = MultiViewSpectrogram().to(device) |
| | waveform = waveform.unsqueeze(0).to(device) |
| |
|
| | with torch.no_grad(): |
| | spec = processor(waveform) |
| |
|
| | |
| | mean = spec.mean(dim=(2, 3), keepdim=True) |
| | std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 |
| | spec = (spec - mean) / std |
| |
|
| | |
| | |
| | |
| | spec = torch.nn.functional.pad(spec, (50, 50)) |
| | windows = spec.unfold(3, 101, 1) |
| | windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) |
| |
|
| | |
| | activations = [] |
| | batch_size = 128 |
| | for i in range(0, len(windows), batch_size): |
| | batch = windows[i : i + batch_size] |
| | out = model(batch) |
| | activations.append(out.cpu().numpy()) |
| |
|
| | return np.concatenate(activations).flatten() |
| |
|
| |
|
| | def pick_peaks(activations, hop_length=160, sr=16000): |
| | """ |
| | Smooth with Hamming window and report local maxima. |
| | """ |
| | |
| | window = np.hamming(5) |
| | window /= window.sum() |
| | smoothed = np.convolve(activations, window, mode="same") |
| |
|
| | |
| | peaks, _ = find_peaks(smoothed, height=0.5, distance=5) |
| |
|
| | timestamps = peaks * hop_length / sr |
| | return timestamps.tolist() |
| |
|
| |
|
| | def visualize_track( |
| | audio: np.ndarray, |
| | sr: int, |
| | pred_beats: list[float], |
| | pred_downbeats: list[float], |
| | gt_beats: list[float], |
| | gt_downbeats: list[float], |
| | output_dir: str, |
| | track_idx: int, |
| | time_range: tuple[float, float] | None = None, |
| | ): |
| | """ |
| | Create and save visualizations for a single track. |
| | """ |
| | from ..data.viz import plot_waveform_with_beats, save_figure |
| |
|
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | |
| | fig = plot_waveform_with_beats( |
| | audio, |
| | sr, |
| | pred_beats, |
| | gt_beats, |
| | pred_downbeats, |
| | gt_downbeats, |
| | title=f"Track {track_idx}: Beat Comparison", |
| | time_range=time_range, |
| | ) |
| | save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png")) |
| |
|
| |
|
| | def synthesize_audio( |
| | audio: np.ndarray, |
| | sr: int, |
| | pred_beats: list[float], |
| | pred_downbeats: list[float], |
| | gt_beats: list[float], |
| | gt_downbeats: list[float], |
| | output_dir: str, |
| | track_idx: int, |
| | click_volume: float = 0.5, |
| | ): |
| | """ |
| | Create and save audio files with click tracks for a single track. |
| | """ |
| | from ..data.audio import create_comparison_audio, save_audio |
| |
|
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | |
| | audio_pred, audio_gt, audio_both = create_comparison_audio( |
| | audio, |
| | pred_beats, |
| | pred_downbeats, |
| | gt_beats, |
| | gt_downbeats, |
| | sr=sr, |
| | click_volume=click_volume, |
| | ) |
| |
|
| | |
| | save_audio( |
| | audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr |
| | ) |
| | save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr) |
| | save_audio( |
| | audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr |
| | ) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Evaluate beat tracking models with visualization and audio synthesis" |
| | ) |
| | parser.add_argument( |
| | "--model-dir", |
| | type=str, |
| | default="outputs/baseline2", |
| | help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)", |
| | ) |
| | parser.add_argument( |
| | "--num-samples", |
| | type=int, |
| | default=116, |
| | help="Number of samples to evaluate", |
| | ) |
| | parser.add_argument( |
| | "--output-dir", |
| | type=str, |
| | default="outputs/eval_baseline2", |
| | help="Directory to save visualizations and audio", |
| | ) |
| | parser.add_argument( |
| | "--visualize", |
| | action="store_true", |
| | help="Generate visualization plots for each track", |
| | ) |
| | parser.add_argument( |
| | "--synthesize", |
| | action="store_true", |
| | help="Generate audio files with click tracks", |
| | ) |
| | parser.add_argument( |
| | "--viz-tracks", |
| | type=int, |
| | default=5, |
| | help="Number of tracks to visualize/synthesize (default: 5)", |
| | ) |
| | parser.add_argument( |
| | "--time-range", |
| | type=float, |
| | nargs=2, |
| | default=None, |
| | metavar=("START", "END"), |
| | help="Time range for visualization in seconds (default: full track)", |
| | ) |
| | parser.add_argument( |
| | "--click-volume", |
| | type=float, |
| | default=0.5, |
| | help="Volume of click sounds relative to audio (0.0 to 1.0)", |
| | ) |
| | parser.add_argument( |
| | "--summary-plot", |
| | action="store_true", |
| | help="Generate summary evaluation plot", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | beat_model = None |
| | downbeat_model = None |
| |
|
| | has_beats = False |
| | has_downbeats = False |
| |
|
| | beats_dir = os.path.join(args.model_dir, "beats") |
| | downbeats_dir = os.path.join(args.model_dir, "downbeats") |
| |
|
| | if os.path.exists(os.path.join(beats_dir, "model.safetensors")): |
| | beat_model = ResNet.from_pretrained(beats_dir).to(DEVICE) |
| | beat_model.eval() |
| | has_beats = True |
| | print(f"Loaded Beat Model from {beats_dir}") |
| | else: |
| | print(f"Warning: No beat model found in {beats_dir}") |
| |
|
| | if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")): |
| | downbeat_model = ResNet.from_pretrained(downbeats_dir).to(DEVICE) |
| | downbeat_model.eval() |
| | has_downbeats = True |
| | print(f"Loaded Downbeat Model from {downbeats_dir}") |
| | else: |
| | print(f"Warning: No downbeat model found in {downbeats_dir}") |
| |
|
| | if not has_beats and not has_downbeats: |
| | print("No models found. Please run training first.") |
| | return |
| |
|
| | predictions = [] |
| | ground_truths = [] |
| | audio_data = [] |
| |
|
| | |
| | test_set = ds["train"].select(range(args.num_samples)) |
| |
|
| | print("Running evaluation...") |
| | for i, item in enumerate(tqdm(test_set)): |
| | waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32) |
| | waveform_device = waveform.to(DEVICE) |
| |
|
| | pred_entry = {"beats": [], "downbeats": []} |
| |
|
| | |
| | if has_beats: |
| | act_b = get_activation_function(beat_model, waveform_device, DEVICE) |
| | pred_entry["beats"] = pick_peaks(act_b) |
| |
|
| | |
| | if has_downbeats: |
| | act_d = get_activation_function(downbeat_model, waveform_device, DEVICE) |
| | pred_entry["downbeats"] = pick_peaks(act_d) |
| |
|
| | predictions.append(pred_entry) |
| | ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]}) |
| |
|
| | |
| | if args.visualize or args.synthesize: |
| | if i < args.viz_tracks: |
| | audio_data.append( |
| | { |
| | "audio": waveform.numpy(), |
| | "sr": item["audio"]["sampling_rate"], |
| | "pred": pred_entry, |
| | "gt": ground_truths[-1], |
| | } |
| | ) |
| |
|
| | |
| | results = evaluate_all(predictions, ground_truths) |
| | print(format_results(results)) |
| |
|
| | |
| | if args.visualize or args.synthesize or args.summary_plot: |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | |
| | if args.visualize: |
| | print(f"\nGenerating visualizations for {len(audio_data)} tracks...") |
| | viz_dir = os.path.join(args.output_dir, "plots") |
| | for i, data in enumerate(tqdm(audio_data, desc="Visualizing")): |
| | time_range = tuple(args.time_range) if args.time_range else None |
| | visualize_track( |
| | data["audio"], |
| | data["sr"], |
| | data["pred"]["beats"], |
| | data["pred"]["downbeats"], |
| | data["gt"]["beats"], |
| | data["gt"]["downbeats"], |
| | viz_dir, |
| | i, |
| | time_range=time_range, |
| | ) |
| | print(f"Saved visualizations to {viz_dir}") |
| |
|
| | |
| | if args.synthesize: |
| | print(f"\nSynthesizing audio for {len(audio_data)} tracks...") |
| | audio_dir = os.path.join(args.output_dir, "audio") |
| | for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")): |
| | synthesize_audio( |
| | data["audio"], |
| | data["sr"], |
| | data["pred"]["beats"], |
| | data["pred"]["downbeats"], |
| | data["gt"]["beats"], |
| | data["gt"]["downbeats"], |
| | audio_dir, |
| | i, |
| | click_volume=args.click_volume, |
| | ) |
| | print(f"Saved audio files to {audio_dir}") |
| | print(" *_pred.wav - Original audio with predicted beat clicks") |
| | print(" *_gt.wav - Original audio with ground truth beat clicks") |
| | print(" *_both.wav - Original audio with both predicted and GT clicks") |
| |
|
| | |
| | if args.summary_plot: |
| | from ..data.viz import plot_evaluation_summary, save_figure |
| |
|
| | print("\nGenerating summary plot...") |
| | fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary") |
| | summary_path = os.path.join(args.output_dir, "evaluation_summary.png") |
| | save_figure(fig, summary_path) |
| | print(f"Saved summary plot to {summary_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|