| | import torch |
| | from torch.utils.data import Dataset |
| | import numpy as np |
| | from tqdm import tqdm |
| |
|
| |
|
| | class BeatTrackingDataset(Dataset): |
| | def __init__( |
| | self, |
| | hf_dataset, |
| | target_type="beats", |
| | sample_rate=16000, |
| | hop_length=160, |
| | context_frames=50, |
| | ): |
| | """ |
| | Args: |
| | hf_dataset: HuggingFace dataset object |
| | target_type (str): "beats" or "downbeats". Determines which labels are treated as positive. |
| | context_frames (int): Number of frames before and after the center frame. |
| | Total frames = 2 * context_frames + 1. |
| | Default 50 means 101 frames (~1s). |
| | """ |
| | self.sr = sample_rate |
| | self.hop_length = hop_length |
| | self.target_type = target_type |
| |
|
| | self.context_frames = context_frames |
| | |
| | |
| | |
| | |
| | self.context_samples = (self.context_frames * 2 + 1) * hop_length + 1488 |
| |
|
| | |
| | self.audio_cache = [] |
| | self.indices = [] |
| | self._prepare_indices(hf_dataset) |
| |
|
| | def _prepare_indices(self, hf_dataset): |
| | """ |
| | Prepares balanced indices and caches audio. |
| | Uses the same "Fuzzier" training examples strategy as the baseline. |
| | """ |
| | print(f"Preparing dataset indices for target: {self.target_type}...") |
| |
|
| | for i, item in tqdm( |
| | enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices" |
| | ): |
| | |
| | audio = item["audio"]["array"] |
| | if hasattr(audio, "numpy"): |
| | audio = audio.numpy() |
| | self.audio_cache.append(audio) |
| |
|
| | |
| | audio_len = len(audio) |
| | n_frames = int(audio_len / self.hop_length) |
| |
|
| | |
| | if self.target_type == "downbeats": |
| | gt_times = item["downbeats"] |
| | else: |
| | gt_times = item["beats"] |
| |
|
| | |
| | if hasattr(gt_times, "tolist"): |
| | gt_times = gt_times.tolist() |
| |
|
| | gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times]) |
| |
|
| | |
| | pos_frames = set() |
| | for bf in gt_frames: |
| | if 0 <= bf < n_frames: |
| | self.indices.append((i, bf, 1.0)) |
| | pos_frames.add(bf) |
| |
|
| | |
| | if 0 <= bf - 1 < n_frames: |
| | self.indices.append((i, bf - 1, 0.25)) |
| | pos_frames.add(bf - 1) |
| | if 0 <= bf + 1 < n_frames: |
| | self.indices.append((i, bf + 1, 0.25)) |
| | pos_frames.add(bf + 1) |
| |
|
| | |
| | |
| | num_pos = len(pos_frames) |
| | num_neg = num_pos * 2 |
| |
|
| | count = 0 |
| | attempts = 0 |
| | while count < num_neg and attempts < num_neg * 5: |
| | f = np.random.randint(0, n_frames) |
| | if f not in pos_frames: |
| | self.indices.append((i, f, 0.0)) |
| | count += 1 |
| | attempts += 1 |
| |
|
| | print( |
| | f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached." |
| | ) |
| |
|
| | def __len__(self): |
| | return len(self.indices) |
| |
|
| | def __getitem__(self, idx): |
| | track_idx, frame_idx, label = self.indices[idx] |
| |
|
| | |
| | audio = self.audio_cache[track_idx] |
| | audio_len = len(audio) |
| |
|
| | |
| | center_sample = frame_idx * self.hop_length |
| | half_context = self.context_samples // 2 |
| |
|
| | |
| | start = center_sample - half_context |
| | end = center_sample + half_context |
| |
|
| | |
| | pad_left = max(0, -start) |
| | pad_right = max(0, end - audio_len) |
| |
|
| | valid_start = max(0, start) |
| | valid_end = min(audio_len, end) |
| |
|
| | |
| | chunk = audio[valid_start:valid_end] |
| |
|
| | if pad_left > 0 or pad_right > 0: |
| | chunk = np.pad(chunk, (pad_left, pad_right), mode="constant") |
| |
|
| | waveform = torch.tensor(chunk, dtype=torch.float32) |
| | return waveform, torch.tensor([label], dtype=torch.float32) |
| |
|