Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| import trimesh | |
| from pathlib import Path | |
| device = torch.device("cpu") | |
| model = torch.jit.load('model_scripted.pt').to(device) | |
| def normalize_vertices(verts): | |
| # Center the vertices | |
| center = verts.mean(dim=0) | |
| verts = verts - center | |
| # Find the maximum absolute value for each axis to scale them independently | |
| scale = verts.abs().max(dim=0)[0] # This finds the max in each dimension independently | |
| # Scale the vertices so that in each dimension, the furthest point is exactly at 1 or -1 | |
| # We avoid division by zero by ensuring scale values are at least a very small number | |
| scale = torch.where(scale == 0, torch.ones_like(scale), scale) # Prevent division by zero | |
| return verts / scale | |
| def plot_3d_results(verts, faces, uv_seam_edge_indices): | |
| # Convert vertices to NumPy for easier manipulation | |
| verts_np = verts.cpu().numpy() | |
| faces_np = faces.cpu().numpy() | |
| # Prepare the vertex coordinates for the Mesh3d plot | |
| x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2] | |
| i, j, k = faces_np[:, 0], faces_np[:, 1], faces_np[:, 2] | |
| # Create the 3D mesh plot | |
| mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightblue', opacity=0.50, name='Mesh') | |
| # Prepare lines for the predicted edges | |
| edge_x, edge_y, edge_z = [], [], [] | |
| for edge in uv_seam_edge_indices: | |
| x0, y0, z0 = verts_np[edge[0]] | |
| x1, y1, z1 = verts_np[edge[1]] | |
| edge_x.extend([x0, x1, None]) | |
| edge_y.extend([y0, y1, None]) | |
| edge_z.extend([z0, z1, None]) | |
| # Create a trace for edges | |
| edges_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(color='red', width=2), | |
| name='Predicted Edges') | |
| # Create a figure and add the mesh and edges | |
| fig = go.Figure(data=[mesh, edges_trace]) | |
| fig.update_layout(scene=dict( | |
| xaxis=dict(nticks=4, backgroundcolor="rgb(200, 200, 230)", gridcolor="white", showbackground=True, | |
| zerolinecolor="white"), | |
| yaxis=dict(nticks=4, backgroundcolor="rgb(230, 200,230)", gridcolor="white", showbackground=True, | |
| zerolinecolor="white"), | |
| zaxis=dict(nticks=4, backgroundcolor="rgb(230, 230,200)", gridcolor="white", showbackground=True, | |
| zerolinecolor="white"), camera=dict(up=dict(x=0, y=1, z=0), eye=dict(x=1.25, y=1.25, z=1.25))), | |
| title_text='Predicted Edges') | |
| # return the figure | |
| return fig | |
| def generate_prediction(file_input, treshold_value=0.5): | |
| if not file_input: | |
| return | |
| # Load and triangulate the mesh | |
| mesh = trimesh.load_mesh(file_input) | |
| # For production, we should use a faster method to preprocess the mesh! | |
| # Convert vertices to a PyTorch tensor | |
| vertices = torch.tensor(mesh.vertices, dtype=torch.float32) | |
| vertices = normalize_vertices(vertices) | |
| # Initialize containers for unique vertices and mapping | |
| unique_vertices = [] | |
| vertex_mapping = {} | |
| new_faces = [] | |
| # Populate unique vertices and create new faces with updated indices | |
| for face in mesh.faces: | |
| new_face = [] | |
| for orig_index in face: | |
| vertex = tuple(vertices[orig_index].tolist()) # Convert to tuple (hashable) | |
| if vertex not in vertex_mapping: | |
| vertex_mapping[vertex] = len(unique_vertices) | |
| unique_vertices.append(vertices[orig_index]) | |
| new_face.append(vertex_mapping[vertex]) | |
| new_faces.append(new_face) | |
| # Create edge set to ensure uniqueness | |
| edge_set = set() | |
| for face in new_faces: | |
| # Unpack the vertex indices | |
| v1, v2, v3 = face | |
| # Create undirected edges (use tuple sorting to ensure uniqueness) | |
| edge_set.add(tuple(sorted((v1, v2)))) | |
| edge_set.add(tuple(sorted((v2, v3)))) | |
| edge_set.add(tuple(sorted((v1, v3)))) | |
| # Convert edges back to tensor | |
| edges = torch.tensor(list(edge_set), dtype=torch.long) | |
| # Convert unique vertices and new faces back to tensors | |
| verts = torch.stack(unique_vertices) | |
| faces = torch.tensor(new_faces, dtype=torch.long) | |
| model.eval() | |
| with torch.no_grad(): | |
| test_outputs_logits = model(verts, edges).to(device) | |
| test_outputs = torch.sigmoid(test_outputs_logits).to(device) | |
| test_predictions = (test_outputs > treshold_value).int().cpu() | |
| uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1 | |
| uv_seam_edges = edges[uv_seam_edges_mask].cpu().tolist() | |
| # Return the HTML content generated by plot_3d_results | |
| return plot_3d_results(verts, faces, uv_seam_edges) | |
| def run_gradio(): | |
| with gr.Blocks() as demo: | |
| gr.Label("Proof of concept demo. Predict UV seams on a 3D sphere meshes.") | |
| with gr.Row(): | |
| model3d_input = gr.FileExplorer(label="Sphere Prototype Model", | |
| file_count='single', | |
| value='randomSphere_180.obj', | |
| glob='**/*.obj') | |
| with gr.Column(): | |
| model3d_output = gr.Plot() | |
| treshold_value = gr.Slider(minimum=0, maximum=1, value=0.6, label="Threshold") | |
| button = gr.Button("Predict") | |
| button.click(generate_prediction, inputs=[model3d_input, treshold_value], outputs=model3d_output) | |
| demo.launch() | |
| run_gradio() | |