SkillForge45 commited on
Commit
54d9695
·
verified ·
1 Parent(s): 9de1b16

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +227 -0
model.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import CLIPTextModel, CLIPTokenizer
5
+
6
+ class TimeEmbedding(nn.Module):
7
+ def __init__(self, dim):
8
+ super().__init__()
9
+ self.dim = dim
10
+ half_dim = dim // 2
11
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
12
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
13
+ self.register_buffer('emb', emb)
14
+
15
+ def forward(self, time):
16
+ emb = time[:, None] * self.emb[None, :]
17
+ emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
18
+ return emb
19
+
20
+ class AttentionBlock(nn.Module):
21
+ def __init__(self, channels, num_heads=4):
22
+ super().__init__()
23
+ self.num_heads = num_heads
24
+ self.scale = (channels // num_heads) ** -0.5
25
+
26
+ self.norm = nn.GroupNorm(32, channels)
27
+ self.qkv = nn.Conv2d(channels, channels * 3, 1)
28
+ self.proj = nn.Conv2d(channels, channels, 1)
29
+
30
+ def forward(self, x):
31
+ b, c, h, w = x.shape
32
+ qkv = self.qkv(self.norm(x))
33
+ q, k, v = qkv.chunk(3, dim=1)
34
+
35
+ q = q.view(b, self.num_heads, -1, h * w).permute(0, 1, 3, 2)
36
+ k = k.view(b, self.num_heads, -1, h * w)
37
+ v = v.view(b, self.num_heads, -1, h * w)
38
+
39
+ attn = torch.softmax((q @ k) * self.scale, dim=-1)
40
+ x = (attn @ v).permute(0, 1, 3, 2).reshape(b, -1, h, w)
41
+ return self.proj(x) + x
42
+
43
+ class ResBlock(nn.Module):
44
+ def __init__(self, in_ch, out_ch, time_emb_dim, text_emb_dim, dropout=0.1):
45
+ super().__init__()
46
+ self.mlp = nn.Sequential(
47
+ nn.SiLU(),
48
+ nn.Linear(time_emb_dim + text_emb_dim, out_ch * 2)
49
+
50
+ self.block1 = nn.Sequential(
51
+ nn.GroupNorm(32, in_ch),
52
+ nn.SiLU(),
53
+ nn.Conv2d(in_ch, out_ch, 3, padding=1))
54
+
55
+ self.block2 = nn.Sequential(
56
+ nn.GroupNorm(32, out_ch),
57
+ nn.SiLU(),
58
+ nn.Dropout(dropout),
59
+ nn.Conv2d(out_ch, out_ch, 3, padding=1))
60
+
61
+ self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
62
+
63
+ def forward(self, x, time_emb, text_emb):
64
+ emb = self.mlp(torch.cat([time_emb, text_emb], dim=-1))
65
+ scale, shift = torch.chunk(emb, 2, dim=1)
66
+
67
+ h = self.block1(x)
68
+ h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
69
+ h = self.block2(h)
70
+
71
+ return h + self.res_conv(x)
72
+
73
+ class UNet(nn.Module):
74
+ def __init__(self, in_channels=3, out_channels=3, dim=64, dim_mults=(1, 2, 4, 8)):
75
+ super().__init__()
76
+ dims = [dim * m for m in dim_mults]
77
+ in_out = list(zip(dims[:-1], dims[1:]))
78
+
79
+ # Time and text embeddings
80
+ self.time_mlp = nn.Sequential(
81
+ TimeEmbedding(dim),
82
+ nn.Linear(dim, dim * 4),
83
+ nn.SiLU(),
84
+ nn.Linear(dim * 4, dim))
85
+
86
+ # Text conditioning
87
+ self.text_proj = nn.Linear(768, dim * 4)
88
+
89
+ # Initial convolution
90
+ self.init_conv = nn.Conv2d(in_channels, dim, 3, padding=1)
91
+
92
+ # Downsample blocks
93
+ self.downs = nn.ModuleList()
94
+ for ind, (in_dim, out_dim) in enumerate(in_out):
95
+ is_last = ind >= (len(in_out) - 1)
96
+ self.downs.append(nn.ModuleList([
97
+ ResBlock(in_dim, in_dim, dim, dim * 4),
98
+ ResBlock(in_dim, in_dim, dim, dim * 4),
99
+ AttentionBlock(in_dim),
100
+ nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1) if not is_last else nn.Conv2d(in_dim, out_dim, 3, padding=1)
101
+ ]))
102
+
103
+ # Middle blocks
104
+ self.mid_block1 = ResBlock(dims[-1], dims[-1], dim, dim * 4)
105
+ self.mid_attn = AttentionBlock(dims[-1])
106
+ self.mid_block2 = ResBlock(dims[-1], dims[-1], dim, dim * 4)
107
+
108
+ # Upsample blocks
109
+ self.ups = nn.ModuleList()
110
+ for ind, (in_dim, out_dim) in enumerate(reversed(in_out)):
111
+ is_last = ind >= (len(in_out) - 1)
112
+ self.ups.append(nn.ModuleList([
113
+ ResBlock(out_dim + in_dim, out_dim, dim, dim * 4),
114
+ ResBlock(out_dim + in_dim, out_dim, dim, dim * 4),
115
+ AttentionBlock(out_dim),
116
+ nn.ConvTranspose2d(out_dim, out_dim, 4, 2, 1) if not is_last else nn.Identity()
117
+ ]))
118
+
119
+ # Final blocks
120
+ self.final_block1 = ResBlock(dim * 2, dim, dim, dim * 4)
121
+ self.final_block2 = ResBlock(dim, dim, dim, dim * 4)
122
+ self.final_conv = nn.Conv2d(dim, out_channels, 3, padding=1)
123
+
124
+ def forward(self, x, time, text_emb):
125
+ t = self.time_mlp(time)
126
+ text_emb = self.text_proj(text_emb)
127
+
128
+ x = self.init_conv(x)
129
+ h = [x]
130
+
131
+ # Downsample
132
+ for block1, block2, attn, downsample in self.downs:
133
+ x = block1(x, t, text_emb)
134
+ x = block2(x, t, text_emb)
135
+ x = attn(x)
136
+ h.append(x)
137
+ x = downsample(x)
138
+
139
+ # Bottleneck
140
+ x = self.mid_block1(x, t, text_emb)
141
+ x = self.mid_attn(x)
142
+ x = self.mid_block2(x, t, text_emb)
143
+
144
+ # Upsample
145
+ for block1, block2, attn, upsample in self.ups:
146
+ x = torch.cat([x, h.pop()], dim=1)
147
+ x = block1(x, t, text_emb)
148
+ x = block2(x, t, text_emb)
149
+ x = attn(x)
150
+ x = upsample(x)
151
+
152
+ # Final
153
+ x = torch.cat([x, h.pop()], dim=1)
154
+ x = self.final_block1(x, t, text_emb)
155
+ x = self.final_block2(x, t, text_emb)
156
+ return self.final_conv(x)
157
+
158
+ class DiffusionModel(nn.Module):
159
+ def __init__(self, model, betas, device):
160
+ super().__init__()
161
+ self.model = model
162
+ self.betas = betas
163
+ self.alphas = 1. - betas
164
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
165
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
166
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
167
+ self.device = device
168
+
169
+ # CLIP model for text conditioning
170
+ self.clip = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
171
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
172
+ for param in self.clip.parameters():
173
+ param.requires_grad = False
174
+
175
+ def get_text_emb(self, prompts):
176
+ inputs = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device)
177
+ return self.clip(**inputs).last_hidden_state.mean(dim=1)
178
+
179
+ def q_sample(self, x_start, t, noise=None):
180
+ if noise is None:
181
+ noise = torch.randn_like(x_start)
182
+
183
+ sqrt_alpha_cumprod = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
184
+ sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
185
+
186
+ return sqrt_alpha_cumprod * x_start + sqrt_one_minus_alpha_cumprod * noise
187
+
188
+ def p_losses(self, x_start, text, t, noise=None):
189
+ if noise is None:
190
+ noise = torch.randn_like(x_start)
191
+
192
+ x_noisy = self.q_sample(x_start, t, noise)
193
+ text_emb = self.get_text_emb(text)
194
+ predicted_noise = self.model(x_noisy, t, text_emb)
195
+
196
+ return F.mse_loss(noise, predicted_noise)
197
+
198
+ @torch.no_grad()
199
+ def sample(self, prompts, image_size=256, batch_size=4, channels=3, cfg_scale=7.5):
200
+ shape = (batch_size, channels, image_size, image_size)
201
+ x = torch.randn(shape, device=self.device)
202
+
203
+ text_emb = self.get_text_emb(prompts)
204
+ uncond_emb = self.get_text_emb([""] * batch_size)
205
+
206
+ for i in reversed(range(0, len(self.betas))):
207
+ t = torch.full((batch_size,), i, device=self.device, dtype=torch.long)
208
+
209
+ # Classifier-free guidance
210
+ noise_pred = self.model(x, t, text_emb)
211
+ noise_pred_uncond = self.model(x, t, uncond_emb)
212
+ noise_pred = noise_pred_uncond + cfg_scale * (noise_pred - noise_pred_uncond)
213
+
214
+ alpha = self.alphas[t].view(-1, 1, 1, 1)
215
+ alpha_cumprod = self.alphas_cumprod[t].view(-1, 1, 1, 1)
216
+ beta = self.betas[t].view(-1, 1, 1, 1)
217
+
218
+ if i > 0:
219
+ noise = torch.randn_like(x)
220
+ else:
221
+ noise = torch.zeros_like(x)
222
+
223
+ x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * noise_pred) + torch.sqrt(beta) * noise
224
+
225
+ x = (x.clamp(-1, 1) + 1) / 2
226
+ x = (x * 255).type(torch.uint8)
227
+ return x