| | import torch |
| | import datetime |
| | import time |
| | import torch.distributed as dist |
| | import yaml |
| | import os |
| |
|
| | class MetricLogger: |
| | """Metric logger for training""" |
| | def __init__(self, delimiter="\t"): |
| | self.meters = {} |
| | self.delimiter = delimiter |
| |
|
| | def update(self, **kwargs): |
| | for k, v in kwargs.items(): |
| | if isinstance(v, torch.Tensor): |
| | v = v.item() |
| | if k not in self.meters: |
| | self.meters[k] = SmoothedValue() |
| | self.meters[k].update(v) |
| |
|
| | def __str__(self): |
| | loss_str = [] |
| | for name, meter in self.meters.items(): |
| | loss_str.append(f"{name}: {meter}") |
| | return self.delimiter.join(loss_str) |
| |
|
| | def synchronize_between_processes(self): |
| | for meter in self.meters.values(): |
| | meter.synchronize_between_processes() |
| |
|
| | def log_every(self, iterable, print_freq, header=None): |
| | i = 0 |
| | if not header: |
| | header = '' |
| | start_time = time.time() |
| | end = time.time() |
| | iter_time = SmoothedValue(fmt='{avg:.4f}') |
| | data_time = SmoothedValue(fmt='{avg:.4f}') |
| | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' |
| | log_msg = [ |
| | header, |
| | '[{0' + space_fmt + '}/{1}]', |
| | 'eta: {eta}', |
| | '{meters}', |
| | 'time: {time}', |
| | 'data: {data}' |
| | ] |
| | log_msg = self.delimiter.join(log_msg) |
| | for obj in iterable: |
| | data_time.update(time.time() - end) |
| | yield obj |
| | iter_time.update(time.time() - end) |
| | if i % print_freq == 0 or i == len(iterable) - 1: |
| | eta_seconds = iter_time.global_avg * (len(iterable) - i) |
| | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) |
| | if torch.cuda.is_available() and dist.get_rank() == 0: |
| | print(log_msg.format( |
| | i, len(iterable), eta=eta_string, |
| | meters=str(self), |
| | time=str(iter_time), data=str(data_time))) |
| | i += 1 |
| | end = time.time() |
| | total_time = time.time() - start_time |
| | total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
| | print(f'{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)') |
| |
|
| |
|
| | class SmoothedValue: |
| | """Track a series of values and provide access to smoothed values""" |
| | def __init__(self, window_size=20, fmt=None): |
| | if fmt is None: |
| | fmt = "{median:.4f} ({global_avg:.4f})" |
| | self.deque = [] |
| | self.total = 0.0 |
| | self.count = 0 |
| | self.fmt = fmt |
| | self.window_size = window_size |
| |
|
| | def update(self, value, n=1): |
| | self.deque.append(value) |
| | if len(self.deque) > self.window_size: |
| | self.deque.pop(0) |
| | self.count += n |
| | self.total += value * n |
| |
|
| | def synchronize_between_processes(self): |
| | """Synchronize across all processes""" |
| | if not dist.is_available() or not dist.is_initialized(): |
| | return |
| | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') |
| | dist.barrier() |
| | dist.all_reduce(t) |
| | t = t.tolist() |
| | self.count = int(t[0]) |
| | self.total = t[1] |
| |
|
| | @property |
| | def median(self): |
| | d = sorted(self.deque) |
| | n = len(d) |
| | if n == 0: |
| | return 0 |
| | if n % 2 == 0: |
| | return (d[n // 2 - 1] + d[n // 2]) / 2 |
| | return d[n // 2] |
| |
|
| | @property |
| | def avg(self): |
| | if len(self.deque) == 0: |
| | return 0 |
| | return sum(self.deque) / len(self.deque) |
| |
|
| | @property |
| | def global_avg(self): |
| | if self.count == 0: |
| | return 0 |
| | return self.total / self.count |
| |
|
| | def __str__(self): |
| | return self.fmt.format( |
| | median=self.median, |
| | avg=self.avg, |
| | global_avg=self.global_avg, |
| | max=max(self.deque) if len(self.deque) > 0 else 0, |
| | value=self.deque[-1] if len(self.deque) > 0 else 0 |
| | ) |
| | |
| |
|
| |
|
| | def load_config(config_path): |
| | """Load configuration from YAML file""" |
| | with open(config_path, 'r') as f: |
| | config = yaml.safe_load(f) |
| | return config |
| |
|
| |
|
| | def log_to_file(log_file, message): |
| | """Write message to log file""" |
| | if log_file is not None: |
| | with open(log_file, 'a') as f: |
| | timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| | f.write(f"[{timestamp}] {message}\n") |
| | f.flush() |
| |
|
| |
|
| | def count_parameters(model, verbose=True): |
| | """Count model parameters""" |
| | def count_params(module): |
| | return sum(p.numel() for p in module.parameters() if p.requires_grad) |
| |
|
| | def format_number(num): |
| | if num >= 1e9: |
| | return f"{num/1e9:.2f}B" |
| | elif num >= 1e6: |
| | return f"{num/1e6:.2f}M" |
| | elif num >= 1e3: |
| | return f"{num/1e3:.2f}K" |
| | else: |
| | return str(num) |
| |
|
| | |
| | if hasattr(model, 'module'): |
| | model = model.module |
| |
|
| | total_params = count_params(model) |
| |
|
| | if verbose: |
| | print("\n" + "="*80) |
| | print("Model Parameter Statistics") |
| | print("="*80) |
| |
|
| | |
| | encoder_params = 0 |
| | for name in ['patch_embed', 'blocks', 'encoder_norm']: |
| | if hasattr(model, name): |
| | module = getattr(model, name) |
| | params = count_params(module) |
| | encoder_params += params |
| | print(f"{name:.<35} {params:>15,} ({format_number(params):>8})") |
| |
|
| | |
| | if hasattr(model, 'head'): |
| | head_params = count_params(model.head) |
| | print(f"{'Classification/Regression Head':.<35} {head_params:>15,} ({format_number(head_params):>8})") |
| |
|
| | print("\n" + "="*80) |
| | print(f"{'Encoder Parameters':.<35} {encoder_params:>15,} ({format_number(encoder_params):>8})") |
| | print(f"{'TOTAL TRAINABLE PARAMETERS':.<35} {total_params:>15,} ({format_number(total_params):>8})") |
| | print("="*80 + "\n") |
| |
|
| | return total_params |
| |
|
| |
|
| |
|
| | def save_checkpoint(state, is_best, checkpoint_dir, filename='checkpoint.pth'): |
| | """Save checkpoint""" |
| | checkpoint_path = os.path.join(checkpoint_dir, filename) |
| | torch.save(state, checkpoint_path) |
| | if is_best: |
| | best_path = os.path.join(checkpoint_dir, 'checkpoint_best.pth') |
| | torch.save(state, best_path) |
| |
|
| |
|
| | def load_checkpoint(checkpoint_path, model, optimizer, scheduler, scaler=None): |
| | """Load checkpoint""" |
| | if not os.path.isfile(checkpoint_path): |
| | print(f"No checkpoint found at '{checkpoint_path}'") |
| | return 0, 0.0, 0.0 |
| |
|
| | print(f"Loading checkpoint '{checkpoint_path}'") |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| |
|
| | start_epoch = checkpoint['epoch'] |
| | best_metric = checkpoint.get('best_metric', 0.0) |
| | best_loss = checkpoint.get('best_loss', float('inf')) |
| |
|
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| |
|
| | if scaler is not None and 'scaler_state_dict' in checkpoint: |
| | scaler.load_state_dict(checkpoint['scaler_state_dict']) |
| |
|
| | print(f"Loaded checkpoint from epoch {start_epoch}") |
| | return start_epoch, best_metric, best_loss |
| |
|
| |
|
| |
|
| | class LabelScaler: |
| | def __init__(self, mean, std): |
| | self.mean = mean |
| | self.std = std |
| |
|
| | def transform(self, labels): |
| | """标准化: (y - mean) / std""" |
| | return (labels - self.mean) / self.std |