Spaces:
Sleeping
Sleeping
| # ---------------------------------------------------------------------------- | |
| # IMPORTS | |
| # ---------------------------------------------------------------------------- | |
| import os | |
| os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
| import glob | |
| import torch | |
| import shutil | |
| from tqdm import tqdm | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from networks import ImageClassifier | |
| from parser import get_parser | |
| from dataset import create_dataloader | |
| from sklearn.metrics import balanced_accuracy_score | |
| def check_accuracy(val_dataloader, model, settings): | |
| model.eval() | |
| label_array = torch.empty(0, dtype=torch.int64, device=device) | |
| pred_array = torch.empty(0, dtype=torch.int64, device=device) | |
| with torch.no_grad(): | |
| with tqdm(val_dataloader, unit='batch', mininterval=0.5) as tbatch: | |
| tbatch.set_description(f'Validation') | |
| for (data, label, _) in tbatch: | |
| data = data.to(device) | |
| label = label.to(device) | |
| pred = model(data).squeeze(1) | |
| label_array = torch.cat((label_array, label)) | |
| pred_array = torch.cat((pred_array, pred)) | |
| accuracy = balanced_accuracy_score(label_array.cpu().numpy(), pred_array.cpu().numpy() > 0) | |
| print(f'Got accuracy {accuracy:.2f} \n') | |
| return accuracy | |
| def train(train_dataloader, val_dataloader, model, settings): | |
| best_accuracy = 0 | |
| lr_decay_counter = 0 | |
| for epoch in range(0, settings.num_epoches): | |
| model.train() | |
| with tqdm(train_dataloader, unit='batch', mininterval=0.5) as tepoch: | |
| tepoch.set_description(f'Epoch {epoch}', refresh=False) | |
| if epoch > 0: | |
| for batch_idx, (data, label, _) in enumerate(tepoch): | |
| data = data.to(device) | |
| label = label.to(device).float() | |
| scores = model(data).squeeze(1) | |
| loss = criterion(scores, label).mean() | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| tepoch.set_postfix(loss=loss.item()) | |
| accuracy = check_accuracy(val_dataloader, model, settings) | |
| if accuracy > best_accuracy: | |
| best_accuracy = accuracy | |
| torch.save(model.state_dict(), f'./checkpoint/{settings.name}/weights/best.pt') | |
| print(f'New best model saved with accuracy {best_accuracy:.4f} \n') | |
| lr_decay_counter = 0 | |
| elif settings.lr_decay_epochs > 0: | |
| lr_decay_counter += 1 | |
| if lr_decay_counter == settings.lr_decay_epochs: | |
| if optimizer.param_groups[0]['lr'] > settings.lr_min: | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] *= 0.1 | |
| print('Learning rate decayed \n') | |
| lr_decay_counter = 0 | |
| else: | |
| print('Learning rate already at minimum \n') | |
| break | |
| if __name__ == "__main__": | |
| parser = get_parser() | |
| settings = parser.parse_args() | |
| print(settings) | |
| device = torch.device(settings.device if torch.cuda.is_available() else 'cpu') | |
| model = ImageClassifier(settings) | |
| model.to(device) | |
| os.makedirs(f'./checkpoint/{settings.name}/weights/', exist_ok=True) | |
| with open(f'./checkpoint/settings.txt', 'w') as f: | |
| f.write(str(settings)) | |
| train_dataloader = create_dataloader(settings, split='train') | |
| val_dataloader = create_dataloader(settings, split='val') | |
| optimizer = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=settings.lr) | |
| criterion = nn.BCEWithLogitsLoss(reduction='none') | |
| train(train_dataloader, val_dataloader, model, settings) |