Spaces:
Runtime error
Runtime error
performance
Browse files- interface/app.py +8 -2
- interface/model_loader.py +1 -1
interface/app.py
CHANGED
|
@@ -10,6 +10,9 @@ import io
|
|
| 10 |
from huggingface_hub import snapshot_download
|
| 11 |
import json
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model")
|
| 14 |
|
| 15 |
|
|
@@ -52,7 +55,10 @@ default_dxdysxsy = json.dumps(
|
|
| 52 |
)
|
| 53 |
|
| 54 |
def cv_to_pil(img):
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
def random_sample(model_name: str):
|
|
@@ -175,5 +181,5 @@ Double click to add or remove stop points.
|
|
| 175 |
random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
|
| 176 |
)
|
| 177 |
|
| 178 |
-
block.queue(api_open=False)
|
| 179 |
block.launch(show_api=False)
|
|
|
|
| 10 |
from huggingface_hub import snapshot_download
|
| 11 |
import json
|
| 12 |
|
| 13 |
+
# disable if running on another environment
|
| 14 |
+
RESIZE = True
|
| 15 |
+
|
| 16 |
models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model")
|
| 17 |
|
| 18 |
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
def cv_to_pil(img):
|
| 58 |
+
img = Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
|
| 59 |
+
if RESIZE:
|
| 60 |
+
img = img.resize((128, 128))
|
| 61 |
+
return img
|
| 62 |
|
| 63 |
|
| 64 |
def random_sample(model_name: str):
|
|
|
|
| 181 |
random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
|
| 182 |
)
|
| 183 |
|
| 184 |
+
# block.queue(api_open=False)
|
| 185 |
block.launch(show_api=False)
|
interface/model_loader.py
CHANGED
|
@@ -12,7 +12,7 @@ class Model:
|
|
| 12 |
):
|
| 13 |
self.truncation = truncation
|
| 14 |
self.use_average_code_as_input = use_average_code_as_input
|
| 15 |
-
ckpt = torch.load(checkpoint_path, map_location="
|
| 16 |
opts = ckpt["opts"]
|
| 17 |
opts["checkpoint_path"] = checkpoint_path
|
| 18 |
self.opts = Namespace(**ckpt["opts"])
|
|
|
|
| 12 |
):
|
| 13 |
self.truncation = truncation
|
| 14 |
self.use_average_code_as_input = use_average_code_as_input
|
| 15 |
+
ckpt = torch.load(checkpoint_path, map_location="cuda")
|
| 16 |
opts = ckpt["opts"]
|
| 17 |
opts["checkpoint_path"] = checkpoint_path
|
| 18 |
self.opts = Namespace(**ckpt["opts"])
|