Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from torch import nn | |
| import pickle as pkl | |
| import torch.nn.functional as F | |
| class Struct(object): | |
| def __init__(self, **kwargs): | |
| for key, val in kwargs.items(): | |
| setattr(self, key, val) | |
| def to_np(array, dtype=np.float32): | |
| if 'scipy.sparse' in str(type(array)): | |
| array = array.todense() | |
| return np.array(array, dtype=dtype) | |
| class Get_Joints(nn.Module): | |
| def __init__(self, path, batch_size=300) -> None: | |
| super().__init__() | |
| self.betas = nn.parameter.Parameter(torch.zeros([batch_size, 10], dtype=torch.float32), requires_grad=False) | |
| with open(path, "rb") as f: | |
| smpl_prior = pkl.load(f, encoding="latin1") | |
| data_struct = Struct(**smpl_prior) | |
| self.v_template = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.v_template)), requires_grad=False) | |
| self.shapedirs = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.shapedirs)), requires_grad=False) | |
| self.J_regressor = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.J_regressor)), requires_grad=False) | |
| posedirs = torch.from_numpy(to_np(data_struct.posedirs)) | |
| num_pose_basis = posedirs.shape[-1] | |
| posedirs = posedirs.reshape([-1, num_pose_basis]).permute(1, 0) | |
| self.posedirs = nn.parameter.Parameter(posedirs, requires_grad=False) | |
| self.parents = nn.parameter.Parameter(torch.from_numpy(to_np(data_struct.kintree_table)[0]).long(), requires_grad=False) | |
| self.parents[0] = -1 | |
| self.ident = nn.parameter.Parameter(torch.eye(3), requires_grad=False) | |
| self.K = nn.parameter.Parameter(torch.zeros([1, 3, 3]), requires_grad=False) | |
| self.zeros = nn.parameter.Parameter(torch.zeros([1, 1]), requires_grad=False) | |
| def blend_shapes(self, betas, shape_disps): | |
| blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) | |
| return blend_shape | |
| def vertices2joints(self, J_regressor, vertices): | |
| return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) | |
| def batch_rodrigues( | |
| self, | |
| rot_vecs, | |
| epsilon = 1e-8, | |
| ): | |
| batch_size = rot_vecs.shape[0] | |
| angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True) | |
| rot_dir = rot_vecs / angle | |
| cos = torch.unsqueeze(torch.cos(angle), dim=1) | |
| sin = torch.unsqueeze(torch.sin(angle), dim=1) | |
| # Bx1 arrays | |
| rx, ry, rz = torch.split(rot_dir, 1, dim=1) | |
| K = self.K.repeat(batch_size, 1, 1) | |
| zeros = self.zeros.repeat(batch_size, 1) | |
| K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) | |
| ident = self.ident.unsqueeze(0) | |
| rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) | |
| return rot_mat | |
| def transform_mat(self, R, t): | |
| return torch.cat([F.pad(R, [0, 0, 0, 1]), | |
| F.pad(t, [0, 0, 0, 1], value=1)], dim=2) | |
| def batch_rigid_transform( | |
| self, | |
| rot_mats, | |
| joints, | |
| parents, | |
| ): | |
| joints = torch.unsqueeze(joints, dim=-1) | |
| rel_joints = joints.clone() | |
| rel_joints[:, 1:] -= joints[:, parents[1:]] | |
| transforms_mat = self.transform_mat( | |
| rot_mats.reshape(-1, 3, 3), | |
| rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) | |
| transform_chain = [transforms_mat[:, 0]] | |
| for i in range(1, parents.shape[0]): | |
| # Subtract the joint location at the rest pose | |
| # No need for rotation, since it's identity when at rest | |
| curr_res = torch.matmul(transform_chain[parents[i]], | |
| transforms_mat[:, i]) | |
| transform_chain.append(curr_res) | |
| transforms = torch.stack(transform_chain, dim=1) | |
| # The last column of the transformations contains the posed joints | |
| posed_joints = transforms[:, :, :3, 3] | |
| return posed_joints | |
| def forward(self, pose, trans=None): | |
| pose = pose.float() | |
| batch = pose.shape[0] | |
| betas = self.betas[:batch] | |
| v_shaped = self.v_template + self.blend_shapes(betas, self.shapedirs) | |
| J = self.vertices2joints(self.J_regressor, v_shaped) | |
| rot_mats = self.batch_rodrigues(pose.view(-1, 3)).view([batch, -1, 3, 3]) | |
| J_transformed = self.batch_rigid_transform(rot_mats, J, self.parents) | |
| if trans is not None: | |
| J_transformed += trans.unsqueeze(dim=1) | |
| return J_transformed |