SkillForge45 commited on
Commit
8388b4e
·
verified ·
1 Parent(s): d27acbb

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +33 -0
train.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import ConditionalGenerator, Discriminator
2
+ from data_loader import MultiStyleDataset
3
+ import torch.optim as optim
4
+
5
+
6
+ num_styles = len(os.listdir(".styles/"))
7
+ G = ConditionalGenerator(num_styles)
8
+ D = Discriminator(num_styles)
9
+
10
+ opt_G = optim.Adam(G.parameters(), lr=2e-4)
11
+ opt_D = optim.Adam(D.parameters(), lr=2e-4)
12
+
13
+ dataset = MultiStyleDataset(".styles/")
14
+
15
+ for epoch in range(100):
16
+ for img, style_id in dataset:
17
+
18
+ fake_img = G(img.unsqueeze(0), torch.tensor([style_id]))
19
+
20
+
21
+ real_loss = torch.mean((D(img.unsqueeze(0), torch.tensor([style_id])) - 1)**2)
22
+ fake_loss = torch.mean(D(fake_img.detach(), torch.tensor([style_id]))**2)
23
+ loss_D = (real_loss + fake_loss) / 2
24
+
25
+ opt_D.zero_grad()
26
+ loss_D.backward()
27
+ opt_D.step()
28
+
29
+
30
+ loss_G = torch.mean((D(fake_img, torch.tensor([style_id])) - 1)**2
31
+ opt_G.zero_grad()
32
+ loss_G.backward()
33
+ opt_G.step()