ViBT / vibt /scheduler.py
Yuanshi's picture
init
0b31b45
from diffusers.schedulers import UniPCMultistepScheduler
import torch
class ViBTScheduler(UniPCMultistepScheduler):
def __init__(self, **kwargs):
super().__init__(**{**kwargs, "use_flow_sigmas": True})
self.set_parameters()
def set_parameters(self, noise_scale=1.0, shift_gamma=5.0, seed=None):
self.noise_scale = noise_scale
self.config.flow_shift = shift_gamma
self.generator = (
None if seed is None else torch.Generator("cuda").manual_seed(seed)
)
def step(self, model_output, timestep, sample, **kwargs):
delta_t = (
max(self.timesteps[self.timesteps < timestep]) - timestep
if any(self.timesteps < timestep)
else -timestep - 1
) / 1000
current_t = (timestep + 1) / 1000.0
eta = (-delta_t * (current_t + delta_t) / current_t) ** 0.5
noise = torch.randn(
sample.shape,
generator=self.generator,
device=sample.device,
dtype=sample.dtype,
)
latents = sample + delta_t * model_output + eta * self.noise_scale * noise
return (latents,)
@classmethod
def from_scheduler(
cls, scheduler: UniPCMultistepScheduler, noise_scale=1.0, shift_gamma=5.0
):
obj = cls.__new__(cls)
obj.__dict__ = scheduler.__dict__.copy()
obj.set_parameters(noise_scale, shift_gamma)
return obj