import numpy as np import torch from torch.utils.data import Dataset import os import h5py from torch_cluster import fps import json import random class TrajDataset(Dataset): def __init__(self, split, cfg): self.cfg = cfg self.dataset_path = cfg.dataset_path self.split = split self.stage = cfg.stage # 'shape' or 'deform' self.mode = cfg.mode # 'ae' or 'diff' self.repeat = cfg.repeat self.seed = cfg.seed self.pc_size = cfg.pc_size self.n_sample_pro_model = cfg.n_sample_pro_model self.n_frames_interval = cfg.n_frames_interval self.n_training_frames = cfg.n_training_frames self.batch_size = cfg.batch_size self.has_gravity = cfg.get('has_gravity', False) self.max_num_forces = cfg.get('max_num_forces', 1) # if os.path.exists(os.path.join(self.dataset_path, cfg.dataset_list)): if os.path.exists(cfg.dataset_list): print(f'Loading {cfg.dataset_list}') with open(cfg.dataset_list, 'r') as f: self.split_lst = json.load(f) else: self.split_lst = [f for f in sorted(os.listdir(self.dataset_path)) if f.endswith('h5')] random.seed(0) random.shuffle(self.split_lst) print('Number of data:', len(self.split_lst)) if cfg.overfit: self.split_lst = self.split_lst[:1] elif cfg.dataset_path.endswith('_test') or cfg.dataset_list.endswith('test.json') or cfg.dataset_list.endswith('test_list.json'): self.split_lst = self.split_lst[:100] print('Test split:', self.split_lst) else: if split == 'train': self.split_lst = self.split_lst[:-4] else: self.split_lst = self.split_lst[-8:] print('Test split:', self.split_lst) self.split_lst_save = self.split_lst.copy() self.split_lst_pcl_len = [49] * len(self.split_lst_save) # if not os.path.exists(os.path.join(self.dataset_path, f'info_deform_ae_{split}.json')): self.prepare_data_lst() # with open(os.path.join(self.dataset_path, f'info_deform_ae_{split}.json'), "w") as f: # json.dump(self.models, f) # print(f'Saved info_deform_ae_{split}.json') # else: # self.models = json.load(open(os.path.join(self.dataset_path, f'info_deform_ae_{split}.json'), 'r')) # print(f'Loaded info_deform_ae_{split}.json') print("Current stage: [bold red]{}[/bold red]".format(self.stage)) print("Current mode: [bold red]{}[/bold red]".format(self.mode)) print("Current split: [bold red]{}[/bold red]".format(self.split)) print("Dataset is repeated [bold cyan]{}[/bold cyan] times".format(self.repeat)) print("Length of split: {}".format(len(self.split_lst) if self.stage == 'shape' else len(self.models))) def prepare_data_lst(self): self.models = [] if self.stage == 'deform': if self.mode == 'ae': if self.split == 'train': models_out, indices_out = self.random_sample_indexes(self.split_lst_save * self.repeat, self.split_lst_pcl_len * self.repeat) self.models += [{"model": m, "indices": indices_out[i]} for i, m in enumerate(models_out)] else: # Evaluate for m in self.split_lst_save: for i in range(1, self.batch_size + 1): self.models += [{"model": m, "indices": [i-1, i]}] elif self.mode == 'diff': # models_out, indices_out = self.subdivide_into_sequences(self.split_lst_save * self.repeat, self.split_lst_pcl_len * self.repeat) # self.models += [{"model": m, "start_idx": indices_out[i]} for i, m in enumerate(models_out)] self.models += [{"model": m, "start_idx": 0} for i, m in enumerate(self.split_lst_save)] else: raise NotImplementedError("mode not implemented") def __getitem__(self, index): if self.stage == 'deform': if self.mode == 'ae': return self.get_deform_ae(index) elif self.mode == 'diff': return self.get_deform_diff(index) def __len__(self): if self.stage == 'deform': if self.mode == 'ae': if self.split == 'train': return sum(self.split_lst_pcl_len) * self.repeat else: return len(self.split_lst_save) * self.batch_size # number of sequences elif self.mode == 'diff': return len(self.models) else: raise NotImplementedError("mode not implemented") def random_sample_indexes(self, models, models_len): n_sample_pro_model = self.n_sample_pro_model interval_between_frames = self.interval_between_frames n_selected_frames = self.n_selected_frames # Initialize output lists models_out = [] indexes_out = [] # Loop over each model for idx, model in enumerate(models): # For each sample per model for n in range(n_sample_pro_model): # Initialize indices list for current sample indexes = [] # Select n_selected_frames number of indices for i in range(n_selected_frames): # If first index, randomly select from range if i == 0: # indexes.append(np.random.randint(0, models_len[idx] - interval_between_frames)) indexes.append(np.random.randint(0, models_len[idx])) else: # For subsequent indices, select within interval_between_frames from the previous index indexes.append( min(indexes[-1] + np.random.randint(0, interval_between_frames), models_len[idx]-1) ) # Append the selected indices and corresponding model to output lists indexes_out.append(sorted(indexes)) models_out.append(model) return models_out, indexes_out def get_deform_ae(self, index): model = self.models[index] model_name = model["model"] model_indices = model["indices"] model_info = {} model_info["model"] = model_name model_info["indices"] = model_indices model_metas = h5py.File(os.path.join(self.dataset_path, f'{model_name}'), 'r') model_pcls = torch.from_numpy(np.array(model_metas['x'])) ind = np.random.default_rng(seed=self.seed).choice(model_pcls[0].shape[0], self.pc_size, replace=False) points_src = model_pcls[model_indices[0]][ind] points_tgt = model_pcls[model_indices[1]][ind] model_data = {} model_data['points_src'] = points_src.float() model_data['points_tgt'] = points_tgt.float() return model_data, model_info def get_deform_diff(self, index): model = self.models[index] model_name = model["model"] model_info = {} model_info["model"] = model_name model_info["indices"] = np.arange(self.n_training_frames) model_data = {} model_data['model'] = model_name model_metas = h5py.File(os.path.join(self.dataset_path, f'{model_name}'), 'r') model_pcls = torch.from_numpy(np.array(model_metas['x'])) # if model_pcls[0].shape[0] > self.pc_size: # ind = np.random.default_rng(seed=self.seed).choice(model_pcls[0].shape[0], self.pc_size, replace=False) # points_src = model_pcls[:1] # points_tgt = model_pcls[1:(self.n_training_frames*self.n_frames_interval+1):self.n_frames_interval][:, ind] # else: # No need to do fps in new dataset case (input is 2048 points) points_src = model_pcls[:1] points_tgt = model_pcls[1:(self.n_training_frames*self.n_frames_interval+1):self.n_frames_interval] if not 'drag_point' in model_metas: # Assume drag direction cross the sphere center drag_dir = np.array(model_metas['drag_force']) drag_dir = drag_dir / np.linalg.norm(drag_dir) drag_point = np.array([self.cfg.norm_fac, self.cfg.norm_fac, self.cfg.norm_fac]) + drag_dir else: drag_point = np.array(model_metas['drag_point']) if not 'floor_height' in model_metas: model_data['floor_height'] = torch.from_numpy(np.array(-2.4)).unsqueeze(-1).float() else: model_data['floor_height'] = (torch.from_numpy(np.array(model_metas['floor_height'])).unsqueeze(-1).float() - self.cfg.norm_fac) / 2 model_data['drag_point'] = (torch.from_numpy(drag_point).float() - self.cfg.norm_fac) / 2 model_data['points_src'] = (points_src.float() - self.cfg.norm_fac) / 2 model_data['points_tgt'] = (points_tgt.float() - self.cfg.norm_fac) / 2 model_data['vol'] = torch.from_numpy(np.array(model_metas['vol'])) model_data['F'] = torch.from_numpy(np.array(model_metas['F'])) model_data['F'] = model_data['F'][1:(self.n_training_frames*self.n_frames_interval+1):self.n_frames_interval] model_data['C'] = torch.from_numpy(np.array(model_metas['C'])) model_data['C'] = model_data['C'][1:(self.n_training_frames*self.n_frames_interval+1):self.n_frames_interval] mask = torch.from_numpy(np.array(model_metas['drag_mask'])).bool() if 'gravity' in model_metas: model_data['gravity'] = torch.from_numpy(np.array(model_metas['gravity'])).long().unsqueeze(0) else: # print('no gravity in model_metas') model_data['gravity'] = torch.from_numpy(np.array(0)).long().unsqueeze(0) model_data['drag_point'] = (torch.from_numpy(drag_point).float() - self.cfg.norm_fac) / 2 if model_data['drag_point'].ndim == 1: # For compatibility: only have one force model_data['drag_point'] = torch.cat([model_data['drag_point'], torch.tensor([mask.sum()]).float()], dim=0).unsqueeze(0) else: model_data['drag_point'] = torch.cat([model_data['drag_point'], mask.sum(dim=-1, keepdim=True).float()], dim=1) force_order = torch.randperm(self.max_num_forces) if self.split == 'train' else torch.arange(self.max_num_forces) mask = mask.unsqueeze(0) if mask.ndim == 1 else mask # force_mask = torch.ones(self.max_num_forces, 1) # force_mask[:mask.shape[0]] *= 0 # force_mask = force_mask[force_order].bool() if mask.shape[1] == 0: mask = torch.zeros(0, self.pc_size).bool() model_data['force'] = torch.zeros(0, 3) model_data['drag_point'] = torch.zeros(0, 4) model_data['base_drag_coeff'] = torch.zeros(self.max_num_forces, 1) elif not 'base_drag_coeff' in model_metas: vol = model_data['vol'].unsqueeze(0) total_volume = torch.sum(vol) masked_volume = torch.sum(vol * mask, dim=1) mean_masked_volume = masked_volume / mask.sum(dim=1) mask_ratio = masked_volume / total_volume base_drag_coeff = 9.8 * 1000 * mean_masked_volume / mask_ratio weighted_force = torch.from_numpy(np.array(model_metas['drag_force'])).float() weighted_force = weighted_force.unsqueeze(0) if weighted_force.ndim == 1 else weighted_force model_data['force'] = weighted_force / base_drag_coeff.unsqueeze(1) coeff = torch.zeros(self.max_num_forces, 1) coeff = coeff[force_order] coeff[:base_drag_coeff.shape[0]] = base_drag_coeff.unsqueeze(1) model_data['base_drag_coeff'] = coeff # model_data['weighted_force'] = weighted_force else: model_data['force'] = torch.from_numpy(np.array(model_metas['drag_force'])).float() model_data['base_drag_coeff'] = torch.from_numpy(np.array(model_metas['base_drag_coeff'])).float() model_data['is_mpm'] = torch.tensor(1).bool() if 'mat_type' in model_metas: model_data['mat_type'] = torch.from_numpy(np.array(model_metas['mat_type'])).long() if np.array(model_data['mat_type']).item() == 3: # Rigid dataset model_data['is_mpm'] = torch.tensor(0).bool() else: # temporary fix for elastic data model_data['mat_type'] = torch.tensor(0).long() if self.has_gravity and model_data['gravity'][0] == 1: # add gravity to force gravity = torch.tensor([[0, -1.0, 0]]).float() drag_point = (model_data['points_src'][0] * (model_data['vol'] / model_data['vol'].sum()).unsqueeze(1)).sum(axis=0) if model_data['is_mpm'] else model_data['points_src'][0].mean(axis=0) drag_point = torch.cat([drag_point, torch.tensor([self.pc_size]).float()]).unsqueeze(0) assert model_data['force'].sum() == 0, f'we are not supporting both drag and gravity now: {model_name}' model_data['force'] = torch.cat([model_data['force'], gravity], dim=0) if not model_data['force'].sum() == 0 else gravity model_data['drag_point'] = torch.cat([model_data['drag_point'], drag_point], dim=0) if not drag_point.sum() == 0 else drag_point mask = torch.cat([mask, torch.ones_like(mask).bool()], dim=0) if not mask.sum() == 0 else torch.ones(1, self.pc_size).bool() all_forces = torch.zeros(self.max_num_forces, 3) all_forces[:model_data['force'].shape[0]] = model_data['force'] all_forces = all_forces[force_order] model_data['force'] = all_forces all_drag_points = torch.zeros(self.max_num_forces, 4) all_drag_points[:model_data['drag_point'].shape[0]] = model_data['drag_point'] all_drag_points = all_drag_points[force_order] model_data['drag_point'] = all_drag_points if model_pcls[0].shape[0] > self.pc_size: ind = np.random.default_rng(seed=self.seed).choice(model_pcls[0].shape[0], self.pc_size, replace=False) model_data['points_src'] = model_data['points_src'][:, ind] model_data['points_tgt'] = model_data['points_tgt'][:, ind] mask = mask[:, ind] if mask.shape[-1] > self.pc_size else mask all_mask = torch.zeros(self.max_num_forces, self.pc_size).bool() all_mask[:mask.shape[0]] = mask all_mask = all_mask[force_order] model_data['mask'] = all_mask[..., None] # (n_forces, pc_size, 1) for compatibility model_data['E'] = torch.log10(torch.from_numpy(np.array(model_metas['E'])).unsqueeze(-1).float()) if np.array(model_metas['E']) > 0 else torch.zeros(1).float() model_data['nu'] = torch.from_numpy(np.array(model_metas['nu'])).unsqueeze(-1).float() return model_data, model_info