AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
raw
history blame
3.92 kB
# ----------------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------------
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import glob
import torch
import shutil
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from networks import ImageClassifier
from parser import get_parser
from dataset import create_dataloader
from sklearn.metrics import balanced_accuracy_score
def check_accuracy(val_dataloader, model, settings):
model.eval()
label_array = torch.empty(0, dtype=torch.int64, device=device)
pred_array = torch.empty(0, dtype=torch.int64, device=device)
with torch.no_grad():
with tqdm(val_dataloader, unit='batch', mininterval=0.5) as tbatch:
tbatch.set_description(f'Validation')
for (data, label, _) in tbatch:
data = data.to(device)
label = label.to(device)
pred = model(data).squeeze(1)
label_array = torch.cat((label_array, label))
pred_array = torch.cat((pred_array, pred))
accuracy = balanced_accuracy_score(label_array.cpu().numpy(), pred_array.cpu().numpy() > 0)
print(f'Got accuracy {accuracy:.2f} \n')
return accuracy
def train(train_dataloader, val_dataloader, model, settings):
best_accuracy = 0
lr_decay_counter = 0
for epoch in range(0, settings.num_epoches):
model.train()
with tqdm(train_dataloader, unit='batch', mininterval=0.5) as tepoch:
tepoch.set_description(f'Epoch {epoch}', refresh=False)
if epoch > 0:
for batch_idx, (data, label, _) in enumerate(tepoch):
data = data.to(device)
label = label.to(device).float()
scores = model(data).squeeze(1)
loss = criterion(scores, label).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
tepoch.set_postfix(loss=loss.item())
accuracy = check_accuracy(val_dataloader, model, settings)
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), f'./checkpoint/{settings.name}/weights/best.pt')
print(f'New best model saved with accuracy {best_accuracy:.4f} \n')
lr_decay_counter = 0
elif settings.lr_decay_epochs > 0:
lr_decay_counter += 1
if lr_decay_counter == settings.lr_decay_epochs:
if optimizer.param_groups[0]['lr'] > settings.lr_min:
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1
print('Learning rate decayed \n')
lr_decay_counter = 0
else:
print('Learning rate already at minimum \n')
break
if __name__ == "__main__":
parser = get_parser()
settings = parser.parse_args()
print(settings)
device = torch.device(settings.device if torch.cuda.is_available() else 'cpu')
model = ImageClassifier(settings)
model.to(device)
os.makedirs(f'./checkpoint/{settings.name}/weights/', exist_ok=True)
with open(f'./checkpoint/settings.txt', 'w') as f:
f.write(str(settings))
train_dataloader = create_dataloader(settings, split='train')
val_dataloader = create_dataloader(settings, split='val')
optimizer = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=settings.lr)
criterion = nn.BCEWithLogitsLoss(reduction='none')
train(train_dataloader, val_dataloader, model, settings)