|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from diffusers.models import AutoencoderTiny |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
from diffusers.models.autoencoders.vae import EncoderOutput, DecoderOutput |
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
|
|
|
class Flux2TinyAutoEncoder(ModelMixin, ConfigMixin): |
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 3, |
|
|
out_channels: int = 3, |
|
|
latent_channels: int = 128, |
|
|
encoder_block_out_channels: list[int] = [64, 64, 64, 64], |
|
|
decoder_block_out_channels: list[int] = [64, 64, 64, 64], |
|
|
act_fn: str = "silu", |
|
|
upsampling_scaling_factor: int = 2, |
|
|
num_encoder_blocks: list[int] = [1, 3, 3, 3], |
|
|
num_decoder_blocks: list[int] = [3, 3, 3, 1], |
|
|
latent_magnitude: float = 3.0, |
|
|
latent_shift: float = 0.5, |
|
|
force_upcast: bool = False, |
|
|
scaling_factor: float = 0.13025, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.tiny_vae = AutoencoderTiny( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
encoder_block_out_channels=encoder_block_out_channels, |
|
|
decoder_block_out_channels=decoder_block_out_channels, |
|
|
act_fn=act_fn, |
|
|
latent_channels=latent_channels // 4, |
|
|
upsampling_scaling_factor=upsampling_scaling_factor, |
|
|
num_encoder_blocks=num_encoder_blocks, |
|
|
num_decoder_blocks=num_decoder_blocks, |
|
|
latent_magnitude=latent_magnitude, |
|
|
latent_shift=latent_shift, |
|
|
force_upcast=force_upcast, |
|
|
scaling_factor=scaling_factor, |
|
|
) |
|
|
self.extra_encoder = nn.Conv2d( |
|
|
latent_channels // 4, latent_channels, |
|
|
kernel_size=4, stride=2, padding=1 |
|
|
) |
|
|
self.extra_decoder = nn.ConvTranspose2d( |
|
|
latent_channels, latent_channels // 4, |
|
|
kernel_size=4, stride=2, padding=1 |
|
|
) |
|
|
self.residual_encoder = nn.Sequential( |
|
|
nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1), |
|
|
nn.GroupNorm(8, latent_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1), |
|
|
) |
|
|
self.residual_decoder = nn.Sequential( |
|
|
nn.Conv2d(latent_channels // 4, latent_channels // 4, kernel_size=3, padding=1), |
|
|
nn.GroupNorm(8, latent_channels // 4), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(latent_channels // 4, latent_channels // 4, kernel_size=3, padding=1), |
|
|
) |
|
|
|
|
|
def encode(self, x: torch.Tensor, return_dict: bool = True) -> EncoderOutput: |
|
|
encoded = self.tiny_vae.encode(x, return_dict=False)[0] |
|
|
compressed = self.extra_encoder(encoded) |
|
|
enhanced = self.residual_encoder(compressed) + compressed |
|
|
|
|
|
if return_dict: |
|
|
return EncoderOutput(latent=enhanced) |
|
|
return enhanced |
|
|
|
|
|
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput: |
|
|
decompressed = self.extra_decoder(z) |
|
|
enhanced = self.residual_decoder(decompressed) + decompressed |
|
|
decoded = self.tiny_vae.decode(enhanced, return_dict=False)[0] |
|
|
|
|
|
if return_dict: |
|
|
return DecoderOutput(sample=decoded) |
|
|
return decoded |
|
|
|
|
|
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> DecoderOutput: |
|
|
encoded = self.encode(sample, return_dict=False)[0] |
|
|
decoded = self.decode(encoded, return_dict=False)[0] |
|
|
|
|
|
if return_dict: |
|
|
return DecoderOutput(sample=decoded) |
|
|
return decoded |
|
|
|