Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import os | |
| import random | |
| import uuid | |
| from contextlib import nullcontext | |
| from copy import deepcopy | |
| from datetime import datetime as dt | |
| from functools import partial | |
| from math import log2 | |
| from time import sleep, time | |
| from typing import Any, Dict | |
| import git | |
| import numpy as np | |
| import psutil | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.data.distributed | |
| import wandb | |
| from PIL import Image | |
| from torch import distributed as dist | |
| from torch import optim | |
| from torch.nn.parallel.distributed import DistributedDataParallel | |
| from torch.utils.data import DataLoader, RandomSampler, SequentialSampler | |
| from tqdm import tqdm | |
| import unik3d.datasets as datasets | |
| from unik3d.datasets import (ConcatDataset, DistributedSamplerNoDuplicate, | |
| collate_fn, get_weights) | |
| from unik3d.models import UniK3D | |
| from unik3d.ops.scheduler import CosineScheduler | |
| from unik3d.utils import (barrier, format_seconds, is_main_process, | |
| log_train_artifacts, validate) | |
| from unik3d.utils.distributed import (create_local_process_group, | |
| local_broadcast_process_authkey, | |
| setup_multi_processes, setup_slurm, | |
| sync_string_across_gpus, | |
| sync_tensor_across_gpus) | |
| from unik3d.utils.ema_torch import (DummyExponentialMovingAverage, | |
| ExponentialMovingAverage) | |
| from unik3d.utils.misc import calculate_mean_values | |
| EMA_INTERVAL = 10 | |
| EMA_TAU = 10000 | |
| EMA_START = 50000 | |
| MAP_DTYPE = { | |
| "f16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| "f32": torch.float32, | |
| } | |
| def aggregate_sync_losses(dict_: dict[str, torch.Tensor], device): | |
| keys = list(dict_.keys()) | |
| values = torch.tensor(list(dict_.values()), device=device) | |
| keys = sync_string_across_gpus(keys, device) | |
| values = sync_tensor_across_gpus(values, dim=0).cpu().tolist() | |
| dict_ = calculate_mean_values(keys, values) | |
| return dict_ | |
| def main_worker(config: Dict[str, Any], args: argparse.Namespace): | |
| current_process = psutil.Process(os.getpid()) | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| seed = config["generic"]["seed"] | |
| if not args.distributed: | |
| args.rank = 0 | |
| args.local_rank = 0 | |
| args.world_size = 1 | |
| else: | |
| # initializes the distributed backend which will take care of synchronizing nodes/GPUs | |
| setup_multi_processes(config) | |
| is_slurm = "SLURM_PROCID" in os.environ | |
| if is_slurm: | |
| setup_slurm("nccl", port=args.master_port) | |
| args.rank = int(os.environ["RANK"]) | |
| args.world_size = int(os.environ["WORLD_SIZE"]) | |
| args.local_rank = device = int(os.environ["LOCAL_RANK"]) | |
| if not is_slurm: | |
| import datetime | |
| dist.init_process_group( | |
| "nccl", | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| timeout=datetime.timedelta(seconds=30 * 60), | |
| ) | |
| torch.cuda.set_device(device) | |
| create_local_process_group() | |
| local_broadcast_process_authkey() | |
| print( | |
| f"Start running DDP on: {args.rank} (local: {args.local_rank}) with seed {seed + args.rank}." | |
| ) | |
| config["training"]["batch_size"] = int( | |
| config["training"]["batch_size"] / args.world_size | |
| ) | |
| dist.barrier() | |
| # Fix seed | |
| # Different for every machine to avoid sampling | |
| # the same element across machines | |
| seed = seed + args.rank | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| batch_size = config["training"]["batch_size"] | |
| if is_main_process(): | |
| print("Config: ", args.config_file) | |
| print( | |
| f"Torch version:{torch.__version__}, cuda:{torch.version.cuda}, cudnn:{torch.backends.cudnn.version()}, threads:{torch.get_num_threads()}" | |
| ) | |
| print("BatchSize per GPU: ", batch_size) | |
| print( | |
| f"Divided into {config['training']['nsteps_accumulation_gradient']} accumulation step" | |
| ) | |
| ############################## | |
| ########### MODEL ############ | |
| ############################## | |
| # Build model | |
| model = UniK3D(config).to(device) | |
| model.eval() | |
| print(f"MODEL: {model.__class__.__name__} at {model.device}") | |
| torch.cuda.empty_cache() | |
| if args.distributed: | |
| model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
| model = DistributedDataParallel( | |
| model, | |
| find_unused_parameters=False, | |
| device_ids=[device], | |
| output_device=device, | |
| ) | |
| ############################## | |
| ######### OPTIMIZER ########## | |
| ############################## | |
| dtype_16bit = config["training"]["f16"] | |
| is_16bit = dtype_16bit != "f32" | |
| clipping = config["training"].get("clipping", None) | |
| # Optimize | |
| ddp_model = model.module if args.distributed else model | |
| params = ddp_model.get_params(config) | |
| optimizer = optim.AdamW( | |
| params, | |
| eps=6e-8 if is_16bit else 1e-8, # smallest subnormal fp16 number is 5.96e-8 | |
| # amsgrad=is_16bit, # use max instead of avg v_hat, avoid small number divisions? | |
| ) | |
| # Load Model: | |
| step = 0 | |
| if config["training"].get("pretrained", None) is not None: | |
| ddp_model.load_pretrained(config["training"]["pretrained"]) | |
| pretrained = torch.load( | |
| config["training"]["pretrained"], map_location="cpu", weights_only=False | |
| ) | |
| try: | |
| optimizer.load_state_dict(pretrained["optimizer"]) | |
| except Exception as e: | |
| if is_main_process(): | |
| print("Could not load optimizer state dict:", e) | |
| step = pretrained.get("step", 0) | |
| ddp_model.pixel_decoder.steps = step | |
| # EMA | |
| ema_class = ( | |
| ExponentialMovingAverage | |
| if config["training"]["ema"] > 0.0 | |
| else DummyExponentialMovingAverage | |
| ) | |
| ema_handle = ema_class( | |
| ddp_model.parameters_grad(), | |
| 1 - (1 - config["training"]["ema"]) * EMA_INTERVAL, | |
| update_after_step=config["training"]["warmup_iters"] / EMA_INTERVAL, | |
| switch=True, | |
| tau=EMA_TAU // EMA_INTERVAL, | |
| ) | |
| setattr(ema_handle, "num_updates", step // EMA_INTERVAL) | |
| ############################## | |
| ######### GENERICS ########### | |
| ############################## | |
| resize_method = config["data"].get("resize_method", "hard") | |
| crop = config["data"].get("crop", "garg") | |
| augmentations_db = config["data"].get("augmentations", {}) | |
| shape_constraints = config["data"].get("shape_constraints", {}) | |
| image_shape = config["data"]["image_shape"] | |
| mini = config["data"]["mini"] | |
| nsteps_accumulation_gradient = config["training"]["nsteps_accumulation_gradient"] | |
| batch_size = config["training"]["batch_size"] | |
| clipping_fn = torch.nn.utils.clip_grad_norm_ | |
| is_shell = int(os.environ.get("SHELL_JOB", 0)) | |
| run_id = sync_string_across_gpus( | |
| [f"{dt.now().strftime('%d-%h_%H-%M')}-{uuid.uuid4()}"], device | |
| )[0] | |
| if not is_shell and is_main_process(): | |
| repo_folder = os.path.dirname(os.path.realpath(__file__)) | |
| try: | |
| repo = git.Repo(repo_folder) | |
| current_head = repo.head if repo.head.is_detached else repo.active_branch | |
| notes = f"MESSAGE: {current_head.commit.message} HASH:{current_head.commit.hexsha} BRANCH:{current_head.name}" | |
| except: | |
| print(f"problem with {repo_folder}, does it exist?") | |
| notes = "" | |
| # restore the original batchsize, not acquired by other calls from now on | |
| if args.distributed: | |
| config["training"]["batch_size"] = ( | |
| config["training"]["batch_size"] * args.world_size | |
| ) | |
| wandb.init( | |
| project="UniK3D", | |
| name=run_id, | |
| config=config, | |
| tags=None, | |
| notes=notes, | |
| dir=os.environ.get("WANDB_HOME", os.environ.get("TMPDIR", "/tmp")), | |
| ) | |
| wandb.watch(model) | |
| ############################## | |
| ########## DATASET ########### | |
| ############################## | |
| # Datasets loading | |
| train_datasets, val_datasets = {}, {} | |
| if is_main_process(): | |
| print("Loading training datasets...") | |
| dims = 0 | |
| for dataset in config["data"]["train_datasets"]: | |
| assert hasattr(datasets, dataset), f"{dataset} not a custom dataset" | |
| train_dataset: datasets.BaseDataset = getattr(datasets, dataset) | |
| train_datasets[dataset] = train_dataset( | |
| image_shape=image_shape, | |
| split_file=train_dataset.train_split, | |
| test_mode=False, | |
| crop=crop, | |
| augmentations_db=augmentations_db, | |
| shape_constraints=shape_constraints, | |
| normalize=config["data"].get("normalization", "imagenet"), | |
| resize_method=resize_method, | |
| mini=mini, | |
| num_frames=config["data"].get("num_frames", 1), | |
| fps_range=[1, 5], | |
| num_copies=config["data"]["pair"], | |
| ) | |
| dim = ( | |
| train_datasets[dataset].dataset._addr.numel() * 8 | |
| + train_datasets[dataset].dataset._lst.numel() | |
| ) / (2**20) | |
| if hasattr(train_datasets[dataset], "sequences"): | |
| dim += ( | |
| train_datasets[dataset].sequences._addr.numel() * 8 | |
| + train_datasets[dataset].sequences._lst.numel() | |
| ) / (2**20) | |
| dims = dims + dim | |
| if is_main_process(): | |
| print(f"{dataset}: {dim:.1f}MB") | |
| print(f"All training datasets loaded, with total size: {dims:.1f}MB") | |
| barrier() | |
| assert batch_size % config["data"]["pair"] == 0 | |
| batch_size = batch_size // config["data"]["pair"] | |
| assert batch_size % nsteps_accumulation_gradient == 0 | |
| batch_chunk = batch_size // nsteps_accumulation_gradient | |
| train_dataset = ConcatDataset( | |
| list(train_datasets.values()), | |
| shape_constraints=shape_constraints, | |
| ) | |
| if is_main_process(): | |
| print("Loading validation datasets...") | |
| for dataset in config["data"]["val_datasets"]: | |
| val_dataset: datasets.BaseDataset = getattr(datasets, dataset) | |
| val_datasets[dataset] = val_dataset( | |
| image_shape=image_shape, | |
| split_file=val_dataset.test_split, | |
| test_mode=True, | |
| crop=crop, | |
| shape_constraints=shape_constraints, | |
| augmentations_db=augmentations_db, | |
| normalize=config["data"].get("normalization", "imagenet"), | |
| resize_method=resize_method, | |
| num_frames=1, | |
| mini=1.0, | |
| num_copies=1, | |
| ) | |
| # Dataset samplers, create distributed sampler pinned to rank | |
| if args.distributed: | |
| sampling = deepcopy(config["data"]["sampling"]) | |
| weights, num_samples = get_weights(train_datasets, sampling) | |
| train_sampler = torch.utils.data.WeightedRandomSampler( | |
| weights, num_samples, replacement=True | |
| ) | |
| valid_samplers = { | |
| k: DistributedSamplerNoDuplicate( | |
| v, | |
| num_replicas=args.world_size, | |
| rank=args.rank, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| for k, v in val_datasets.items() | |
| } | |
| else: | |
| train_sampler = RandomSampler(train_dataset) | |
| valid_samplers = {k: SequentialSampler(v) for k, v in val_datasets.items()} | |
| train_sampler = torch.utils.data.BatchSampler( | |
| train_sampler, batch_size=batch_size, drop_last=True | |
| ) | |
| # Dataset loader | |
| val_batch_size = 1 | |
| num_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", 4)) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| num_workers=num_workers, | |
| sampler=train_sampler, | |
| pin_memory=True, | |
| collate_fn=partial(collate_fn, is_batched=True), | |
| persistent_workers=True if num_workers else None, | |
| ) | |
| val_loaders = { | |
| name_dataset: DataLoader( | |
| dataset, | |
| batch_size=val_batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| sampler=valid_samplers[name_dataset], | |
| pin_memory=True, | |
| drop_last=False, | |
| collate_fn=partial(collate_fn, is_batched=False), | |
| ) | |
| for name_dataset, dataset in val_datasets.items() | |
| } | |
| # SCHEDULERS! | |
| scheduler_wd = CosineScheduler( | |
| optimizer, | |
| key="weight_decay", | |
| init_value=config["training"]["wd"], | |
| base_value=config["training"]["wd"], | |
| final_value=config["training"]["wd_final"], | |
| warmup_iters=0, | |
| total_iters=config["training"]["n_iters"], | |
| flat_iters=config["training"]["warmup_iters"], | |
| step_init=step - 1, | |
| ) | |
| scheduler_lr = CosineScheduler( | |
| optimizer, | |
| key="lr", | |
| init_value=config["training"]["lr"] * config["training"].get("lr_warmup", 1.0), | |
| final_value=config["training"]["lr_final"], | |
| warmup_iters=5000, | |
| flat_iters=config["training"]["warmup_iters"], | |
| total_iters=config["training"]["n_iters"], | |
| step_init=step - 1, | |
| ) | |
| scheduler_betas = CosineScheduler( | |
| optimizer, | |
| key="betas", | |
| init_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, | |
| base_value=0.85 if config["training"].get("cycle_betas", True) else 0.9, | |
| final_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, | |
| warmup_iters=config["training"]["warmup_iters"], | |
| total_iters=config["training"]["n_iters"], | |
| step_init=step - 1, | |
| ) | |
| # Set loss scaler for half precision training + sanity zeroing grads | |
| dtype = MAP_DTYPE[dtype_16bit] | |
| if not torch.cuda.is_bf16_supported() and is_16bit: | |
| dtype = torch.float16 | |
| context = torch.autocast(device_type="cuda", dtype=dtype, enabled=is_16bit) | |
| # use float16 to check for instability at inference an avoid bfloat16 for coarseness | |
| context_val = torch.autocast( | |
| device_type="cuda", dtype=torch.float16, enabled=is_16bit | |
| ) | |
| optimizer.zero_grad(set_to_none=True) | |
| ############################## | |
| ########## TRAINING ########## | |
| ############################## | |
| # Remember that if i-th layer is frozen, this will break gradient checkpointing | |
| # in layer i+1-th. This is because CheckpointFunction treats the i+1-th input as | |
| # without gradient, thus the i+1-th layer does not have grads (?). To solve it, | |
| # just add requires_grad_() to the inputs coming from the frozen layer | |
| ddp_model.train() | |
| start = time() | |
| n_steps = config["training"]["n_iters"] | |
| init_steps = int(step) | |
| track_pbar = is_shell | |
| if is_main_process(): | |
| print("Is a shell job?", is_shell) | |
| print("Use dtype:", dtype if is_16bit else torch.float32) | |
| print( | |
| f'Train for {config["training"]["n_iters"]} steps, validate every {config["training"]["validation_interval"]} steps' | |
| ) | |
| print(f"START with {num_workers} workers") | |
| if track_pbar: | |
| pbar = tqdm(total=n_steps - init_steps) | |
| scaler = torch.amp.GradScaler( | |
| "cuda", | |
| init_scale=2**14 if dtype_16bit == "f16" else 2**40, | |
| enabled=is_16bit, | |
| growth_factor=1.2, | |
| backoff_factor=0.8, | |
| growth_interval=500, | |
| ) | |
| track_losses, track_grad = {}, {} | |
| system_memory = dict(psutil.virtual_memory()._asdict())["available"] / 2**30 | |
| cpid_memory = current_process.memory_info()[0] / 2.0**30 | |
| gpu_mem = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30 | |
| while True: | |
| for j, batches in enumerate(train_loader): | |
| system_memory = ( | |
| 0.99 * system_memory | |
| + 0.01 * dict(psutil.virtual_memory()._asdict())["available"] / 2**30 | |
| ) | |
| cpid_memory = ( | |
| 0.99 * cpid_memory + 0.01 * current_process.memory_info()[0] / 2.0**30 | |
| ) | |
| gpu_mem = ( | |
| 0.99 * gpu_mem | |
| + 0.01 | |
| * (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) | |
| / 2**30 | |
| ) | |
| if j % 1000 == 0 and is_main_process(): | |
| print(f"System information at step {j}") | |
| print(f"System-wide RAM available: {system_memory:.2f}GB") | |
| print(f"CPU utilization: {psutil.cpu_percent(interval=None)}%") | |
| print(f"GPU memory utilized: {gpu_mem:.2f}GB") | |
| batches["data"] = { | |
| k: v.to(model.device, non_blocking=True) | |
| for k, v in batches["data"].items() | |
| } | |
| for idx in range(nsteps_accumulation_gradient): | |
| batch = {} | |
| batch_slice = slice(idx * batch_chunk, (idx + 1) * batch_chunk) | |
| batch["data"] = {k: v[batch_slice] for k, v in batches["data"].items()} | |
| batch["img_metas"] = batches["img_metas"][batch_slice] | |
| with ( | |
| model.no_sync() | |
| if idx < nsteps_accumulation_gradient - 1 | |
| else nullcontext() | |
| ): | |
| with context: | |
| preds, losses = model(batch["data"], batch["img_metas"]) | |
| loss = sum(losses["opt"].values()) | |
| scaler.scale(loss).backward() | |
| losses_dict = { | |
| k: v.detach() for loss in losses.values() for k, v in loss.items() | |
| } | |
| track_losses.update( | |
| { | |
| k: track_losses.get(k, 0.0) | |
| + torch.nan_to_num(v, nan=1e5, posinf=1e5, neginf=1e5) | |
| for k, v in losses_dict.items() | |
| } | |
| ) | |
| ddp_model.loss_history = track_losses | |
| if clipping is not None: | |
| scaler.unscale_(optimizer) | |
| grad_norm = clipping_fn(ddp_model.parameters_grad(), clipping) | |
| if torch.isfinite(grad_norm): | |
| track_losses.update( | |
| {"Grad_Norm": track_losses.get("Grad_Norm", 0.0) + grad_norm} | |
| ) | |
| # there is a deeper issue, either log/sqrt of negative loss | |
| # or the inputs create large values and destroy model weights | |
| if is_16bit and scaler.get_scale() < 1: | |
| raise ValueError("Scale went less than 1, ISSUE!!!") | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler_wd.step() | |
| scheduler_lr.step() | |
| scheduler_betas.step() | |
| model.module.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| if step % EMA_INTERVAL == 0: | |
| ema_handle.update() | |
| if is_main_process() and track_pbar: | |
| pbar.update(1) | |
| step += 1 | |
| # LOGGING | |
| if step % 100 == 0 and is_main_process(): | |
| log_num = min(10, preds["depth"].shape[0]) | |
| log_train_artifacts( | |
| batch["data"]["image"][-log_num:, 0].float(), | |
| ( | |
| batch["data"]["depth"][-log_num:, 0].float() | |
| if "depth" in batch["data"] | |
| else [] | |
| ), | |
| preds["depth"][-log_num:, 0].detach().float(), | |
| infos={ | |
| k: v[-log_num:, 0] for k, v in preds.get("infos", {}).items() | |
| }, | |
| step=step, | |
| ) | |
| if step % 50 == 0: | |
| track_losses = { | |
| k: v / (50 * nsteps_accumulation_gradient) | |
| for k, v in track_losses.items() | |
| } | |
| # grad norm is for every step! | |
| track_losses["Grad_Norm"] = ( | |
| track_losses["Grad_Norm"] * nsteps_accumulation_gradient | |
| ) | |
| track_losses = aggregate_sync_losses(track_losses, device=model.device) | |
| if is_main_process(): | |
| elapsed = int(time() - start) | |
| eta = int(elapsed * (n_steps - step) / max(1, step - init_steps)) | |
| print( | |
| f"Step {step}/{n_steps} [{format_seconds(elapsed)}<{format_seconds(eta)}]" | |
| ) | |
| try: | |
| wandb.log( | |
| { | |
| **{f"Train/{k}": v for k, v in track_losses.items()}, | |
| **{f"Train/lr": scheduler_lr.get()[-1]}, | |
| **{f"Train/wd": scheduler_wd.get()[-2]}, | |
| **{f"Train/scale_f16": log2(scaler.get_scale())}, | |
| }, | |
| step=step, | |
| ) | |
| except Exception as e: | |
| print("Not logging loss because of:", e) | |
| if step % 100 == 0: | |
| log_loss_dict = { | |
| f"Train/{k}": v for k, v in track_losses.items() | |
| } | |
| print( | |
| ", ".join( | |
| [f"{k}: {v:.5f}" for k, v in log_loss_dict.items()] | |
| ) | |
| ) | |
| track_losses = {} # reinit every 50 steps, average the current 50 steps | |
| # Validation | |
| is_last_step = step >= config["training"]["n_iters"] | |
| is_validation = step % config["training"]["validation_interval"] == 0 | |
| if is_last_step or is_validation: | |
| torch.cuda.empty_cache() | |
| barrier() | |
| if is_main_process(): | |
| print(f"Validation at {step}th step...") | |
| ddp_model.eval() | |
| start_validation = time() | |
| with torch.no_grad(), ema_handle.average_parameters(): | |
| validate( | |
| model, | |
| test_loaders=val_loaders, | |
| step=step, | |
| run_id=run_id, | |
| idxs=(64, 96, 224, 256), # random | |
| context=context_val, | |
| ) | |
| if is_main_process(): | |
| print(f"Elapsed: {format_seconds(int(time() - start_validation))}") | |
| ddp_model.train() | |
| torch.cuda.empty_cache() | |
| if step >= config["training"]["n_iters"]: | |
| if is_main_process() and track_pbar: | |
| pbar.close() | |
| wandb.finish(0) | |
| dist.destroy_process_group() | |
| return 0 | |
| if __name__ == "__main__": | |
| if "SLURM_PROCID" in os.environ: | |
| os.environ["TRITON_CACHE_DIR"] = "/tmp" | |
| # Arguments | |
| parser = argparse.ArgumentParser( | |
| description="Training script", conflict_handler="resolve" | |
| ) | |
| parser.add_argument("--config-file", type=str, required=True) | |
| parser.add_argument("--master-port", type=str) | |
| parser.add_argument("--distributed", action="store_true") | |
| parser.add_argument("--local_rank", type=int, default=0) | |
| args = parser.parse_args() | |
| with open(args.config_file, "r") as f: | |
| config = json.load(f) | |
| deterministic = config["generic"].get("deterministic", True) | |
| torch.backends.cudnn.deterministic = deterministic | |
| torch.backends.cudnn.benchmark = not deterministic | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.set_float32_matmul_precision("high") | |
| torch.backends.cuda.enable_mem_efficient_sdp(False) | |
| torch.set_num_threads(1) | |
| main_worker(config, args) | |