Spaces:
Running
Running
| import gradio as gr | |
| from paths import * | |
| import numpy as np | |
| from vision_tower import DINOv2_MLP | |
| from transformers import AutoImageProcessor | |
| import torch | |
| import os | |
| import matplotlib.pyplot as plt | |
| import io | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./') | |
| print(ckpt_path) | |
| save_path = './' | |
| device = 'cpu' | |
| dino = DINOv2_MLP( | |
| dino_mode = 'large', | |
| in_dim = 1024, | |
| out_dim = 360+180+60+2, | |
| evaluate = True, | |
| mask_dino = False, | |
| frozen_back = False | |
| ).to(device) | |
| dino.eval() | |
| print('model create') | |
| dino.load_state_dict(torch.load(ckpt_path, map_location='cpu')) | |
| print('weight loaded') | |
| val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./') | |
| def get_3angle(image): | |
| # image = Image.open(image_path).convert('RGB') | |
| image_inputs = val_preprocess(images = image) | |
| image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device) | |
| with torch.no_grad(): | |
| dino_pred = dino(image_inputs) | |
| gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1) | |
| gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1) | |
| gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1) | |
| angles = torch.zeros(3) | |
| angles[0] = gaus_ax_pred | |
| angles[1] = gaus_pl_pred - 90 | |
| angles[2] = gaus_ro_pred - 30 | |
| return angles | |
| def scale(x): | |
| # print(x) | |
| # if abs(x[0])<0.1 and abs(x[1])<0.1: | |
| # return x*5 | |
| # else: | |
| # return x | |
| return x*3 | |
| def get_proj2D_XYZ(phi, theta, gamma): | |
| x = np.array([-1*np.sin(phi)*np.cos(gamma) - np.cos(phi)*np.sin(theta)*np.sin(gamma), np.sin(phi)*np.sin(gamma) - np.cos(phi)*np.sin(theta)*np.cos(gamma)]) | |
| y = np.array([-1*np.cos(phi)*np.cos(gamma) + np.sin(phi)*np.sin(theta)*np.sin(gamma), np.cos(phi)*np.sin(gamma) + np.sin(phi)*np.sin(theta)*np.cos(gamma)]) | |
| z = np.array([np.cos(theta)*np.sin(gamma), np.cos(theta)*np.cos(gamma)]) | |
| x = scale(x) | |
| y = scale(y) | |
| z = scale(z) | |
| return x, y, z | |
| # 绘制3D坐标轴 | |
| def draw_axis(ax, origin, vector, color, label=None): | |
| ax.quiver(origin[0], origin[1], vector[0], vector[1], angles='xy', scale_units='xy', scale=1, color=color) | |
| if label!=None: | |
| ax.text(origin[0] + vector[0] * 1.1, origin[1] + vector[1] * 1.1, label, color=color, fontsize=12) | |
| def figure_to_img(fig): | |
| with io.BytesIO() as buf: | |
| fig.savefig(buf, format='JPG', bbox_inches='tight') | |
| buf.seek(0) | |
| image = Image.open(buf).copy() | |
| return image | |
| # def generate_mutimodal(title, context, img): | |
| # return f"Title:{title}\nContext:{context}\n...{img}" | |
| def generate_mutimodal(img): | |
| angles = get_3angle(img) | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| h, w, c = img.shape | |
| if h>w: | |
| extent = [-5*w/h, 5*w/h, -5, 5] | |
| else: | |
| extent = [-5, 5, -5*h/w, 5*h/w] | |
| ax.imshow(img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围 | |
| origin = np.array([0, 0]) | |
| # # 设置旋转角度 | |
| phi = np.radians(angles[0]) | |
| theta = np.radians(angles[1]) | |
| gamma = np.radians(-1*angles[2]) | |
| # 旋转后的向量 | |
| rot_x, rot_y, rot_z = get_proj2D_XYZ(phi, theta, gamma) | |
| draw_axis(ax, origin, rot_y, 'g') | |
| draw_axis(ax, origin, rot_z, 'b') | |
| draw_axis(ax, origin, rot_x, 'r') | |
| # 关闭坐标轴和网格 | |
| ax.set_axis_off() | |
| ax.grid(False) | |
| # 设置坐标范围 | |
| ax.set_xlim(-5, 5) | |
| ax.set_ylim(-5, 5) | |
| res_img = figure_to_img(fig) | |
| # axis_model = "axis.obj" | |
| return [res_img, float(angles[0]), float(angles[1]), float(angles[2])] | |
| server = gr.Interface( | |
| flagging_mode='never', | |
| fn=generate_mutimodal, | |
| inputs=[ | |
| gr.Image(height=512, width=512, label="upload your image") | |
| ], | |
| outputs=[ | |
| gr.Image(height=512, width=512, label="result image"), | |
| # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"), | |
| gr.Textbox(lines=1, label='Azimuth(0~360°)'), | |
| gr.Textbox(lines=1, label='Polar(-90~90°)'), | |
| gr.Textbox(lines=1, label='Rotation(-90~90°)') | |
| ] | |
| ) | |
| server.launch() | |