File size: 6,953 Bytes
94c2704 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | import random
import torch
import numpy as np
import torch.nn.functional as F
from src.PeptiVerse.inference import PeptiVersePredictor
from src.utils.model_utils import _print
class MadSBMSampler:
def __init__(self, model, config, device, guidance=None):
self.config = config
self.device = device
self.model = model
self.tokenizer = model.tokenizer
self.mask_id = self.tokenizer.mask_token_id
self.eps = config.time_embed.min_time
self.seed_everything(seed=42)
if guidance:
self.guidance = guidance
self.peptiverse = PeptiVersePredictor(
manifest_path="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse/best_models.txt",
classifier_weight_root="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse",
device=self.device
)
@torch.inference_mode()
def sample(self, xt, num_steps, tracer, target_toks=None, guidance=None):
xt = xt.clone()
B, L = xt.shape
assert B == 1, "Do only 1 sequence at a time"
t_max = 1.0 - self.eps
dt = 1.0 / num_steps
attn_mask = torch.ones_like(xt, device=self.device)
action_traj = {}
tot_action = 0.0
tracer.log_step(xt=xt, step_idx=0)
converge_idx = num_steps
converged = False
for k in range(num_steps):
# t decreases from 1 --> 0 as our model was trained that t=1 --> noise and t=0 --> clean
prog = (k + 1) / float(num_steps)
t_val = t_max - (t_max - self.eps) * prog
t = torch.full((B,), fill_value=float(t_val), device=self.device) # B = 1 during sampling
# predicted control field --> B, L, V
outs = self.model(input_ids=xt, attention_mask=attn_mask, t=t)
u_tilt = outs['dit']
total_logits = outs['madsbm']
esm_logits = outs['esm']
if self.config.model.ablate:
actional = self.compute_action(u_tilt, esm_logits=None)
else:
actional = self.compute_action(u_tilt, esm_logits=esm_logits)
action_traj[f"action_step_{k+1}"] = actional
tot_action += (actional * dt)
# Compute jump rates and jump probs
# P(jump) = 1 - exp(-rate * dt)
r_theta = torch.exp(u_tilt * self.config.sampling.rate_scale)
R_tot = r_theta.sum(dim=-1) # 1, L
rate = (- R_tot * self.config.sampling.jump_scale * dt).clamp(min=-40.0, max=0.0)
jump_prob = 1.0 - torch.exp(rate)
# Scale and filter logits with nucleus sampling
logits = total_logits.clone()
logits /= self.config.sampling.tau
logits = self.top_p_filter(logits, self.config.sampling.top_p)
# Sample new tokens
probs = F.softmax(logits, dim=-1)
probs = probs.view(-1, probs.size(-1))
sample = torch.multinomial(probs, 1)
candidate_toks = sample.view(B, L)
# determine tokens we can change
rand = torch.rand(B, L, device=self.device)
can_jump = (rand < jump_prob)
updatable = can_jump & self.is_masked(xt)
# Update the sequence
if guidance:
chosen_candidate = self.binding_guidance(probs, target_toks, B, L)
xt[updatable] = chosen_candidate[updatable]
else:
xt[updatable] = candidate_toks[updatable]
tracer.log_step(xt=xt, step_idx = k+1)
if k == num_steps-1:
final_logits = total_logits
still_masked = self.is_masked(xt)
if not converged and not self.is_masked(xt).any():
converge_idx = k + 1
converged = True
# Copy over remaining tokens
if still_masked.any():
final_toks = final_logits.argmax(dim=-1)
xt[still_masked] = final_toks[still_masked]
tracer.log_step(xt, num_steps + 1)
binding_affin = self.peptiverse.predict_binding_affinity(
mode = 'wt',
target_ids = target_toks,
binder_ids = xt
)['affinity']
return xt, binding_affin
def binding_guidance(self, probs, target_toks, B, L):
M = self.config.sampling.M
candidate_toks = []
affinities = []
for _ in range(M):
ith_sample = torch.multinomial(probs, 1).view(B, L)
candidate_toks.append(ith_sample)
for toks in candidate_toks:
pred = self.peptiverse.predict_binding_affinity(
mode = 'wt',
target_ids = target_toks,
binder_ids = toks.detach()
)['affinity']
affinities.append(pred)
affinities = torch.tensor(affinities, dtype=torch.float32)
weights = F.softmax(affinities / self.config.sampling.tau, dim=0)
chosen_idx = torch.multinomial(weights, 1).item()
return candidate_toks[chosen_idx]
def compute_action(self, u_tilt, esm_logits=None):
""" Computes the action functional for evals """
if esm_logits is not None:
R0 = torch.softmax(esm_logits, dim=-1)
else:
R0 = 1.0 / self.tokenizer.vocab_size
psi_u = torch.exp(u_tilt) - u_tilt - 1.0
action_per_tok = (R0 * psi_u).sum(dim=-1) # R0 goes to 1 in both cases
return action_per_tok.mean().item()
def top_p_filter(self, logits, p_val):
"""
Implementation of nucleus / top-p sampling
Masks out tokens that contribute to the bottom (1 - p) cumulative probability
"""
# Sort logits and get cumulative probabilities
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cum prob > p-val thresh
sorted_idx_to_remove = cum_probs > p_val
# Shift the indices to the right to keep also the first token above the threshold
sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone()
sorted_idx_to_remove[..., 0] = 0
idx_to_remove = sorted_idx_to_remove.scatter(-1, sorted_indices, sorted_idx_to_remove)
logits[idx_to_remove] = float('-inf')
return logits
def is_masked(self, xt):
return (xt == self.mask_id)
def seed_everything(self, seed):
if seed is None:
return
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if using multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
|