Beyond VLM-Based Rewards: Diffusion-Native Latent Reward Modeling

arXiv GitHub Hugging Face


DiNa-LRM: A diffusion-native latent reward model. It achieves competitive reward accuracy while being significantly cheaper for alignment by operating directly in the latent space.


πŸš€ Quick Start

To use this reward model, ensure you have the diffusion_rm package installed:

# 1. Create a new conda environment
conda create -n diffusion-rm python=3.10 -y
conda activate diffusion-rm

# 2. Install the package in editable mode
# This will install all necessary dependencies including torch and diffusers
pip install git+https://github.com/HKUST-C4G/diffusion-rm.git

1. Reward from Pipeline-Generated Latents

In this scenario, the model scores the "clean" latent directly produced by the diffusion transformer before the final VAE decoding.

import torch
from diffusers import StableDiffusion3Pipeline
from diffusion_rm.models.sd3_rm import encode_prompt
from diffusion_rm.infer.inference import DRMInferencer

# Load SD3.5 Pipeline
device = torch.device('cuda:0')
dtype = torch.bfloat16
pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-medium", 
    torch_dtype=dtype
).to(device)
pipe.vae.to(device, dtype=dtype)
pipe.text_encoder.to(device, dtype=dtype)
pipe.text_encoder_2.to(device, dtype=dtype)
pipe.text_encoder_3.to(device, dtype=dtype)
pipe.transformer.to(device, dtype=dtype)

text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]

def compute_text_embeddings(text_encoders, tokenizers, prompts):
    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds = encode_prompt(
            text_encoders, tokenizers, prompts, max_sequence_length=256
        )
        prompt_embeds = prompt_embeds.to(text_encoders[0].device)
        pooled_prompt_embeds = pooled_prompt_embeds.to(text_encoders[0].device)
    return prompt_embeds, pooled_prompt_embeds

# Initialize DiNa-LRM Scorer
scorer = DRMInferencer(
    pipeline=pipe,
    config_path=None,
    model_path="liuhuohuo/DiNa-LRM-SD35M-12layers",
    device=device,
    model_dtype=dtype,
    load_from_disk=False,
)

# 1. Generate latents (Set output_type='latent' for DiNa-LRM)
prompt = "A girl walking in the street"

with torch.no_grad():
    # Helper to get embeddings
    prompt_embeds, pooled_embeds = compute_text_embeddings(text_encoders, tokenizers, [prompt])
    
    output = pipe(
        prompt_embeds=prompt_embeds, 
        pooled_prompt_embeds=pooled_embeds, 
        num_inference_steps=40, 
        guidance_scale=4.5, 
        output_type='latent'
    )
    latents = output.images
    

# 2. Compute reward
with torch.no_grad():
    raw_score = scorer.reward(
        text_conds={'encoder_hidden_states': prompt_embeds, 'pooled_projections': pooled_embeds},
        latents=latents,
        u=0.4
    )
    score = (raw_score + 10.0) / 10.0
    print(f"DiNa-LRM Score: {score.item()}")

# 3. [Optional] decode and save images
with torch.no_grad():
    latents_decoded = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents_decoded.to(pipe.vae.dtype), return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    
image.save("example.png")

2. Reward from Local Image Files

To score existing images, we first encode the image into the latent space using the VAE encoder.

import torch
import torchvision.transforms as T
from PIL import Image
from diffusers import StableDiffusion3Pipeline
from diffusion_rm.models.sd3_rm import encode_prompt
from diffusion_rm.infer.inference import DRMInferencer

# Load SD3.5 Pipeline
device = torch.device('cuda:0')
dtype = torch.bfloat16
pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-medium", 
    torch_dtype=dtype
).to(device)
pipe.vae.to(device, dtype=dtype)
pipe.text_encoder.to(device, dtype=dtype)
pipe.text_encoder_2.to(device, dtype=dtype)
pipe.text_encoder_3.to(device, dtype=dtype)
pipe.transformer.to(device, dtype=dtype)

text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]

def compute_text_embeddings(text_encoders, tokenizers, prompts):
    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds = encode_prompt(
            text_encoders, tokenizers, prompts, max_sequence_length=256
        )
        prompt_embeds = prompt_embeds.to(text_encoders[0].device)
        pooled_prompt_embeds = pooled_prompt_embeds.to(text_encoders[0].device)
    return prompt_embeds, pooled_prompt_embeds

# Initialize DiNa-LRM Scorer
scorer = DRMInferencer(
    pipeline=pipe,
    config_path=None,
    model_path="liuhuohuo/DiNa-LRM-SD35M-12layers",
    device=device,
    model_dtype=dtype,
    load_from_disk=False,
)

# 1. Load and Preprocess Image
image_path = "assets/example.png"
raw_image = Image.open(image_path).convert("RGB")
transform = T.Compose([
    T.ToTensor(), 
    T.Normalize([0.5], [0.5])
])
image_tensor = transform(raw_image).unsqueeze(0).to(device, dtype=dtype)

prompt = "A girl walking in the street"

with torch.no_grad():
    # Helper to get embeddings
    prompt_embeds, pooled_embeds = compute_text_embeddings(text_encoders, tokenizers, [prompt])


# 2. Encode to Latent Space
with torch.no_grad():
    latents = pipe.vae.encode(image_tensor).latent_dist.sample()
    # Apply SD3-specific scaling and shift
    latents = (latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor

# 3. Compute Reward
# Note: score normalization is often calculated as: score = (raw_score + 10.0) / 10.0
raw_score = scorer.reward(
    text_conds={'encoder_hidden_states': prompt_embeds, 'pooled_projections': pooled_embeds},
    latents=latents,
    u=0.1 # Lower u is recommended for static/clean images
)
score = (raw_score + 10.0) / 10.0
print(f"Local Image Score: {score.item()}")
Downloads last month
25
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for liuhuohuo/DiNa-LRM-SD35M-12layers

Finetuned
(52)
this model

Paper for liuhuohuo/DiNa-LRM-SD35M-12layers