Unconditional Image Generation
Diffusers
Safetensors
English
bitdance
imagenet
class-conditional
custom-pipeline
Instructions to use BiliSakura/BitDance-ImageNet-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/BitDance-ImageNet-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/BitDance-ImageNet-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| import torch | |
| from safetensors.torch import load_file as load_safetensors | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| try: | |
| from .transformer.qae import VQModel | |
| except ImportError: # pragma: no cover | |
| from transformer.qae import VQModel | |
| class BitDanceImageNetAutoencoder(ModelMixin, ConfigMixin): | |
| def __init__(self, ddconfig: Dict[str, Any], num_codebooks: int = 4): | |
| super().__init__() | |
| self.runtime_model = VQModel(ddconfig, num_codebooks) | |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): | |
| del args, kwargs | |
| model_dir = Path(pretrained_model_name_or_path) | |
| config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) | |
| model = cls(ddconfig=config["ddconfig"], num_codebooks=int(config.get("num_codebooks", 4))) | |
| state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") | |
| model.runtime_model.load_state_dict(state, strict=True) | |
| model.eval() | |
| return model | |
| def encode(self, x: torch.Tensor): | |
| return self.runtime_model.encode(x) | |
| def decode(self, z: torch.Tensor): | |
| return self.runtime_model.decode(z) | |
| def forward(self, z: torch.Tensor): | |
| return self.decode(z) | |