Beyond VLM-Based Rewards: Diffusion-Native Latent Reward Modeling
Paper
β’
2602.11146
β’
Published
DiNa-LRM: A diffusion-native latent reward model. It achieves competitive reward accuracy while being significantly cheaper for alignment by operating directly in the latent space.
To use this reward model, ensure you have the diffusion_rm package installed:
# 1. Create a new conda environment
conda create -n diffusion-rm python=3.10 -y
conda activate diffusion-rm
# 2. Install the package in editable mode
# This will install all necessary dependencies including torch and diffusers
pip install git+https://github.com/HKUST-C4G/diffusion-rm.git
In this scenario, the model scores the "clean" latent directly produced by the diffusion transformer before the final VAE decoding.
import torch
from diffusers import StableDiffusion3Pipeline
from diffusion_rm.models.sd3_rm import encode_prompt
from diffusion_rm.infer.inference import DRMInferencer
# Load SD3.5 Pipeline
device = torch.device('cuda:0')
dtype = torch.bfloat16
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium",
torch_dtype=dtype
).to(device)
pipe.vae.to(device, dtype=dtype)
pipe.text_encoder.to(device, dtype=dtype)
pipe.text_encoder_2.to(device, dtype=dtype)
pipe.text_encoder_3.to(device, dtype=dtype)
pipe.transformer.to(device, dtype=dtype)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]
def compute_text_embeddings(text_encoders, tokenizers, prompts):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompts, max_sequence_length=256
)
prompt_embeds = prompt_embeds.to(text_encoders[0].device)
pooled_prompt_embeds = pooled_prompt_embeds.to(text_encoders[0].device)
return prompt_embeds, pooled_prompt_embeds
# Initialize DiNa-LRM Scorer
scorer = DRMInferencer(
pipeline=pipe,
config_path=None,
model_path="liuhuohuo/DiNa-LRM-SD35M-12layers",
device=device,
model_dtype=dtype,
load_from_disk=False,
)
# 1. Generate latents (Set output_type='latent' for DiNa-LRM)
prompt = "A girl walking in the street"
with torch.no_grad():
# Helper to get embeddings
prompt_embeds, pooled_embeds = compute_text_embeddings(text_encoders, tokenizers, [prompt])
output = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_embeds,
num_inference_steps=40,
guidance_scale=4.5,
output_type='latent'
)
latents = output.images
# 2. Compute reward
with torch.no_grad():
raw_score = scorer.reward(
text_conds={'encoder_hidden_states': prompt_embeds, 'pooled_projections': pooled_embeds},
latents=latents,
u=0.4
)
score = (raw_score + 10.0) / 10.0
print(f"DiNa-LRM Score: {score.item()}")
# 3. [Optional] decode and save images
with torch.no_grad():
latents_decoded = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
image = pipe.vae.decode(latents_decoded.to(pipe.vae.dtype), return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
image.save("example.png")
To score existing images, we first encode the image into the latent space using the VAE encoder.
import torch
import torchvision.transforms as T
from PIL import Image
from diffusers import StableDiffusion3Pipeline
from diffusion_rm.models.sd3_rm import encode_prompt
from diffusion_rm.infer.inference import DRMInferencer
# Load SD3.5 Pipeline
device = torch.device('cuda:0')
dtype = torch.bfloat16
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium",
torch_dtype=dtype
).to(device)
pipe.vae.to(device, dtype=dtype)
pipe.text_encoder.to(device, dtype=dtype)
pipe.text_encoder_2.to(device, dtype=dtype)
pipe.text_encoder_3.to(device, dtype=dtype)
pipe.transformer.to(device, dtype=dtype)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]
def compute_text_embeddings(text_encoders, tokenizers, prompts):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, prompts, max_sequence_length=256
)
prompt_embeds = prompt_embeds.to(text_encoders[0].device)
pooled_prompt_embeds = pooled_prompt_embeds.to(text_encoders[0].device)
return prompt_embeds, pooled_prompt_embeds
# Initialize DiNa-LRM Scorer
scorer = DRMInferencer(
pipeline=pipe,
config_path=None,
model_path="liuhuohuo/DiNa-LRM-SD35M-12layers",
device=device,
model_dtype=dtype,
load_from_disk=False,
)
# 1. Load and Preprocess Image
image_path = "assets/example.png"
raw_image = Image.open(image_path).convert("RGB")
transform = T.Compose([
T.ToTensor(),
T.Normalize([0.5], [0.5])
])
image_tensor = transform(raw_image).unsqueeze(0).to(device, dtype=dtype)
prompt = "A girl walking in the street"
with torch.no_grad():
# Helper to get embeddings
prompt_embeds, pooled_embeds = compute_text_embeddings(text_encoders, tokenizers, [prompt])
# 2. Encode to Latent Space
with torch.no_grad():
latents = pipe.vae.encode(image_tensor).latent_dist.sample()
# Apply SD3-specific scaling and shift
latents = (latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
# 3. Compute Reward
# Note: score normalization is often calculated as: score = (raw_score + 10.0) / 10.0
raw_score = scorer.reward(
text_conds={'encoder_hidden_states': prompt_embeds, 'pooled_projections': pooled_embeds},
latents=latents,
u=0.1 # Lower u is recommended for static/clean images
)
score = (raw_score + 10.0) / 10.0
print(f"Local Image Score: {score.item()}")
Base model
stabilityai/stable-diffusion-3.5-medium