File size: 3,915 Bytes
9c4b1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# ----------------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------------
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import glob
import torch
import shutil
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

from networks import ImageClassifier
from parser import get_parser
from dataset import create_dataloader
from sklearn.metrics import balanced_accuracy_score

def check_accuracy(val_dataloader, model, settings):
    model.eval()
    
    label_array = torch.empty(0, dtype=torch.int64, device=device)
    pred_array = torch.empty(0, dtype=torch.int64, device=device)

    with torch.no_grad():
        with tqdm(val_dataloader, unit='batch', mininterval=0.5) as tbatch:
            tbatch.set_description(f'Validation')
            for (data, label, _) in tbatch:
                data = data.to(device)
                label = label.to(device)
                
                pred = model(data).squeeze(1)
                
                label_array = torch.cat((label_array, label))
                pred_array = torch.cat((pred_array, pred))
    
    accuracy = balanced_accuracy_score(label_array.cpu().numpy(), pred_array.cpu().numpy() > 0)

    print(f'Got accuracy {accuracy:.2f} \n')
    return accuracy


def train(train_dataloader, val_dataloader, model, settings):
    best_accuracy = 0
    lr_decay_counter = 0
    for epoch in range(0, settings.num_epoches):
        model.train()
        with tqdm(train_dataloader, unit='batch', mininterval=0.5) as tepoch:
            tepoch.set_description(f'Epoch {epoch}', refresh=False)
            if epoch > 0:
                for batch_idx, (data, label, _) in enumerate(tepoch):
                    data = data.to(device)
                    label = label.to(device).float()

                    scores = model(data).squeeze(1)

                    loss = criterion(scores, label).mean()
    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    tepoch.set_postfix(loss=loss.item())

        accuracy = check_accuracy(val_dataloader, model, settings)

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), f'./checkpoint/{settings.name}/weights/best.pt')

            print(f'New best model saved with accuracy {best_accuracy:.4f} \n')
            lr_decay_counter = 0

        elif settings.lr_decay_epochs > 0:
            lr_decay_counter += 1
            if lr_decay_counter == settings.lr_decay_epochs:
                if optimizer.param_groups[0]['lr'] > settings.lr_min:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] *= 0.1
                    print('Learning rate decayed \n')
                    lr_decay_counter = 0
                else:
                    print('Learning rate already at minimum \n')
                    break

if __name__ == "__main__":
    parser = get_parser()
    settings = parser.parse_args()
    print(settings)

    device = torch.device(settings.device if torch.cuda.is_available() else 'cpu')

    model = ImageClassifier(settings)
    model.to(device)
    os.makedirs(f'./checkpoint/{settings.name}/weights/', exist_ok=True)
    
    with open(f'./checkpoint/settings.txt', 'w') as f:
        f.write(str(settings))

    train_dataloader = create_dataloader(settings, split='train')
    val_dataloader = create_dataloader(settings, split='val')

    optimizer = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=settings.lr)

    criterion = nn.BCEWithLogitsLoss(reduction='none')

    train(train_dataloader, val_dataloader, model, settings)