from diffusers.schedulers import UniPCMultistepScheduler import torch class ViBTScheduler(UniPCMultistepScheduler): def __init__(self, **kwargs): super().__init__(**{**kwargs, "use_flow_sigmas": True}) self.set_parameters() def set_parameters(self, noise_scale=1.0, shift_gamma=5.0, seed=None): self.noise_scale = noise_scale self.config.flow_shift = shift_gamma self.generator = ( None if seed is None else torch.Generator("cuda").manual_seed(seed) ) def step(self, model_output, timestep, sample, **kwargs): delta_t = ( max(self.timesteps[self.timesteps < timestep]) - timestep if any(self.timesteps < timestep) else -timestep - 1 ) / 1000 current_t = (timestep + 1) / 1000.0 eta = (-delta_t * (current_t + delta_t) / current_t) ** 0.5 noise = torch.randn( sample.shape, generator=self.generator, device=sample.device, dtype=sample.dtype, ) latents = sample + delta_t * model_output + eta * self.noise_scale * noise return (latents,) @classmethod def from_scheduler( cls, scheduler: UniPCMultistepScheduler, noise_scale=1.0, shift_gamma=5.0 ): obj = cls.__new__(cls) obj.__dict__ = scheduler.__dict__.copy() obj.set_parameters(noise_scale, shift_gamma) return obj