wli1995 commited on
Commit
442b121
·
verified ·
1 Parent(s): c6a60e4

Upload gradio demo

Browse files
.gitattributes CHANGED
@@ -44,3 +44,4 @@ images/face/00_00.png filter=lfs diff=lfs merge=lfs -text
44
  images/image/02.png filter=lfs diff=lfs merge=lfs -text
45
  images/result_0.png filter=lfs diff=lfs merge=lfs -text
46
  images/result_1.png filter=lfs diff=lfs merge=lfs -text
 
 
44
  images/image/02.png filter=lfs diff=lfs merge=lfs -text
45
  images/result_0.png filter=lfs diff=lfs merge=lfs -text
46
  images/result_1.png filter=lfs diff=lfs merge=lfs -text
47
+ assert/gradio_demo.JPG filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -63,7 +63,30 @@ Input Data:
63
  | `-- 02.png
64
 
65
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
67
 
68
  #### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
69
 
 
63
  | `-- 02.png
64
 
65
  ```
66
+ #### Inference with M.2 Accelerator card
67
+ ```
68
+ $cd python
69
+ $python3 gradio_demo.py
70
+ [INFO] Available providers: ['AXCLRTExecutionProvider']
71
+ [INFO] Using provider: AXCLRTExecutionProvider
72
+ [INFO] SOC Name: AX650N
73
+ [INFO] VNPU type: VNPUType.DISABLED
74
+ [INFO] Compiler version: 5.0-patch1 6d9cc640
75
+ [INFO] Using provider: AXCLRTExecutionProvider
76
+ [INFO] SOC Name: AX650N
77
+ [INFO] VNPU type: VNPUType.DISABLED
78
+ [INFO] Compiler version: 5.0-patch1 681a0b38
79
+ [INFO] Using provider: AXCLRTExecutionProvider
80
+ [INFO] SOC Name: AX650N
81
+ [INFO] VNPU type: VNPUType.DISABLED
82
+ [INFO] Compiler version: 4.2-dirty 5e72cf06-dirty
83
+ * Running on local URL: http://0.0.0.0:7860
84
+ * To create a public link, set `share=True` in `launch()`.
85
+ ```
86
+ Then use the M.2 Accelerator card IP instead of the 0.0.0.0, and use chrome open the URL: http://[your ip]:7860
87
 
88
+ ![gradio face](./assert/gradio_face.JPG)
89
+ ![gradio demo](./assert/gradio_demo.JPG)
90
 
91
  #### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
92
 
assert/gradio_demo.JPG ADDED

Git LFS Details

  • SHA256: 5ffbf568e8f2a013c33aff71dadb3ff3d8b9e521c65e825b3873a59ad8385da2
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
assert/gradio_face.JPG ADDED
python/gradio_demo.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import tempfile
4
+ import numpy as np
5
+ import axengine as axe
6
+ import cv2
7
+ from utils.restoration_helper import RestoreHelper
8
+
9
+ restore_helper = RestoreHelper(
10
+ upscale_factor=1,
11
+ face_size=512,
12
+ crop_ratio=(1, 1),
13
+ det_model="../model/yolov5l-face.axmodel",
14
+ res_model="../model/codeformer.axmodel",
15
+ bg_model="../model/realesrgan-x2.axmodel",
16
+ save_ext='png',
17
+ use_parse=True
18
+ )
19
+
20
+ def face(img_path, session):
21
+
22
+ output_names = [x.name for x in session.get_outputs()]
23
+ input_name = session.get_inputs()[0].name
24
+
25
+ ori_image = cv2.imread(img_path)
26
+ h, w = ori_image.shape[:2]
27
+ image = cv2.resize(ori_image, (512, 512))
28
+ image = (image[..., ::-1] /255.0).astype(np.float32)
29
+
30
+ mean = [0.5, 0.5, 0.5]
31
+ std = [0.5, 0.5, 0.5]
32
+ image = ((image - mean) / std).astype(np.float32)
33
+
34
+ #image = (image /1.0).astype(np.float32)
35
+ img = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
36
+
37
+ # Use the model to generate super-resolved images
38
+ sr = session.run(output_names, {input_name: img})
39
+
40
+ #sr_y_image = imgproc.array_to_image(sr)
41
+ sr = np.transpose(sr[0].squeeze(0), (1,2,0))
42
+ sr = (sr*std + mean).astype(np.float32)
43
+
44
+ # Save image
45
+ ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
46
+ out_image = cv2.resize(ndarr[..., ::-1], (w, h))
47
+
48
+ return out_image
49
+
50
+ def full_image(img_path, restore_helper=restore_helper):
51
+
52
+ restore_helper.clean_all()
53
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
54
+
55
+ restore_helper.read_image(img)
56
+ # get face landmarks for each face
57
+ num_det_faces = restore_helper.get_face_landmarks_5(
58
+ only_center_face=False, resize=640, eye_dist_threshold=5)
59
+ # align and warp each face
60
+ restore_helper.align_warp_face()
61
+ # face restoration for each cropped face
62
+ for idx, cropped_face in enumerate(restore_helper.cropped_faces):
63
+ # prepare data
64
+ cropped_face_t = (cropped_face.astype(np.float32) / 255.0) * 2.0 - 1.0
65
+ cropped_face_t = np.transpose(
66
+ np.expand_dims(np.ascontiguousarray(cropped_face_t[...,::-1]), axis=0),
67
+ (0,3,1,2)
68
+ )
69
+ #print('cropped_face_t', cropped_face_t.shape)
70
+
71
+ try:
72
+ ort_outs = restore_helper.rs_sessison.run(
73
+ restore_helper.rs_output,
74
+ {restore_helper.rs_input: cropped_face_t}
75
+ )
76
+ restored_face = ort_outs[0]
77
+ restored_face = (restored_face.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
78
+ restored_face = np.clip(restored_face[...,::-1], 0, 255).astype(np.uint8)
79
+ except Exception as error:
80
+ print(f'\tFailed inference for CodeFormer: {error}')
81
+ restored_face = (cropped_face_t.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
82
+ restored_face = np.clip(restored_face, 0, 255).astype(np.uint8)
83
+
84
+ restored_face = restored_face.astype('uint8')
85
+ restore_helper.add_restored_face(restored_face, cropped_face)
86
+
87
+ # upsample the background
88
+ # Now only support RealESRGAN for upsampling background
89
+ bg_img = restore_helper.background_upsampling(img)
90
+ restore_helper.get_inverse_affine(None)
91
+ # paste each restored face to the input image
92
+ restored_img = restore_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
93
+
94
+ return restored_img
95
+
96
+
97
+ def colorize_image(input_img_path: str, model_name: str, progress=gr.Progress()):
98
+ if not input_img_path:
99
+ raise gr.Error("未上传图片")
100
+
101
+ # 加载图像
102
+ progress(0.3, desc="加载图像...")
103
+
104
+ # 根据模型选择调用不同函数
105
+ if model_name == "Face":
106
+ out = face(input_img_path, session=restore_helper.rs_sessison)
107
+ else:
108
+ out = full_image(input_img_path, restore_helper=restore_helper)
109
+
110
+ progress(0.9, desc="保存结果...")
111
+
112
+ # 保存到临时文件
113
+ output_path = os.path.join(tempfile.gettempdir(), "restore_output.jpg")
114
+ cv2.imwrite(output_path, out)
115
+
116
+ progress(1.0, desc="完成!")
117
+ return output_path
118
+
119
+
120
+ # ==============================
121
+ # Gradio 界面
122
+ # ==============================
123
+ custom_css = """
124
+ body, .gradio-container {
125
+ font-family: 'Microsoft YaHei', 'PingFang SC', 'Helvetica Neue', Arial, sans-serif;
126
+ }
127
+ .model-buttons .wrap {
128
+ display: flex;
129
+ gap: 10px;
130
+ }
131
+ .model-buttons .wrap label {
132
+ background-color: #f0f0f0;
133
+ padding: 10px 20px;
134
+ border-radius: 8px;
135
+ cursor: pointer;
136
+ text-align: center;
137
+ font-weight: 600;
138
+ border: 2px solid transparent;
139
+ flex: 1;
140
+ }
141
+ .model-buttons .wrap label:hover {
142
+ background-color: #e0e0e0;
143
+ }
144
+ .model-buttons .wrap input[type="radio"]:checked + label {
145
+ background-color: #4CAF50;
146
+ color: white;
147
+ border-color: #45a049;
148
+ }
149
+ """
150
+
151
+ with gr.Blocks(title="人脸修复工具") as demo:
152
+ gr.Markdown("## 🎨 人脸修复演示DEMO")
153
+
154
+ with gr.Row(equal_height=True):
155
+ # 左侧:输入区
156
+ with gr.Column(scale=1, min_width=300):
157
+ gr.Markdown("### 📤 输入")
158
+ input_image = gr.Image(
159
+ type="filepath",
160
+ label="上传图片",
161
+ sources=["upload"],
162
+ height=300
163
+ )
164
+
165
+ gr.Markdown("### 🔧 选择修复模式")
166
+ model_choice = gr.Radio(
167
+ choices=["Face", "Full image"],
168
+ value="Face",
169
+ label=None,
170
+ elem_classes="model-buttons"
171
+ )
172
+
173
+ run_btn = gr.Button("🚀 开始修复", variant="primary")
174
+
175
+ # 右侧:输出区
176
+ with gr.Column(scale=1, min_width=600):
177
+ gr.Markdown("### 🖼️ 修复结果")
178
+ output_image = gr.Image(
179
+ label="修复后图片",
180
+ interactive=False,
181
+ height=600
182
+ )
183
+ download_btn = gr.File(label="📥 下载修复图片")
184
+
185
+ # 绑定事件
186
+ def on_colorize(img_path, model, progress=gr.Progress()):
187
+ if img_path is None:
188
+ raise gr.Error("请先上传图片!")
189
+ try:
190
+ result_path = colorize_image(img_path, model, progress=progress)
191
+ return result_path, result_path
192
+ except Exception as e:
193
+ raise gr.Error(f"处理失败: {str(e)}")
194
+
195
+ run_btn.click(
196
+ fn=on_colorize,
197
+ inputs=[input_image, model_choice],
198
+ outputs=[output_image, download_btn]
199
+ )
200
+
201
+ # 启动
202
+ if __name__ == "__main__":
203
+ demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), css=custom_css)
python/utils/__pycache__/face_detector.cpython-313.pyc ADDED
Binary file (8.11 kB). View file
 
python/utils/__pycache__/general.cpython-313.pyc ADDED
Binary file (19.2 kB). View file
 
python/utils/__pycache__/restoration_helper.cpython-313.pyc ADDED
Binary file (27.9 kB). View file