| | from datasets import load_dataset, Audio |
| |
|
| | N_PROC = None |
| |
|
| | ds = load_dataset("JacobLinCool/taiko-1000-parsed") |
| | ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"]) |
| |
|
| |
|
| | def filter_out_broken(example): |
| | try: |
| | example["audio"]["array"] |
| | return True |
| | except: |
| | return False |
| |
|
| |
|
| | ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32) |
| | ds = ds.cast_column("audio", Audio(sampling_rate=16000)) |
| |
|
| |
|
| | def build_beat_and_downbeat_labels(example): |
| | """ |
| | Extract beat and downbeat times from the chart segments. |
| | |
| | - Downbeats: First beat of each measure (segment timestamp) |
| | - Beats: All beats within each measure based on time signature |
| | |
| | Returns lists of times in seconds. |
| | """ |
| | title = example["metadata"]["TITLE"] |
| | segments = example["oni"]["segments"] |
| |
|
| | beats = [] |
| | downbeats = [] |
| |
|
| | for i, segment in enumerate(segments): |
| | seg_timestamp = segment["timestamp"] |
| | measure_num = segment["measure_num"] |
| | measure_den = segment["measure_den"] |
| | notes = segment["notes"] |
| |
|
| | |
| | downbeats.append(seg_timestamp) |
| |
|
| | |
| | bpm = None |
| | if notes: |
| | bpm = notes[0]["bpm"] |
| | else: |
| | |
| | for j in range(i + 1, len(segments)): |
| | if segments[j]["notes"]: |
| | bpm = segments[j]["notes"][0]["bpm"] |
| | break |
| |
|
| | if bpm is None or bpm <= 0: |
| | bpm = 120.0 |
| |
|
| | |
| | |
| | beat_duration = (60.0 / bpm) * (4.0 / measure_den) |
| |
|
| | |
| | for beat_idx in range(measure_num): |
| | beat_time = seg_timestamp + beat_idx * beat_duration |
| | beats.append(beat_time) |
| |
|
| | |
| | beats = sorted(set(beats)) |
| | downbeats = sorted(set(downbeats)) |
| |
|
| | return { |
| | "title": title, |
| | "beats": beats, |
| | "downbeats": downbeats, |
| | } |
| |
|
| |
|
| | ds = ds.map( |
| | build_beat_and_downbeat_labels, |
| | num_proc=N_PROC, |
| | batch_size=32, |
| | writer_batch_size=32, |
| | remove_columns=["oni", "metadata"], |
| | ) |
| |
|
| | ds = ds.with_format("torch") |
| |
|
| | if __name__ == "__main__": |
| | print(ds) |
| | print(ds["train"].features) |
| |
|