add some comments
Browse files- text2world_hf.py +6 -1
text2world_hf.py
CHANGED
|
@@ -3,12 +3,15 @@ import argparse
|
|
| 3 |
import torch
|
| 4 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 5 |
|
|
|
|
| 6 |
from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
|
| 7 |
from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
| 8 |
import .cosmos1.utils.log as log
|
| 9 |
import .cosmos1.utils.misc as misc
|
| 10 |
from .cosmos1.utils.io import read_prompts_from_file, save_video
|
| 11 |
|
|
|
|
|
|
|
| 12 |
class DiffusionText2WorldConfig(PretrainedConfig):
|
| 13 |
model_type = "DiffusionText2World"
|
| 14 |
def __init__(self, **kwargs):
|
|
@@ -38,6 +41,7 @@ class DiffusionText2WorldConfig(PretrainedConfig):
|
|
| 38 |
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
| 39 |
|
| 40 |
|
|
|
|
| 41 |
class DiffusionText2World(PreTrainedModel):
|
| 42 |
config_class = DiffusionText2WorldConfig
|
| 43 |
|
|
@@ -69,6 +73,7 @@ class DiffusionText2World(PreTrainedModel):
|
|
| 69 |
seed=config.seed,
|
| 70 |
)
|
| 71 |
|
|
|
|
| 72 |
def forward(self, prompt):
|
| 73 |
cfg = self.config
|
| 74 |
# Handle multiple prompts if prompt file is provided
|
|
@@ -118,7 +123,7 @@ class DiffusionText2World(PreTrainedModel):
|
|
| 118 |
log.info(f"Saved prompt to {prompt_save_path}")
|
| 119 |
|
| 120 |
def save_pretrained(self, save_directory, **kwargs):
|
| 121 |
-
# We don't save anything
|
| 122 |
pass
|
| 123 |
|
| 124 |
@classmethod
|
|
|
|
| 3 |
import torch
|
| 4 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 5 |
|
| 6 |
+
# TODO: This is a bug to fix. Huggingface cannot download .cosmos1.models.diffusion.inference.inference_utils because it's in a subfolder.
|
| 7 |
from .cosmos1.models.diffusion.inference.inference_utils import add_common_arguments, validate_args
|
| 8 |
from .cosmos1.models.diffusion.inference.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
| 9 |
import .cosmos1.utils.log as log
|
| 10 |
import .cosmos1.utils.misc as misc
|
| 11 |
from .cosmos1.utils.io import read_prompts_from_file, save_video
|
| 12 |
|
| 13 |
+
|
| 14 |
+
# custom config class
|
| 15 |
class DiffusionText2WorldConfig(PretrainedConfig):
|
| 16 |
model_type = "DiffusionText2World"
|
| 17 |
def __init__(self, **kwargs):
|
|
|
|
| 41 |
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
| 42 |
|
| 43 |
|
| 44 |
+
# custom model calss
|
| 45 |
class DiffusionText2World(PreTrainedModel):
|
| 46 |
config_class = DiffusionText2WorldConfig
|
| 47 |
|
|
|
|
| 73 |
seed=config.seed,
|
| 74 |
)
|
| 75 |
|
| 76 |
+
# modifed from text2world.py demo function
|
| 77 |
def forward(self, prompt):
|
| 78 |
cfg = self.config
|
| 79 |
# Handle multiple prompts if prompt file is provided
|
|
|
|
| 123 |
log.info(f"Saved prompt to {prompt_save_path}")
|
| 124 |
|
| 125 |
def save_pretrained(self, save_directory, **kwargs):
|
| 126 |
+
# We don't save anything, but need this function to override
|
| 127 |
pass
|
| 128 |
|
| 129 |
@classmethod
|