Unconditional Image Generation
Diffusers
Safetensors
English
bitdance
imagenet
class-conditional
custom-pipeline
Instructions to use BiliSakura/BitDance-ImageNet-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/BitDance-ImageNet-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/BitDance-ImageNet-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
File size: 1,478 Bytes
fc1f31d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict
import torch
from safetensors.torch import load_file as load_safetensors
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
try:
from .transformer.qae import VQModel
except ImportError: # pragma: no cover
from transformer.qae import VQModel
class BitDanceImageNetAutoencoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, ddconfig: Dict[str, Any], num_codebooks: int = 4):
super().__init__()
self.runtime_model = VQModel(ddconfig, num_codebooks)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
del args, kwargs
model_dir = Path(pretrained_model_name_or_path)
config = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
model = cls(ddconfig=config["ddconfig"], num_codebooks=int(config.get("num_codebooks", 4)))
state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors")
model.runtime_model.load_state_dict(state, strict=True)
model.eval()
return model
def encode(self, x: torch.Tensor):
return self.runtime_model.encode(x)
def decode(self, z: torch.Tensor):
return self.runtime_model.decode(z)
def forward(self, z: torch.Tensor):
return self.decode(z)
|