Spaces:
Runtime error
Runtime error
| import logging | |
| import pickle | |
| from enum import Enum | |
| from typing import Any, TypeVar, Union | |
| import numpy as np | |
| from mmcv.utils import print_log | |
| from detrsmpl.data.data_structures.human_data import HumanData | |
| from detrsmpl.utils.path_utils import ( | |
| Existence, | |
| check_path_existence, | |
| check_path_suffix, | |
| ) | |
| # In T = TypeVar('T'), T can be anything. | |
| # See definition of typing.TypeVar for details. | |
| _HumanData = TypeVar('_HumanData') | |
| _MultiHumanData_SUPPORTED_KEYS = HumanData.SUPPORTED_KEYS.copy() | |
| _MultiHumanData_SUPPORTED_KEYS.update( | |
| {'optional': { | |
| 'type': dict, | |
| 'slice_key': 'frame_range', | |
| 'dim': 0 | |
| }}) | |
| class _KeyCheck(Enum): | |
| PASS = 0 | |
| WARN = 1 | |
| ERROR = 2 | |
| class MultiHumanData(HumanData): | |
| SUPPORTED_KEYS = _MultiHumanData_SUPPORTED_KEYS | |
| def __new__(cls: _HumanData, *args: Any, **kwargs: Any) -> _HumanData: | |
| """New an instance of HumanData. | |
| Args: | |
| cls (HumanData): HumanData class. | |
| Returns: | |
| HumanData: An instance of Hu | |
| """ | |
| ret_human_data = super().__new__(cls, args, kwargs) | |
| setattr(ret_human_data, '__data_len__', -1) | |
| setattr(ret_human_data, '__instance_num__', -1) | |
| setattr(ret_human_data, '__key_strict__', False) | |
| setattr(ret_human_data, '__keypoints_compressed__', False) | |
| return ret_human_data | |
| def load(self, npz_path: str): | |
| """Load data from npz_path and update them to self. | |
| Args: | |
| npz_path (str): | |
| Path to a dumped npz file. | |
| """ | |
| supported_keys = self.__class__.SUPPORTED_KEYS | |
| with np.load(npz_path, allow_pickle=True) as npz_file: | |
| tmp_data_dict = dict(npz_file) | |
| for key, value in list(tmp_data_dict.items()): | |
| if isinstance(value, np.ndarray) and\ | |
| len(value.shape) == 0: | |
| # value is not an ndarray before dump | |
| value = value.item() | |
| elif key in supported_keys and\ | |
| type(value) != supported_keys[key]['type']: | |
| value = supported_keys[key]['type'](value) | |
| if value is None: | |
| tmp_data_dict.pop(key) | |
| elif key == '__key_strict__' or \ | |
| key == '__data_len__' or\ | |
| key == '__instance_num__' or\ | |
| key == '__keypoints_compressed__': | |
| self.__setattr__(key, value) | |
| # pop the attributes to keep dict clean | |
| tmp_data_dict.pop(key) | |
| elif key == 'bbox_xywh' and value.shape[1] == 4: | |
| value = np.hstack([value, np.ones([value.shape[0], 1])]) | |
| tmp_data_dict[key] = value | |
| else: | |
| tmp_data_dict[key] = value | |
| self.update(tmp_data_dict) | |
| self.__set_default_values__() | |
| def dump(self, npz_path: str, overwrite: bool = True): | |
| """Dump keys and items to an npz file. | |
| Args: | |
| npz_path (str): | |
| Path to a dumped npz file. | |
| overwrite (bool, optional): | |
| Whether to overwrite if there is already a file. | |
| Defaults to True. | |
| Raises: | |
| ValueError: | |
| npz_path does not end with '.npz'. | |
| FileExistsError: | |
| When overwrite is False and file exists. | |
| """ | |
| if not check_path_suffix(npz_path, ['.npz']): | |
| raise ValueError('Not an npz file.') | |
| if not overwrite: | |
| if check_path_existence(npz_path, 'file') == Existence.FileExist: | |
| raise FileExistsError | |
| dict_to_dump = { | |
| '__key_strict__': self.__key_strict__, | |
| '__data_len__': self.__data_len__, | |
| '__instance_num__': self.__instance_num__, | |
| '__keypoints_compressed__': self.__keypoints_compressed__, | |
| } | |
| dict_to_dump.update(self) | |
| np.savez_compressed(npz_path, **dict_to_dump) | |
| def dump_by_pickle(self, pkl_path: str, overwrite: bool = True) -> None: | |
| """Dump keys and items to a pickle file. It's a secondary dump method, | |
| when a HumanData instance is too large to be dumped by self.dump() | |
| Args: | |
| pkl_path (str): | |
| Path to a dumped pickle file. | |
| overwrite (bool, optional): | |
| Whether to overwrite if there is already a file. | |
| Defaults to True. | |
| Raises: | |
| ValueError: | |
| npz_path does not end with '.pkl'. | |
| FileExistsError: | |
| When overwrite is False and file exists. | |
| """ | |
| if not check_path_suffix(pkl_path, ['.pkl']): | |
| raise ValueError('Not an pkl file.') | |
| if not overwrite: | |
| if check_path_existence(pkl_path, 'file') == Existence.FileExist: | |
| raise FileExistsError | |
| dict_to_dump = { | |
| '__key_strict__': self.__key_strict__, | |
| '__data_len__': self.__data_len__, | |
| '__instance_num__': self.__instance_num__, | |
| '__keypoints_compressed__': self.__keypoints_compressed__, | |
| } | |
| dict_to_dump.update(self) | |
| with open(pkl_path, 'wb') as f_writeb: | |
| pickle.dump(dict_to_dump, | |
| f_writeb, | |
| protocol=pickle.HIGHEST_PROTOCOL) | |
| def load_by_pickle(self, pkl_path: str) -> None: | |
| """Load data from pkl_path and update them to self. | |
| When a HumanData Instance was dumped by | |
| self.dump_by_pickle(), use this to load. | |
| Args: | |
| npz_path (str): | |
| Path to a dumped npz file. | |
| """ | |
| with open(pkl_path, 'rb') as f_readb: | |
| tmp_data_dict = pickle.load(f_readb) | |
| for key, value in list(tmp_data_dict.items()): | |
| if value is None: | |
| tmp_data_dict.pop(key) | |
| elif key == '__key_strict__' or \ | |
| key == '__data_len__' or\ | |
| key == '__instance_num__' or\ | |
| key == '__keypoints_compressed__': | |
| self.__setattr__(key, value) | |
| # pop the attributes to keep dict clean | |
| tmp_data_dict.pop(key) | |
| elif key == 'bbox_xywh' and value.shape[1] == 4: | |
| value = np.hstack([value, np.ones([value.shape[0], 1])]) | |
| tmp_data_dict[key] = value | |
| else: | |
| tmp_data_dict[key] = value | |
| self.update(tmp_data_dict) | |
| self.__set_default_values__() | |
| def instance_num(self) -> int: | |
| """Get the human instance num of this MultiHumanData instance. In | |
| MuliHumanData, an image may have multiple corresponding human | |
| instances. | |
| Returns: | |
| int: | |
| Number of human instance related to this instance. | |
| """ | |
| return self.__instance_num__ | |
| def instance_num(self, value: int): | |
| """Set the human instance num of this MultiHumanData instance. | |
| Args: | |
| value (int): | |
| Number of human instance related to this instance. | |
| """ | |
| self.__instance_num__ = value | |
| def get_slice(self, | |
| arg_0: int, | |
| arg_1: Union[int, Any] = None, | |
| step: int = 1) -> _HumanData: | |
| """Slice all sliceable values along major_dim dimension. | |
| Args: | |
| arg_0 (int): | |
| When arg_1 is None, arg_0 is stop and start=0. | |
| When arg_1 is not None, arg_0 is start. | |
| arg_1 (Union[int, Any], optional): | |
| None or where to stop. | |
| Defaults to None. | |
| step (int, optional): | |
| Length of step. Defaults to 1. | |
| Returns: | |
| MultiHumanData: | |
| A new MultiHumanData instance with sliced values. | |
| """ | |
| ret_human_data = \ | |
| MultiHumanData.new(key_strict=self.get_key_strict()) | |
| if arg_1 is None: | |
| start = 0 | |
| stop = arg_0 | |
| else: | |
| start = arg_0 | |
| stop = arg_1 | |
| slice_index = slice(start, stop, step) | |
| dim_dict = self.__get_slice_dim__() | |
| # frame_range = self.get_raw_value('optional')['frame_range'] | |
| for key, dim in dim_dict.items(): | |
| # primary index | |
| if key == 'optional': | |
| frame_range = None | |
| else: | |
| frame_range = self.get_raw_value('optional')['frame_range'] | |
| # keys not expected be sliced | |
| if dim is None: | |
| ret_human_data[key] = self[key] | |
| elif isinstance(dim, dict): | |
| value_dict = self.get_raw_value(key) | |
| sliced_dict = {} | |
| for sub_key in value_dict.keys(): | |
| sub_value = value_dict[sub_key] | |
| if dim[sub_key] is None: | |
| sliced_dict[sub_key] = sub_value | |
| else: | |
| sub_dim = dim[sub_key] | |
| sliced_sub_value = \ | |
| MultiHumanData.__get_sliced_result__( | |
| sub_value, sub_dim, slice_index, frame_range) | |
| sliced_dict[sub_key] = sliced_sub_value | |
| ret_human_data[key] = sliced_dict | |
| else: | |
| value = self[key] | |
| sliced_value = \ | |
| MultiHumanData.__get_sliced_result__( | |
| value, dim, slice_index, frame_range) | |
| ret_human_data[key] = sliced_value | |
| # check keypoints compressed | |
| if self.check_keypoints_compressed(): | |
| ret_human_data.compress_keypoints_by_mask() | |
| return ret_human_data | |
| def __get_slice_dim__(self) -> dict: | |
| """For each key in this HumanData, get the dimension for slicing. 0 for | |
| default, if no other value specified. | |
| Returns: | |
| dict: | |
| Keys are self.keys(). | |
| Values indicate where to slice. | |
| None for not expected to be sliced or | |
| failed. | |
| """ | |
| supported_keys = self.__class__.SUPPORTED_KEYS | |
| ret_dict = {} | |
| for key in self.keys(): | |
| # keys not expected be sliced | |
| if key in supported_keys and \ | |
| 'dim' in supported_keys[key] and \ | |
| supported_keys[key]['dim'] is None: | |
| ret_dict[key] = None | |
| else: | |
| value = self[key] | |
| if isinstance(value, dict) and len(value) > 0: | |
| ret_dict[key] = {} | |
| for sub_key in value.keys(): | |
| try: | |
| sub_value_len = len(value[sub_key]) | |
| if sub_value_len != self.instance_num and \ | |
| sub_value_len != self.data_len: | |
| ret_dict[key][sub_key] = None | |
| elif 'dim' in value: | |
| ret_dict[key][sub_key] = value['dim'] | |
| else: | |
| ret_dict[key][sub_key] = 0 | |
| except TypeError: | |
| ret_dict[key][sub_key] = None | |
| continue | |
| # instance cannot be sliced without len method | |
| try: | |
| value_len = len(value) | |
| except TypeError: | |
| ret_dict[key] = None | |
| continue | |
| # slice on dim 0 by default | |
| slice_dim = 0 | |
| if key in supported_keys and \ | |
| 'dim' in supported_keys[key]: | |
| slice_dim = \ | |
| supported_keys[key]['dim'] | |
| data_len = value_len if slice_dim == 0 \ | |
| else value.shape[slice_dim] | |
| # dim not for slice | |
| if data_len != self.__instance_num__: | |
| ret_dict[key] = None | |
| continue | |
| else: | |
| ret_dict[key] = slice_dim | |
| return ret_dict | |
| # TODO: to support cache | |
| def __check_value_len__(self, key: Any, val: Any) -> bool: | |
| """Check whether the temporal length of val matches other values. | |
| Args: | |
| key (Any): | |
| Key in MultiHumanData. | |
| val (Any): | |
| Value to the key. | |
| Returns: | |
| bool: | |
| If temporal dim is defined and temporal length doesn't match, | |
| return False. | |
| Else return True. | |
| """ | |
| ret_bool = True | |
| supported_keys = self.__class__.SUPPORTED_KEYS | |
| # MultiHumanData | |
| instance_num = 0 | |
| if key == 'optional' and \ | |
| 'frame_range' in val: | |
| for frame_range in val['frame_range']: | |
| instance_num += (frame_range[-1] - frame_range[0]) | |
| if self.instance_num == -1: | |
| # init instance_num for multi_human_data | |
| self.instance_num = instance_num | |
| elif self.instance_num != instance_num: | |
| ret_bool = False | |
| data_len = len(val['frame_range']) | |
| if self.data_len == -1: | |
| # init data_len | |
| self.data_len = data_len | |
| elif self.data_len == self.instance_num: | |
| # update data_len | |
| self.data_len = data_len | |
| elif self.data_len != self.instance_num: | |
| ret_bool = False | |
| # check definition | |
| elif key in supported_keys: | |
| # check data length | |
| if 'dim' in supported_keys[key] and \ | |
| supported_keys[key]['dim'] is not None: | |
| val_slice_dim = supported_keys[key]['dim'] | |
| if supported_keys[key]['type'] == dict: | |
| slice_key = supported_keys[key]['slice_key'] | |
| val_data_len = val[slice_key].shape[val_slice_dim] | |
| else: | |
| val_data_len = val.shape[val_slice_dim] | |
| if self.instance_num < 0: | |
| # Init instance_num for HumanData, | |
| # which is equal to data_len. | |
| self.instance_num = val_data_len | |
| else: | |
| # check if val_data_len matches recorded instance_num | |
| if self.instance_num != val_data_len: | |
| ret_bool = False | |
| if self.data_len < 0: | |
| # init data_len for HumanData, it's equal to | |
| # instance_num. | |
| # If it's MultiHumanData needs to be updated | |
| self.data_len = val_data_len | |
| if not ret_bool: | |
| err_msg = 'Data length check Failed:\n' | |
| err_msg += f'key={str(key)}\n' | |
| if self.data_len != self.instance_num: | |
| err_msg += f'val\'s instance_num={self.data_len}\n' | |
| err_msg += f'expected instance_num={self.instance_num}\n' | |
| print_log(msg=err_msg, | |
| logger=self.__class__.logger, | |
| level=logging.ERROR) | |
| return ret_bool | |
| def __set_default_values__(self) -> None: | |
| """For older versions of HumanData, call this method to apply missing | |
| values (also attributes). | |
| Note: | |
| 1. Older HumanData doesn't define `data_len`. | |
| 2. In the newer HumanData, `data_len` equals the `instances_num`. | |
| 3. In MultiHumanData, `instance_num` equals instances num, | |
| and `data_len` equals frames num. | |
| """ | |
| supported_keys = self.__class__.SUPPORTED_KEYS | |
| if self.instance_num == -1: | |
| # the loaded file is not multi_human_data | |
| for key in supported_keys: | |
| if key in self and \ | |
| 'dim' in supported_keys[key] and\ | |
| supported_keys[key]['dim'] is not None: | |
| if 'slice_key' in supported_keys[key] and\ | |
| supported_keys[key]['type'] == dict: | |
| sub_key = supported_keys[key]['slice_key'] | |
| slice_dim = supported_keys[key]['dim'] | |
| self.instance_num = self[key][sub_key].shape[slice_dim] | |
| else: | |
| slice_dim = supported_keys[key]['dim'] | |
| self.instance_num = self[key].shape[slice_dim] | |
| # convert HumanData to MultiHumanData | |
| self.data_len = self.instance_num | |
| optional = {} | |
| optional['frame_range'] = \ | |
| [[i, i + 1] for i in range(self.data_len)] | |
| self['optional'] = optional | |
| break | |
| for key in list(self.keys()): | |
| convention_key = f'{key}_convention' | |
| if key.startswith('keypoints') and \ | |
| not key.endswith('_mask') and \ | |
| not key.endswith('_convention') and \ | |
| convention_key not in self: | |
| self[convention_key] = 'human_data' | |
| def __get_sliced_result__( | |
| cls, | |
| input_data: Union[np.ndarray, list, tuple], | |
| slice_dim: int, | |
| slice_range: slice, | |
| frame_index: list = None) -> Union[np.ndarray, list, tuple]: | |
| if frame_index is not None: | |
| slice_data = [] | |
| for frame_range in frame_index[slice_range]: | |
| slice_index = slice(frame_range[0], frame_range[-1], 1) | |
| slice_result = \ | |
| HumanData.__get_sliced_result__( | |
| input_data, | |
| slice_dim, | |
| slice_index) | |
| for element in slice_result: | |
| slice_data.append(element) | |
| if isinstance(input_data, np.ndarray): | |
| slice_data = np.array(slice_data) | |
| else: | |
| slice_data = type(input_data)(slice_data) | |
| else: | |
| # primary index | |
| slice_data = \ | |
| HumanData.__get_sliced_result__( | |
| input_data, | |
| slice_dim, | |
| slice_range) | |
| return slice_data | |