|
|
|
|
|
|
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import torch |
|
|
import tqdm |
|
|
|
|
|
|
|
|
import datetime |
|
|
import numpy as np |
|
|
|
|
|
import sys |
|
|
sys.path.append(os.path.dirname(os.getcwd()) + "/python_lib") |
|
|
import midigpt |
|
|
|
|
|
class CustomDataset: |
|
|
def __init__(self, split_id=0, is_training=True, batch_size=32, dataset=None, num_bars=4, min_tracks=2, max_tracks=12, max_seq_len=2048, expressive=False, no_max_length=False, resolution=12, encoding=None, pad_value=-100, arch="gpt2", accum_steps=1, batches_per_epoch=1000, overload_batches_per_epoch=None, **kwargs): |
|
|
|
|
|
self.is_training = is_training |
|
|
self.batch_size = batch_size // accum_steps |
|
|
self.split_id = split_id |
|
|
self.max_seq_len = max_seq_len |
|
|
self.batches_per_epoch = batches_per_epoch if overload_batches_per_epoch is None else overload_batches_per_epoch |
|
|
self.dataset = list(range(self.batches_per_epoch)) |
|
|
self.pad_value = pad_value |
|
|
self.arch = arch |
|
|
|
|
|
|
|
|
self.dataloader = midigpt.Jagged(dataset) |
|
|
self.dataloader.set_num_bars(num_bars) |
|
|
self.dataloader.set_min_tracks(min_tracks) |
|
|
self.dataloader.set_max_tracks(max_tracks) |
|
|
self.dataloader.set_max_seq_len(max_seq_len) |
|
|
seed = np.random.randint(2**20) |
|
|
self.dataloader.set_seed(seed) |
|
|
self.encoder_mode = midigpt.getEncoderType(encoding) |
|
|
|
|
|
|
|
|
self.tc = midigpt.TrainConfig() |
|
|
self.tc.num_bars = num_bars |
|
|
self.tc.min_tracks = min_tracks |
|
|
self.tc.max_tracks = max_tracks |
|
|
self.tc.use_microtiming = expressive |
|
|
self.tc.no_max_length = no_max_length |
|
|
self.tc.resolution = resolution |
|
|
|
|
|
self.current = 0 |
|
|
|
|
|
def _get_batch(self): |
|
|
input_ids, mask = self.dataloader.read_batch_v2( |
|
|
self.batch_size, self.split_id, self.encoder_mode, self.tc) |
|
|
input_ids = np.array(input_ids) |
|
|
mask = np.array(mask) |
|
|
labels = np.copy(input_ids) |
|
|
labels += (1-mask) * self.pad_value |
|
|
batch = { |
|
|
"input_ids" : torch.from_numpy(input_ids), |
|
|
"attention_mask" : torch.from_numpy(mask), |
|
|
"labels" : torch.from_numpy(labels) |
|
|
} |
|
|
if self.arch == "xl": |
|
|
batch.pop("attention_mask") |
|
|
assert np.all(np.sum(mask,axis=1)==self.max_seq_len) |
|
|
if self.arch == "bert": |
|
|
batch.pop("labels") |
|
|
return batch |
|
|
|
|
|
def _get_batch_test(self): |
|
|
inputs = torch.ones((32,800), dtype=torch.int64) |
|
|
return { |
|
|
"input_ids" : inputs, |
|
|
"labels" : inputs |
|
|
} |
|
|
|
|
|
def __iter__(self): |
|
|
self.current = 0 |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
self.current += 1 |
|
|
if self.current <= self.batches_per_epoch: |
|
|
while True: |
|
|
try: |
|
|
return self._get_batch() |
|
|
except Exception as e: |
|
|
print("ERROR IN BATCHER : ", e) |
|
|
raise StopIteration |
|
|
|
|
|
def __len__(self): |
|
|
return self.batches_per_epoch |
|
|
|
|
|
def pad(seqs, pad_value): |
|
|
seqlens = np.array([len(seq) for seq in seqs]) |
|
|
maxlen = np.max(seqlens) |
|
|
return np.array([np.pad(seq, (0,maxlen-len(seq)), mode="constant", constant_values=pad_value) for seq in seqs]), seqlens |