| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import csv |
| import gc |
| import logging |
| import os |
| import shutil |
| import sys |
| import time |
| from collections import OrderedDict |
| from datetime import datetime |
|
|
| import monai.transforms as mt |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import yaml |
| from monai.apps import get_logger |
| from monai.auto3dseg.utils import datafold_read |
| from monai.bundle import BundleWorkflow, ConfigParser |
| from monai.config import print_config |
| from monai.data import DataLoader, Dataset, decollate_batch |
| from monai.metrics import CumulativeAverage |
| from monai.utils import ( |
| BundleProperty, |
| ImageMetaKey, |
| convert_to_dst_type, |
| ensure_tuple, |
| look_up_option, |
| optional_import, |
| set_determinism, |
| ) |
| from torch.cuda.amp import GradScaler, autocast |
| from torch.utils.data import WeightedRandomSampler |
| from torch.utils.data.distributed import DistributedSampler |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| mlflow, mlflow_is_imported = optional_import("mlflow") |
|
|
|
|
| if __package__ in (None, ""): |
| from cell_distributed_weighted_sampler import DistributedWeightedSampler |
| from components import LabelsToFlows, LoadTiffd, LogitsToLabels |
| from utils import LOGGING_CONFIG, parsing_bundle_config |
| else: |
| from .cell_distributed_weighted_sampler import DistributedWeightedSampler |
| from .components import LabelsToFlows, LoadTiffd, LogitsToLabels |
| from .utils import LOGGING_CONFIG, parsing_bundle_config |
|
|
|
|
| logger = get_logger("VistaCell") |
|
|
|
|
| class VistaCell(BundleWorkflow): |
| """ |
| Primary vista model training workflow that extends |
| monai.bundle.BundleWorkflow for cell segmentation. |
| """ |
|
|
| def __init__(self, config_file=None, meta_file=None, logging_file=None, workflow_type="train", **override): |
| """ |
| config_file can be one or a list of config files. |
| the rest key-values in the `override` are to override config content. |
| """ |
|
|
| parser = parsing_bundle_config(config_file, logging_file=logging_file, meta_file=meta_file) |
| parser.update(pairs=override) |
|
|
| mode = parser.get("mode", None) |
| if mode is not None: |
| workflow_type = mode |
| else: |
| mode = workflow_type |
| super().__init__(workflow_type=workflow_type) |
| self._props = {} |
| self._set_props = {} |
| self.parser = parser |
|
|
| self.rank = int(os.getenv("LOCAL_RANK", "0")) |
| self.global_rank = int(os.getenv("RANK", "0")) |
| self.is_distributed = dist.is_available() and dist.is_initialized() |
|
|
| |
| if dist.is_torchelastic_launched() or ( |
| os.getenv("NGC_ARRAY_SIZE") is not None and int(os.getenv("NGC_ARRAY_SIZE")) > 1 |
| ): |
| if dist.is_available(): |
| dist.init_process_group(backend="nccl", init_method="env://") |
|
|
| self.is_distributed = dist.is_available() and dist.is_initialized() |
|
|
| torch.cuda.set_device(self.config("device")) |
| dist.barrier() |
|
|
| else: |
| self.is_distributed = False |
|
|
| if self.global_rank == 0 and self.config("ckpt_path") and not os.path.exists(self.config("ckpt_path")): |
| os.makedirs(self.config("ckpt_path"), exist_ok=True) |
|
|
| if self.rank == 0: |
| |
| _log_file = self.config("log_output_file", "vista_cell.log") |
| _log_file_dir = os.path.dirname(_log_file) |
| if _log_file_dir and not os.path.exists(_log_file_dir): |
| os.makedirs(_log_file_dir, exist_ok=True) |
|
|
| print_config() |
|
|
| if self.is_distributed: |
| dist.barrier() |
|
|
| seed = self.config("seed", None) |
| if seed is not None: |
| set_determinism(seed) |
| logger.info(f"set determinism seed: {self.config('seed', None)}") |
| elif torch.cuda.is_available(): |
| torch.backends.cudnn.benchmark = True |
| logger.info("No seed provided, using cudnn.benchmark for performance.") |
|
|
| if os.path.exists(self.config("ckpt_path")): |
| self.parser.export_config_file( |
| self.parser.config, |
| os.path.join(self.config("ckpt_path"), "working.yaml"), |
| fmt="yaml", |
| default_flow_style=None, |
| ) |
|
|
| self.add_property("network", required=True) |
| self.add_property("train_loader", required=True) |
| self.add_property("val_dataset", required=False) |
| self.add_property("val_loader", required=False) |
| self.add_property("val_preprocessing", required=False) |
| self.add_property("train_sampler", required=True) |
| self.add_property("val_sampler", required=True) |
| self.add_property("mode", required=False) |
| |
| |
| self.evaluator = None |
|
|
| def _set_property(self, name, property, value): |
| |
| self._set_props[name] = value |
|
|
| def _get_property(self, name, property): |
| """ |
| The customized bundle workflow must implement required properties in: |
| https://github.com/Project-MONAI/MONAI/blob/dev/monai/bundle/properties.py. |
| """ |
| if name in self._set_props: |
| self._props[name] = self._set_props[name] |
| return self._props[name] |
| if name in self._props: |
| return self._props[name] |
| try: |
| value = getattr(self, f"get_{name}")() |
| except AttributeError as err: |
| if property[BundleProperty.REQUIRED]: |
| raise ValueError( |
| f"Property '{name}' is required by the bundle format, " |
| f"but the method 'get_{name}' is not implemented." |
| ) from err |
| raise AttributeError from err |
| self._props[name] = value |
| return value |
|
|
| def config(self, name, default="null", **kwargs): |
| """read the parsed content (evaluate the expression) from the config file.""" |
| if default != "null": |
| return self.parser.get_parsed_content(name, default=default, **kwargs) |
| return self.parser.get_parsed_content(name, **kwargs) |
|
|
| def initialize(self): |
| _log_file = self.config("log_output_file", "vista_cell.log") |
| if _log_file is None: |
| LOGGING_CONFIG["loggers"]["VistaCell"]["handlers"].remove("file") |
| LOGGING_CONFIG["handlers"].pop("file", None) |
| else: |
| LOGGING_CONFIG["handlers"]["file"]["filename"] = _log_file |
| logging.config.dictConfig(LOGGING_CONFIG) |
|
|
| def get_mode(self): |
| mode_str = self.config("mode", self.workflow_type) |
| return look_up_option(mode_str, ("train", "training", "infer", "inference", "eval", "evaluation")) |
|
|
| def run(self): |
| if str(self.mode).startswith("train"): |
| return self.train() |
| if str(self.mode).startswith("infer"): |
| return self.infer() |
| return self.validate() |
|
|
| def finalize(self): |
| if self.is_distributed: |
| dist.destroy_process_group() |
| set_determinism(None) |
|
|
| def get_network_def(self): |
| return self.config("network_def") |
|
|
| def get_network(self): |
| pretrained_ckpt_name = self.config("pretrained_ckpt_name", None) |
| pretrained_ckpt_path = self.config("pretrained_ckpt_path", None) |
| if pretrained_ckpt_name is not None and pretrained_ckpt_path is None: |
| |
| pretrained_ckpt_path = os.path.join(self.config("ckpt_path"), pretrained_ckpt_name) |
|
|
| if pretrained_ckpt_path is not None and not os.path.exists(pretrained_ckpt_path): |
| logger.info(f"Pretrained checkpoint {pretrained_ckpt_path} not found.") |
| raise ValueError(f"Pretrained checkpoint {pretrained_ckpt_path} not found.") |
|
|
| if pretrained_ckpt_path is not None and os.path.exists(pretrained_ckpt_path): |
| |
| if "checkpoint" in self.parser.config["network_def"]: |
| self.parser.config["network_def"]["checkpoint"] = None |
| model = self.config("network") |
| self.checkpoint_load(ckpt=pretrained_ckpt_path, model=model) |
| else: |
| model = self.config("network") |
|
|
| if self.config("channels_last", False): |
| model = model.to(memory_format=torch.channels_last) |
|
|
| if self.is_distributed: |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
| if self.config("compile", False): |
| model = torch.compile(model) |
|
|
| if self.is_distributed: |
| model = torch.nn.parallel.DistributedDataParallel( |
| module=model, |
| device_ids=[self.rank], |
| output_device=self.rank, |
| find_unused_parameters=self.config("find_unused_parameters", False), |
| ) |
|
|
| pytorch_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info(f"total parameters count {pytorch_params} distributed {self.is_distributed}") |
| return model |
|
|
| def get_train_dataset_data(self): |
| train_files, valid_files = [], [] |
| dataset_data = self.config("train#dataset#data") |
| val_key = None |
| if isinstance(dataset_data, dict): |
| val_key = dataset_data.get("key", None) |
| data_list_files = dataset_data["data_list_files"] |
|
|
| if isinstance(data_list_files, str): |
| data_list_files = ConfigParser.load_config_file( |
| data_list_files |
| ) |
| else: |
| data_list_files = ensure_tuple(data_list_files) |
|
|
| if self.global_rank == 0: |
| print("Using data_list_files ", data_list_files) |
|
|
| for idx, d in enumerate(data_list_files): |
| logger.info(f"adding datalist ({idx}): {d['datalist']}") |
| t, v = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=self.config("fold")) |
|
|
| if val_key is not None: |
| v, _ = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=-1, key=val_key) |
|
|
| for item in t: |
| item["datalist_id"] = idx |
| item["datalist_count"] = len(t) |
| for item in v: |
| item["datalist_id"] = idx |
| item["datalist_count"] = len(v) |
| train_files.extend(t) |
| valid_files.extend(v) |
|
|
| if self.config("quick", False): |
| logger.info("quick_data") |
| train_files = train_files[:8] |
| valid_files = valid_files[:7] |
| if not valid_files: |
| logger.warning("No validation data found.") |
| return train_files, valid_files |
|
|
| def read_val_datalists(self, section="validate", data_list_files=None, val_key=None, merge=True): |
| """read the corresponding folds of the datalist for validation or inference""" |
| dataset_data = self.config(f"{section}#dataset#data") |
|
|
| if isinstance(dataset_data, list): |
| return dataset_data |
|
|
| if data_list_files is None: |
| data_list_files = dataset_data["data_list_files"] |
|
|
| if isinstance(data_list_files, str): |
| data_list_files = ConfigParser.load_config_file( |
| data_list_files |
| ) |
| else: |
| data_list_files = ensure_tuple(data_list_files) |
|
|
| if val_key is None: |
| val_key = dataset_data.get("key", None) |
|
|
| val_files, idx = [], 0 |
| for d in data_list_files: |
| if val_key is not None: |
| v_files, _ = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=-1, key=val_key) |
| else: |
| _, v_files = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=self.config("fold")) |
| logger.info(f"adding datalist ({idx} -- {val_key}): {d['datalist']} {len(v_files)}") |
| if merge: |
| val_files.extend(v_files) |
| else: |
| val_files.append(v_files) |
| idx += 1 |
|
|
| if self.config("quick", False): |
| logger.info("quick_data") |
| val_files = val_files[:8] if merge else [val_files[0][:8]] |
| return val_files |
|
|
| def get_train_preprocessing(self): |
| roi_size = self.config("train#dataset#preprocessing#roi_size") |
|
|
| train_xforms = [] |
| train_xforms.append(LoadTiffd(keys=["image", "label"])) |
| train_xforms.append(mt.EnsureTyped(keys=["image", "label"], data_type="tensor", dtype=torch.float)) |
| if self.config("prescale", True): |
| print("Prescaling images to 0..1") |
| train_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True)) |
| train_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True)) |
| train_xforms.append( |
| mt.ScaleIntensityRangePercentilesd( |
| keys="image", lower=1, upper=99, b_min=0.0, b_max=1.0, channel_wise=True, clip=True |
| ) |
| ) |
| train_xforms.append(mt.SpatialPadd(keys=["image", "label"], spatial_size=roi_size)) |
| train_xforms.append( |
| mt.RandSpatialCropd(keys=["image", "label"], roi_size=roi_size) |
| ) |
|
|
| |
| train_xforms.extend( |
| [ |
| mt.RandAffined( |
| keys=["image", "label"], |
| prob=0.5, |
| rotate_range=np.pi, |
| scale_range=[-0.5, 0.5], |
| mode=["bilinear", "nearest"], |
| spatial_size=roi_size, |
| cache_grid=True, |
| padding_mode="border", |
| ), |
| mt.RandAxisFlipd(keys=["image", "label"], prob=0.5), |
| mt.RandGaussianNoised(keys=["image"], prob=0.25, mean=0, std=0.1), |
| mt.RandAdjustContrastd(keys=["image"], prob=0.25, gamma=(1, 2)), |
| mt.RandGaussianSmoothd(keys=["image"], prob=0.25, sigma_x=(1, 2)), |
| mt.RandHistogramShiftd(keys=["image"], prob=0.25, num_control_points=3), |
| mt.RandGaussianSharpend(keys=["image"], prob=0.25), |
| ] |
| ) |
|
|
| train_xforms.append( |
| LabelsToFlows(keys="label", flow_key="flow") |
| ) |
|
|
| return train_xforms |
|
|
| def get_val_preprocessing(self): |
| val_xforms = [] |
| val_xforms.append(LoadTiffd(keys=["image", "label"], allow_missing_keys=True)) |
| val_xforms.append( |
| mt.EnsureTyped(keys=["image", "label"], data_type="tensor", dtype=torch.float, allow_missing_keys=True) |
| ) |
|
|
| if self.config("prescale", True): |
| print("Prescaling val images to 0..1") |
| val_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True)) |
|
|
| val_xforms.append( |
| mt.ScaleIntensityRangePercentilesd( |
| keys="image", lower=1, upper=99, b_min=0.0, b_max=1.0, channel_wise=True, clip=True |
| ) |
| ) |
| val_xforms.append(LabelsToFlows(keys="label", flow_key="flow", allow_missing_keys=True)) |
|
|
| return val_xforms |
|
|
| def get_train_dataset(self): |
| train_dataset_data = self.config("train#dataset#data") |
| if isinstance(train_dataset_data, list): |
| train_files = train_dataset_data |
| else: |
| train_files, _ = self.train_dataset_data |
| logger.info(f"train files {len(train_files)}") |
| return Dataset(data=train_files, transform=mt.Compose(self.train_preprocessing)) |
|
|
| def get_val_dataset(self): |
| """this is to be used for validation during training""" |
| val_dataset_data = self.config("validate#dataset#data") |
| if isinstance(val_dataset_data, list): |
| valid_files = val_dataset_data |
| else: |
| _, valid_files = self.train_dataset_data |
| return Dataset(data=valid_files, transform=mt.Compose(self.val_preprocessing)) |
|
|
| def set_val_datalist(self, datalist_py): |
| self.parser["validate#dataset#data"] = datalist_py |
| self._props.pop("val_loader", None) |
| self._props.pop("val_dataset", None) |
| self._props.pop("val_sampler", None) |
|
|
| def get_train_sampler(self): |
| if self.config("use_weighted_sampler", False): |
| data = self.train_dataset.data |
| logger.info(f"Using weighted sampler, first item {data[0]}") |
| sample_weights = 1.0 / torch.as_tensor( |
| [item.get("datalist_count", 1.0) for item in data], dtype=torch.float |
| ) |
| |
| |
| num_samples_per_epoch = self.config("num_samples_per_epoch", None) |
| if num_samples_per_epoch is None: |
| num_samples_per_epoch = len(data) |
| logger.warning( |
| "We are using weighted random sampler, but num_samples_per_epoch is not provided, " |
| f"so using {num_samples_per_epoch} full data length as a workaround!" |
| ) |
|
|
| if self.is_distributed: |
| return DistributedWeightedSampler( |
| self.train_dataset, shuffle=True, weights=sample_weights, num_samples=num_samples_per_epoch |
| ) |
| return WeightedRandomSampler(weights=sample_weights, num_samples=num_samples_per_epoch) |
|
|
| if self.is_distributed: |
| return DistributedSampler(self.train_dataset, shuffle=True) |
| return None |
|
|
| def get_val_sampler(self): |
| if self.is_distributed: |
| return DistributedSampler(self.val_dataset, shuffle=False) |
| return None |
|
|
| def get_train_loader(self): |
| sampler = self.train_sampler |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.config("train#batch_size"), |
| shuffle=(sampler is None), |
| sampler=sampler, |
| pin_memory=True, |
| num_workers=self.config("train#num_workers"), |
| ) |
|
|
| def get_val_loader(self): |
| sampler = self.val_sampler |
| return DataLoader( |
| self.val_dataset, |
| batch_size=self.config("validate#batch_size"), |
| shuffle=False, |
| sampler=sampler, |
| pin_memory=True, |
| num_workers=self.config("validate#num_workers"), |
| ) |
|
|
| def train(self): |
| config = self.config |
| distributed = self.is_distributed |
| sliding_inferrer = config("inferer#sliding_inferer") |
| use_amp = config("amp") |
|
|
| amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ |
| config("amp_dtype") |
| ] |
| if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): |
| amp_dtype = torch.float16 |
| logger.warning( |
| "bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually" |
| ) |
|
|
| use_gradscaler = use_amp and amp_dtype == torch.float16 |
| logger.info(f"Using grad scaler {use_gradscaler} amp_dtype {amp_dtype} use_amp {use_amp}") |
| grad_scaler = GradScaler(enabled=use_gradscaler) |
|
|
| loss_function = config("loss_function") |
| acc_function = config("key_metric") |
|
|
| ckpt_path = config("ckpt_path") |
| channels_last = config("channels_last") |
|
|
| num_epochs_per_saving = config("train#trainer#num_epochs_per_saving") |
| num_epochs_per_validation = config("train#trainer#num_epochs_per_validation") |
| num_epochs = config("train#trainer#max_epochs") |
| val_schedule_list = self.schedule_validation_epochs( |
| num_epochs=num_epochs, num_epochs_per_validation=num_epochs_per_validation |
| ) |
| logger.info(f"Scheduling validation loops at epochs: {val_schedule_list}") |
|
|
| train_loader = self.train_loader |
| val_loader = self.val_loader |
| optimizer = config("optimizer") |
| model = self.network |
|
|
| tb_writer = None |
| csv_path = progress_path = None |
|
|
| if self.global_rank == 0 and ckpt_path is not None: |
| |
| progress_path = os.path.join(ckpt_path, "progress.yaml") |
|
|
| tb_writer = SummaryWriter(log_dir=ckpt_path) |
| logger.info(f"Writing Tensorboard logs to {tb_writer.log_dir}") |
|
|
| if mlflow_is_imported: |
| if config("mlflow_tracking_uri", None) is not None: |
| mlflow.set_tracking_uri(config("mlflow_tracking_uri")) |
| mlflow.set_experiment("vista2d") |
|
|
| mlflow_run_name = config("mlflow_run_name", f'vista2d train fold{config("fold")}') |
| mlflow.start_run( |
| run_name=mlflow_run_name, log_system_metrics=config("mlflow_log_system_metrics", False) |
| ) |
| mlflow.log_params(self.parser.config) |
| mlflow.log_dict(self.parser.config, "hyper_parameters.yaml") |
|
|
| csv_path = os.path.join(ckpt_path, "accuracy_history.csv") |
| self.save_history_csv( |
| csv_path=csv_path, |
| header=["epoch", "metric", "loss", "iter", "time", "train_time", "validation_time", "epoch_time"], |
| ) |
|
|
| do_torch_save = ( |
| (self.global_rank == 0) and ckpt_path and config("ckpt_save") and not config("train#skip", False) |
| ) |
| best_ckpt_path = os.path.join(ckpt_path, "model.pt") |
| intermediate_ckpt_path = os.path.join(ckpt_path, "model_final.pt") |
|
|
| best_metric = float(config("best_metric", -1)) |
| start_epoch = config("start_epoch", 0) |
| best_metric_epoch = -1 |
| pre_loop_time = time.time() |
| report_num_epochs = num_epochs |
| train_time = validation_time = 0 |
| val_acc_history = [] |
|
|
| if start_epoch > 0: |
| val_schedule_list = [v for v in val_schedule_list if v >= start_epoch] |
| if len(val_schedule_list) == 0: |
| val_schedule_list = [start_epoch] |
| print(f"adjusted schedule_list {val_schedule_list}") |
|
|
| logger.info( |
| f"Using num_epochs => {num_epochs}\n " |
| f"Using start_epoch => {start_epoch}\n " |
| f"batch_size => {config('train#batch_size')} \n " |
| f"num_warmup_epochs => {config('train#trainer#num_warmup_epochs')} \n " |
| ) |
|
|
| lr_scheduler = config("lr_scheduler") |
| if lr_scheduler is not None and start_epoch > 0: |
| lr_scheduler.last_epoch = start_epoch |
|
|
| range_num_epochs = range(start_epoch, num_epochs) |
|
|
| if distributed: |
| dist.barrier() |
|
|
| if self.global_rank == 0 and tb_writer is not None and mlflow_is_imported and mlflow.is_tracking_uri_set(): |
| mlflow.log_param("len_train_set", len(train_loader.dataset)) |
| mlflow.log_param("len_val_set", len(val_loader.dataset)) |
|
|
| for epoch in range_num_epochs: |
| report_epoch = epoch |
|
|
| if distributed: |
| if isinstance(train_loader.sampler, DistributedSampler): |
| train_loader.sampler.set_epoch(epoch) |
| dist.barrier() |
|
|
| epoch_time = start_time = time.time() |
|
|
| train_loss, train_acc = 0, 0 |
|
|
| if not config("train#skip", False): |
| train_loss, train_acc = self.train_epoch( |
| model=model, |
| train_loader=train_loader, |
| optimizer=optimizer, |
| loss_function=loss_function, |
| acc_function=acc_function, |
| grad_scaler=grad_scaler, |
| epoch=report_epoch, |
| rank=self.rank, |
| global_rank=self.global_rank, |
| num_epochs=report_num_epochs, |
| use_amp=use_amp, |
| amp_dtype=amp_dtype, |
| channels_last=channels_last, |
| device=config("device"), |
| ) |
|
|
| train_time = time.time() - start_time |
| logger.info( |
| f"Latest training {report_epoch}/{report_num_epochs - 1} " |
| f"loss: {train_loss:.4f} time {train_time:.2f}s " |
| f"lr: {optimizer.param_groups[0]['lr']:.4e}" |
| ) |
|
|
| if self.global_rank == 0 and tb_writer is not None: |
| tb_writer.add_scalar("train/loss", train_loss, report_epoch) |
|
|
| if mlflow_is_imported and mlflow.is_tracking_uri_set(): |
| mlflow.log_metric("train/loss", train_loss, step=report_epoch) |
| mlflow.log_metric("train/epoch_time", train_time, step=report_epoch) |
|
|
| |
| val_acc_mean = -1 |
| if ( |
| len(val_schedule_list) > 0 |
| and epoch + 1 >= val_schedule_list[0] |
| and val_loader is not None |
| and len(val_loader) > 0 |
| ): |
| val_schedule_list.pop(0) |
|
|
| start_time = time.time() |
| torch.cuda.empty_cache() |
|
|
| val_loss, val_acc = self.val_epoch( |
| model=model, |
| val_loader=val_loader, |
| sliding_inferrer=sliding_inferrer, |
| loss_function=loss_function, |
| acc_function=acc_function, |
| epoch=report_epoch, |
| rank=self.rank, |
| global_rank=self.global_rank, |
| num_epochs=report_num_epochs, |
| use_amp=use_amp, |
| amp_dtype=amp_dtype, |
| channels_last=channels_last, |
| device=config("device"), |
| ) |
|
|
| torch.cuda.empty_cache() |
| validation_time = time.time() - start_time |
|
|
| val_acc_mean = float(np.mean(val_acc)) |
| val_acc_history.append((report_epoch, val_acc_mean)) |
|
|
| if self.global_rank == 0: |
| logger.info( |
| f"Latest validation {report_epoch}/{report_num_epochs - 1} " |
| f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc: {val_acc} time: {validation_time:.2f}s" |
| ) |
|
|
| if tb_writer is not None: |
| tb_writer.add_scalar("val/acc", val_acc_mean, report_epoch) |
| tb_writer.add_scalar("val/loss", val_loss, report_epoch) |
| if mlflow_is_imported and mlflow.is_tracking_uri_set(): |
| mlflow.log_metric("val/acc", val_acc_mean, step=report_epoch) |
| mlflow.log_metric("val/epoch_time", validation_time, step=report_epoch) |
|
|
| timing_dict = { |
| "time": f"{(time.time() - pre_loop_time) / 3600:.2f} hr", |
| "train_time": f"{train_time:.2f}s", |
| "validation_time": f"{validation_time:.2f}s", |
| "epoch_time": f"{time.time() - epoch_time:.2f}s", |
| } |
|
|
| if val_acc_mean > best_metric: |
| logger.info(f"New best metric ({best_metric:.6f} --> {val_acc_mean:.6f}). ") |
| best_metric, best_metric_epoch = val_acc_mean, report_epoch |
| save_time = 0 |
| if do_torch_save: |
| save_time = self.checkpoint_save( |
| ckpt=best_ckpt_path, model=model, epoch=best_metric_epoch, best_metric=best_metric |
| ) |
|
|
| if progress_path is not None: |
| self.save_progress_yaml( |
| progress_path=progress_path, |
| ckpt=best_ckpt_path if do_torch_save else None, |
| best_avg_score_epoch=best_metric_epoch, |
| best_avg_score=best_metric, |
| save_time=save_time, |
| **timing_dict, |
| ) |
| if csv_path is not None: |
| self.save_history_csv( |
| csv_path=csv_path, |
| epoch=report_epoch, |
| metric=f"{val_acc_mean:.4f}", |
| loss=f"{train_loss:.4f}", |
| iter=report_epoch * len(train_loader.dataset), |
| **timing_dict, |
| ) |
|
|
| |
| if epoch > max(20, num_epochs / 4) and 0 <= val_acc_mean < 0.01 and config("stop_on_lowacc", True): |
| logger.info( |
| f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. " |
| "Most likely optimization diverged, try setting a smaller learning_rate" |
| f" than {config('learning_rate')}" |
| ) |
| raise ValueError( |
| f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. " |
| "Most likely optimization diverged, try setting a smaller learning_rate" |
| f" than {config('learning_rate')}" |
| ) |
|
|
| |
| if do_torch_save and ((epoch + 1) % num_epochs_per_saving == 0 or (epoch + 1) >= num_epochs): |
| if report_epoch != best_metric_epoch: |
| self.checkpoint_save( |
| ckpt=intermediate_ckpt_path, model=model, epoch=report_epoch, best_metric=val_acc_mean |
| ) |
| else: |
| try: |
| shutil.copyfile(best_ckpt_path, intermediate_ckpt_path) |
| except Exception as err: |
| logger.warning(f"error copying {best_ckpt_path} {intermediate_ckpt_path} {err}") |
| pass |
|
|
| if lr_scheduler is not None: |
| lr_scheduler.step() |
|
|
| if self.global_rank == 0: |
| |
| time_remaining_estimate = train_time * (num_epochs - epoch) |
| if val_loader is not None and len(val_loader) > 0: |
| if validation_time == 0: |
| validation_time = train_time |
| time_remaining_estimate += validation_time * len(val_schedule_list) |
|
|
| logger.info( |
| f"Estimated remaining training time for the current model fold {config('fold')} is " |
| f"{time_remaining_estimate/3600:.2f} hr, " |
| f"running time {(time.time() - pre_loop_time)/3600:.2f} hr, " |
| f"est total time {(time.time() - pre_loop_time + time_remaining_estimate)/3600:.2f} hr \n" |
| ) |
|
|
| |
| train_loader = val_loader = optimizer = None |
|
|
| |
| logger.info(f"Checking to run final testing {config('run_final_testing')}") |
| if config("run_final_testing"): |
| if distributed: |
| dist.barrier() |
| _ckpt_name = best_ckpt_path if os.path.exists(best_ckpt_path) else intermediate_ckpt_path |
| if not os.path.exists(_ckpt_name): |
| logger.info(f"Unable to validate final no checkpoints found {best_ckpt_path}, {intermediate_ckpt_path}") |
| else: |
| |
| |
| gc.collect() |
| torch.cuda.empty_cache() |
| best_metric = self.run_final_testing( |
| pretrained_ckpt_path=_ckpt_name, |
| progress_path=progress_path, |
| best_metric_epoch=best_metric_epoch, |
| pre_loop_time=pre_loop_time, |
| ) |
|
|
| if ( |
| self.global_rank == 0 |
| and tb_writer is not None |
| and mlflow_is_imported |
| and mlflow.is_tracking_uri_set() |
| ): |
| mlflow.log_param("acc_testing", val_acc_mean) |
| mlflow.log_metric("acc_testing", val_acc_mean) |
|
|
| if tb_writer is not None: |
| tb_writer.flush() |
| tb_writer.close() |
|
|
| if mlflow_is_imported and mlflow.is_tracking_uri_set(): |
| mlflow.end_run() |
|
|
| logger.info( |
| f"=== DONE: best_metric: {best_metric:.4f} at epoch: {best_metric_epoch} of {report_num_epochs}." |
| f"Training time {(time.time() - pre_loop_time)/3600:.2f} hr." |
| ) |
| return best_metric |
|
|
| def run_final_testing(self, pretrained_ckpt_path, progress_path, best_metric_epoch, pre_loop_time): |
| logger.info("Running final best model testing set!") |
|
|
| |
| start_time = time.time() |
|
|
| self._props.pop("network", None) |
| self.parser["pretrained_ckpt_path"] = pretrained_ckpt_path |
| self.parser["validate#evaluator#postprocessing"] = None |
|
|
| val_acc_mean, val_loss, val_acc = self.validate(val_key="testing") |
| validation_time = f"{time.time() - start_time:.2f}s" |
| val_acc_mean = float(np.mean(val_acc)) |
| logger.info(f"Testing: loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc {val_acc} time {validation_time}") |
|
|
| if self.global_rank == 0 and progress_path is not None: |
| self.save_progress_yaml( |
| progress_path=progress_path, |
| ckpt=pretrained_ckpt_path, |
| best_avg_score_epoch=best_metric_epoch, |
| best_avg_score=val_acc_mean, |
| validation_time=validation_time, |
| run_final_testing=True, |
| time=f"{(time.time() - pre_loop_time) / 3600:.2f} hr", |
| ) |
| return val_acc_mean |
|
|
| def validate(self, validation_files=None, val_key=None, datalist=None): |
| if self.config("pretrained_ckpt_name", None) is None and self.config("pretrained_ckpt_path", None) is None: |
| self.parser["pretrained_ckpt_name"] = "model.pt" |
| logger.info("Using default model.pt checkpoint for validation.") |
|
|
| grouping = self.config("validate#grouping", False) |
| if validation_files is None: |
| validation_files = self.read_val_datalists("validate", datalist, val_key=val_key, merge=not grouping) |
| if len(validation_files) == 0: |
| logger.warning(f"No validation files found {datalist} {val_key}!") |
| return 0, 0, 0 |
| if not grouping or not isinstance(validation_files[0], (list, tuple)): |
| validation_files = [validation_files] |
| logger.info(f"validation file groups {len(validation_files)} grouping {grouping}") |
| val_acc_dict = {} |
|
|
| amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ |
| self.config("amp_dtype") |
| ] |
| if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): |
| amp_dtype = torch.float16 |
| logger.warning( |
| "bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually" |
| ) |
|
|
| for datalist_id, group_files in enumerate(validation_files): |
| self.set_val_datalist(group_files) |
| val_loader = self.val_loader |
|
|
| start_time = time.time() |
| val_loss, val_acc = self.val_epoch( |
| model=self.network, |
| val_loader=val_loader, |
| sliding_inferrer=self.config("inferer#sliding_inferer"), |
| loss_function=self.config("loss_function"), |
| acc_function=self.config("key_metric"), |
| rank=self.rank, |
| global_rank=self.global_rank, |
| use_amp=self.config("amp"), |
| amp_dtype=amp_dtype, |
| post_transforms=self.config("validate#evaluator#postprocessing"), |
| channels_last=self.config("channels_last"), |
| device=self.config("device"), |
| ) |
| val_acc_mean = float(np.mean(val_acc)) |
| logger.info( |
| f"Validation {datalist_id} complete, loss_avg: {val_loss:.4f} " |
| f"acc_avg: {val_acc_mean:.4f} acc {val_acc} time {time.time() - start_time:.2f}s" |
| ) |
| val_acc_dict[datalist_id] = val_acc_mean |
| for k, v in val_acc_dict.items(): |
| logger.info(f"group: {k} => {v:.4f}") |
| val_acc_mean = sum(val_acc_dict.values()) / len(val_acc_dict.values()) |
| logger.info(f"Testing group score average: {val_acc_mean:.4f}") |
| return val_acc_mean, val_loss, val_acc |
|
|
| def infer(self, infer_files=None, infer_key=None, datalist=None): |
| if self.config("pretrained_ckpt_name", None) is None and self.config("pretrained_ckpt_path", None) is None: |
| self.parser["pretrained_ckpt_name"] = "model.pt" |
| logger.info("Using default model.pt checkpoint for inference.") |
|
|
| if infer_files is None: |
| infer_files = self.read_val_datalists("infer", datalist, val_key=infer_key, merge=True) |
| if len(infer_files) == 0: |
| logger.warning(f"no file to infer {datalist} {infer_key}.") |
| return |
| logger.info(f"inference files {len(infer_files)}") |
| self.set_val_datalist(infer_files) |
| val_loader = self.val_loader |
|
|
| amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ |
| self.config("amp_dtype") |
| ] |
| if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): |
| amp_dtype = torch.bfloat16 |
| logger.warning( |
| "bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually" |
| ) |
|
|
| start_time = time.time() |
| self.val_epoch( |
| model=self.network, |
| val_loader=val_loader, |
| sliding_inferrer=self.config("inferer#sliding_inferer"), |
| loss_function=None, |
| acc_function=None, |
| rank=self.rank, |
| global_rank=self.global_rank, |
| use_amp=self.config("amp"), |
| amp_dtype=amp_dtype, |
| post_transforms=self.config("infer#evaluator#postprocessing"), |
| channels_last=self.config("channels_last"), |
| device=self.config("device"), |
| ) |
| logger.info(f"Inference complete time {time.time() - start_time:.2f}s") |
| return |
|
|
| @torch.no_grad() |
| def val_epoch( |
| self, |
| model, |
| val_loader, |
| sliding_inferrer, |
| loss_function=None, |
| acc_function=None, |
| epoch=0, |
| rank=0, |
| global_rank=0, |
| num_epochs=0, |
| use_amp=True, |
| amp_dtype=torch.float16, |
| post_transforms=None, |
| channels_last=False, |
| device=None, |
| ): |
| model.eval() |
| distributed = dist.is_available() and dist.is_initialized() |
| memory_format = torch.channels_last if channels_last else torch.preserve_format |
|
|
| run_loss = CumulativeAverage() |
| run_acc = CumulativeAverage() |
| run_loss.append(torch.tensor(0, device=device), count=0) |
|
|
| avg_loss = avg_acc = 0 |
| start_time = time.time() |
|
|
| |
| |
| |
| |
| nonrepeated_data_length = len(val_loader.dataset) |
| sampler = val_loader.sampler |
| if distributed and isinstance(sampler, DistributedSampler) and not sampler.drop_last: |
| nonrepeated_data_length = len(range(sampler.rank, len(sampler.dataset), sampler.num_replicas)) |
|
|
| for idx, batch_data in enumerate(val_loader): |
| data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
| filename = batch_data["image"].meta[ImageMetaKey.FILENAME_OR_OBJ] |
| batch_size = data.shape[0] |
| loss = acc = None |
|
|
| with autocast(enabled=use_amp, dtype=amp_dtype): |
| logits = sliding_inferrer(inputs=data, network=model) |
| data = None |
|
|
| |
| if loss_function is not None: |
| target = batch_data["flow"].as_subclass(torch.Tensor).to(device=logits.device) |
| loss = loss_function(logits, target) |
| run_loss.append(loss.to(device=device), count=batch_size) |
| target = None |
|
|
| pred_mask_all = [] |
|
|
| for b_ind in range(logits.shape[0]): |
| pred_mask, p = LogitsToLabels()(logits=logits[b_ind], filename=filename) |
| pred_mask_all.append(pred_mask) |
|
|
| if acc_function is not None: |
| label = batch_data["label"].as_subclass(torch.Tensor) |
|
|
| for b_ind in range(label.shape[0]): |
| acc = acc_function(pred_mask_all[b_ind], label[b_ind, 0].long()) |
| acc = acc.detach().clone() if isinstance(acc, torch.Tensor) else torch.tensor(acc) |
|
|
| if idx < nonrepeated_data_length: |
| run_acc.append(acc.to(device=device), count=1) |
| else: |
| run_acc.append(torch.zeros_like(acc, device=device), count=0) |
| label = None |
|
|
| avg_loss = loss.cpu() if loss is not None else 0 |
| avg_acc = acc.cpu().numpy() if acc is not None else 0 |
|
|
| logger.info( |
| f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} " |
| f"loss: {avg_loss:.4f} acc {avg_acc} time {time.time() - start_time:.2f}s" |
| ) |
|
|
| if post_transforms: |
| seg = torch.from_numpy(np.stack(pred_mask_all, axis=0).astype(np.int32)).unsqueeze(1) |
| batch_data["seg"] = convert_to_dst_type( |
| seg, batch_data["image"], dtype=torch.int32, device=torch.device("cpu") |
| )[0] |
| for bd in decollate_batch(batch_data): |
| post_transforms(bd) |
|
|
| start_time = time.time() |
|
|
| label = target = data = batch_data = None |
|
|
| if distributed: |
| dist.barrier() |
|
|
| avg_loss = run_loss.aggregate() |
| avg_acc = run_acc.aggregate() |
|
|
| if np.any(avg_acc < 0): |
| dist.barrier() |
| logger.warning(f"Avg accuracy is negative ({avg_acc}), something went wrong!!!!!") |
|
|
| return avg_loss, avg_acc |
|
|
| def train_epoch( |
| self, |
| model, |
| train_loader, |
| optimizer, |
| loss_function, |
| acc_function, |
| grad_scaler, |
| epoch, |
| rank, |
| global_rank=0, |
| num_epochs=0, |
| use_amp=True, |
| amp_dtype=torch.float16, |
| channels_last=False, |
| device=None, |
| ): |
| model.train() |
| memory_format = torch.channels_last if channels_last else torch.preserve_format |
|
|
| run_loss = CumulativeAverage() |
|
|
| start_time = time.time() |
| avg_loss = avg_acc = 0 |
| for idx, batch_data in enumerate(train_loader): |
| data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
| target = batch_data["flow"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| with autocast(enabled=use_amp, dtype=amp_dtype): |
| logits = model(data) |
|
|
| |
| loss = loss_function(logits.float(), target) |
|
|
| grad_scaler.scale(loss).backward() |
| grad_scaler.step(optimizer) |
| grad_scaler.update() |
|
|
| batch_size = data.shape[0] |
|
|
| run_loss.append(loss, count=batch_size) |
| avg_loss = run_loss.aggregate() |
|
|
| logger.info( |
| f"Epoch {epoch}/{num_epochs} {idx}/{len(train_loader)} " |
| f"loss: {avg_loss:.4f} time {time.time() - start_time:.2f}s " |
| ) |
| start_time = time.time() |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| data = None |
| target = None |
| batch_data = None |
|
|
| return avg_loss, avg_acc |
|
|
| def save_history_csv(self, csv_path=None, header=None, **kwargs): |
| if csv_path is not None: |
| if header is not None: |
| with open(csv_path, "a") as myfile: |
| wrtr = csv.writer(myfile, delimiter="\t") |
| wrtr.writerow(header) |
| if len(kwargs): |
| with open(csv_path, "a") as myfile: |
| wrtr = csv.writer(myfile, delimiter="\t") |
| wrtr.writerow(list(kwargs.values())) |
|
|
| def save_progress_yaml(self, progress_path=None, ckpt=None, **report): |
| if ckpt is not None: |
| report["model"] = ckpt |
|
|
| report["date"] = str(datetime.now())[:19] |
|
|
| if progress_path is not None: |
| yaml.add_representer( |
| float, lambda dumper, value: dumper.represent_scalar("tag:yaml.org,2002:float", f"{value:.4f}") |
| ) |
| with open(progress_path, "a") as progress_file: |
| yaml.dump([report], stream=progress_file, allow_unicode=True, default_flow_style=None, sort_keys=False) |
|
|
| logger.info("Progress:" + ",".join(f" {k}: {v}" for k, v in report.items())) |
|
|
| def checkpoint_save(self, ckpt: str, model: torch.nn.Module, **kwargs): |
| |
| save_time = time.time() |
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
| state_dict = model.module.state_dict() |
| else: |
| state_dict = model.state_dict() |
|
|
| if self.config("compile", False): |
| |
| state_dict = OrderedDict( |
| (k[len("_orig_mod.") :] if k.startswith("_orig_mod.") else k, v) for k, v in state_dict.items() |
| ) |
|
|
| torch.save({"state_dict": state_dict, "config": self.parser.config, **kwargs}, ckpt) |
|
|
| save_time = time.time() - save_time |
| logger.info(f"Saving checkpoint process: {ckpt}, {kwargs}, save_time {save_time:.2f}s") |
|
|
| return save_time |
|
|
| def checkpoint_load(self, ckpt: str, model: torch.nn.Module, **kwargs): |
| |
| if not os.path.isfile(ckpt): |
| logger.warning("Invalid checkpoint file: " + str(ckpt)) |
| return |
| checkpoint = torch.load(ckpt, map_location="cpu") |
|
|
| model.load_state_dict(checkpoint["state_dict"], strict=True) |
| epoch = checkpoint.get("epoch", 0) |
| best_metric = checkpoint.get("best_metric", 0) |
|
|
| if self.config("continue", False): |
| if "epoch" in checkpoint: |
| self.parser["start_epoch"] = checkpoint["epoch"] |
| if "best_metric" in checkpoint: |
| self.parser["best_metric"] = checkpoint["best_metric"] |
|
|
| logger.info( |
| f"=> loaded checkpoint {ckpt} (epoch {epoch}) " |
| f"(best_metric {best_metric}) setting start_epoch {self.config('start_epoch')}" |
| ) |
| self.parser["start_epoch"] = int(self.config("start_epoch")) + 1 |
| return |
|
|
| def schedule_validation_epochs(self, num_epochs, num_epochs_per_validation=None, fraction=0.16) -> list: |
| """ |
| Schedule of epochs to validate (progressively more frequently) |
| num_epochs - total number of epochs |
| num_epochs_per_validation - if provided use a linear schedule with this step |
| init_step |
| """ |
|
|
| if num_epochs_per_validation is None: |
| x = (np.sin(np.linspace(0, np.pi / 2, max(10, int(fraction * num_epochs)))) * num_epochs).astype(int) |
| x = np.cumsum(np.sort(np.diff(np.unique(x)))[::-1]) |
| x[-1] = num_epochs |
| x = x.tolist() |
| else: |
| if num_epochs_per_validation >= num_epochs: |
| x = [num_epochs_per_validation] |
| else: |
| x = list(range(num_epochs_per_validation, num_epochs, num_epochs_per_validation)) |
|
|
| if len(x) == 0: |
| x = [0] |
|
|
| return x |
|
|
|
|
| def main(**kwargs) -> None: |
| workflow = VistaCell(**kwargs) |
| workflow.initialize() |
| workflow.run() |
| workflow.finalize() |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
|
|
| from pathlib import Path |
|
|
| sys.path.append(str(Path(__file__).parent.parent)) |
|
|
| |
|
|
| fire, fire_is_imported = optional_import("fire") |
| if fire_is_imported: |
| fire.Fire(main) |
| else: |
| print("Missing package: fire") |
|
|