Spaces:
Runtime error
Runtime error
Commit
·
171e1d9
1
Parent(s):
d52e07e
updated notebook
Browse files
notebooks/enhance_me_train.ipynb
CHANGED
|
@@ -22,7 +22,7 @@
|
|
| 22 |
},
|
| 23 |
"outputs": [],
|
| 24 |
"source": [
|
| 25 |
-
"!git clone https://github.com/soumik12345/enhance-me\n",
|
| 26 |
"!pip install -qqq wandb streamlit"
|
| 27 |
]
|
| 28 |
},
|
|
@@ -171,7 +171,7 @@
|
|
| 171 |
" enhanced_image = mirnet.infer(original_image)\n",
|
| 172 |
" ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
|
| 173 |
" commons.plot_results(\n",
|
| 174 |
-
" [original_image, ground_truth,
|
| 175 |
" [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
|
| 176 |
" (18, 18),\n",
|
| 177 |
" )"
|
|
@@ -238,7 +238,24 @@
|
|
| 238 |
"outputs": [],
|
| 239 |
"source": [
|
| 240 |
"zero_dce.compile(learning_rate=learning_rate)\n",
|
| 241 |
-
"zero_dce.train(epochs=epochs)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
]
|
| 243 |
}
|
| 244 |
],
|
|
|
|
| 22 |
},
|
| 23 |
"outputs": [],
|
| 24 |
"source": [
|
| 25 |
+
"!git clone https://github.com/soumik12345/enhance-me -b zero-dce\n",
|
| 26 |
"!pip install -qqq wandb streamlit"
|
| 27 |
]
|
| 28 |
},
|
|
|
|
| 171 |
" enhanced_image = mirnet.infer(original_image)\n",
|
| 172 |
" ground_truth = Image.open(mirnet.test_enhanced_images[index])\n",
|
| 173 |
" commons.plot_results(\n",
|
| 174 |
+
" [original_image, ground_truth, enhanced_image],\n",
|
| 175 |
" [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n",
|
| 176 |
" (18, 18),\n",
|
| 177 |
" )"
|
|
|
|
| 238 |
"outputs": [],
|
| 239 |
"source": [
|
| 240 |
"zero_dce.compile(learning_rate=learning_rate)\n",
|
| 241 |
+
"history = zero_dce.train(epochs=epochs)\n",
|
| 242 |
+
"zero_dce.save_weights(os.path.join(experiment_name, \"weights.h5\"))"
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "code",
|
| 247 |
+
"execution_count": null,
|
| 248 |
+
"metadata": {},
|
| 249 |
+
"outputs": [],
|
| 250 |
+
"source": [
|
| 251 |
+
"for index, low_image_file in enumerate(zero_dce.test_low_images):\n",
|
| 252 |
+
" original_image = Image.open(low_image_file)\n",
|
| 253 |
+
" enhanced_image = zero_dce.infer(original_image)\n",
|
| 254 |
+
" commons.plot_results(\n",
|
| 255 |
+
" [original_image, enhanced_image],\n",
|
| 256 |
+
" [\"Original Image\", \"Enhanced Image\"],\n",
|
| 257 |
+
" (18, 18),\n",
|
| 258 |
+
" )"
|
| 259 |
]
|
| 260 |
}
|
| 261 |
],
|