Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import sys | |
| sys.path.append('./') | |
| from model.dit import CogVideoXTransformer3DModel | |
| class PointEmbed(nn.Module): | |
| def __init__(self, hidden_dim=96, dim=512): | |
| super().__init__() | |
| assert hidden_dim % 6 == 0 | |
| self.embedding_dim = hidden_dim | |
| e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi | |
| e = torch.stack([ | |
| torch.cat([e, torch.zeros(self.embedding_dim // 6), | |
| torch.zeros(self.embedding_dim // 6)]), | |
| torch.cat([torch.zeros(self.embedding_dim // 6), e, | |
| torch.zeros(self.embedding_dim // 6)]), | |
| torch.cat([torch.zeros(self.embedding_dim // 6), | |
| torch.zeros(self.embedding_dim // 6), e]), | |
| ]) | |
| self.register_buffer('basis', e) # 3 x 16 | |
| self.mlp = nn.Linear(self.embedding_dim+3, dim) | |
| def embed(input, basis): | |
| projections = torch.einsum( | |
| 'bnd,de->bne', input, basis) | |
| embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) | |
| return embeddings | |
| def forward(self, input): | |
| # input: B x N x 3 | |
| embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C | |
| return embed | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.1, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| self.register_parameter('pe', nn.Parameter(pe, requires_grad=False)) | |
| def forward(self, x): | |
| # not used in the final model | |
| x = x + self.pe[:x.shape[0], :] | |
| return self.dropout(x) | |
| class MDM_DiT(nn.Module): | |
| def __init__(self, n_points, n_frame, n_feats, model_config): | |
| super().__init__() | |
| self.n_points = n_points | |
| self.n_feats = n_feats | |
| self.latent_dim = model_config.latent_dim | |
| self.cond_seq_length = 4 | |
| self.cond_frame = 1 if model_config.frame_cond else 0 | |
| self.dit = CogVideoXTransformer3DModel(sample_points=n_points, sample_frames=n_frame+self.cond_frame, in_channels=n_feats, | |
| num_layers=model_config.n_layers, num_attention_heads=self.latent_dim // 64, cond_seq_length=self.cond_seq_length) | |
| self.input_encoder = PointEmbed(dim=self.latent_dim) | |
| # self.init_cond_encoder = PointEmbed(dim=self.latent_dim) | |
| self.E_cond_encoder = nn.Linear(1, self.latent_dim) | |
| self.nu_cond_encoder = nn.Linear(1, self.latent_dim) | |
| self.force_cond_encoder = nn.Linear(3, self.latent_dim) | |
| self.drag_point_encoder = nn.Linear(3, self.latent_dim) | |
| def enable_gradient_checkpointing(self): | |
| self.dit._set_gradient_checkpointing(True) | |
| def forward(self, x, timesteps, init_pc, force, E, nu, drag_mask, drag_point, floor_height=None, coeff=None, y=None, null_emb=0): | |
| """ | |
| x: [batch_size, frame, n_points, n_feats], denoted x_t in the paper | |
| timesteps: [batch_size] (int) | |
| """ | |
| bs, n_frame, n_points, n_feats = x.shape | |
| init_pc = init_pc.reshape(bs, n_points, n_feats) | |
| force = force.unsqueeze(1) | |
| E = E.unsqueeze(1) | |
| nu = nu.unsqueeze(1) | |
| drag_point = drag_point.unsqueeze(1) | |
| x = torch.cat([init_pc.unsqueeze(1), x], axis=1) | |
| n_frame += 1 | |
| encoder_hidden_states = torch.cat([self.force_cond_encoder(force), self.E_cond_encoder(E), | |
| self.nu_cond_encoder(nu), self.drag_point_encoder(drag_point)], axis=1) | |
| hidden_states = self.input_encoder(x.reshape(bs * n_frame, n_points, | |
| n_feats)).reshape(bs, n_frame, n_points, self.latent_dim) | |
| full_seq = torch.cat([encoder_hidden_states, hidden_states.reshape(bs, n_frame * n_points, self.latent_dim)], axis=1) | |
| output = self.dit(full_seq, timesteps).reshape(bs, n_frame, n_points, 3)[:, self.cond_frame:] | |
| output = output + init_pc.unsqueeze(1) | |
| return output | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| point_num = 512 | |
| frame_num = 6 | |
| x = torch.randn(2, frame_num, point_num, 3).to(device).to(torch.float16) | |
| timesteps = torch.tensor([999, 999]).int().to(device).to(torch.float16) | |
| init_pc = torch.randn(2, 1, point_num, 3).to(device).to(torch.float16) | |
| force = torch.randn(2, 3).to(device).to(torch.float16) | |
| E = torch.randn(2, 1).to(device).to(torch.float16) | |
| nu = torch.randn(2, 1).to(device).to(torch.float16) | |
| model = MDM_DiT([point_num], 3).to(device).to(torch.float16) | |
| output = model(x, timesteps, init_pc, force, E, nu) | |
| print(output.shape) | |