| | import random |
| |
|
| | import torch |
| | import numpy as np |
| |
|
| | import torch.nn.functional as F |
| |
|
| | from src.PeptiVerse.inference import PeptiVersePredictor |
| | from src.utils.model_utils import _print |
| |
|
| |
|
| | class MadSBMSampler: |
| | def __init__(self, model, config, device, guidance=None): |
| | self.config = config |
| | self.device = device |
| | self.model = model |
| | self.tokenizer = model.tokenizer |
| | self.mask_id = self.tokenizer.mask_token_id |
| | self.eps = config.time_embed.min_time |
| | self.seed_everything(seed=42) |
| |
|
| | if guidance: |
| | self.guidance = guidance |
| | self.peptiverse = PeptiVersePredictor( |
| | manifest_path="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse/best_models.txt", |
| | classifier_weight_root="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse", |
| | device=self.device |
| | ) |
| |
|
| |
|
| | @torch.inference_mode() |
| | def sample(self, xt, num_steps, tracer, target_toks=None, guidance=None): |
| | xt = xt.clone() |
| | B, L = xt.shape |
| | assert B == 1, "Do only 1 sequence at a time" |
| |
|
| | t_max = 1.0 - self.eps |
| | dt = 1.0 / num_steps |
| | attn_mask = torch.ones_like(xt, device=self.device) |
| |
|
| | action_traj = {} |
| | tot_action = 0.0 |
| |
|
| | tracer.log_step(xt=xt, step_idx=0) |
| |
|
| | converge_idx = num_steps |
| | converged = False |
| |
|
| | for k in range(num_steps): |
| | |
| | prog = (k + 1) / float(num_steps) |
| | t_val = t_max - (t_max - self.eps) * prog |
| | t = torch.full((B,), fill_value=float(t_val), device=self.device) |
| |
|
| | |
| | outs = self.model(input_ids=xt, attention_mask=attn_mask, t=t) |
| |
|
| | u_tilt = outs['dit'] |
| | total_logits = outs['madsbm'] |
| | esm_logits = outs['esm'] |
| |
|
| | if self.config.model.ablate: |
| | actional = self.compute_action(u_tilt, esm_logits=None) |
| | else: |
| | actional = self.compute_action(u_tilt, esm_logits=esm_logits) |
| |
|
| | action_traj[f"action_step_{k+1}"] = actional |
| | tot_action += (actional * dt) |
| |
|
| | |
| | |
| | r_theta = torch.exp(u_tilt * self.config.sampling.rate_scale) |
| | R_tot = r_theta.sum(dim=-1) |
| | rate = (- R_tot * self.config.sampling.jump_scale * dt).clamp(min=-40.0, max=0.0) |
| | jump_prob = 1.0 - torch.exp(rate) |
| |
|
| | |
| | logits = total_logits.clone() |
| | logits /= self.config.sampling.tau |
| | logits = self.top_p_filter(logits, self.config.sampling.top_p) |
| |
|
| | |
| | probs = F.softmax(logits, dim=-1) |
| | probs = probs.view(-1, probs.size(-1)) |
| | sample = torch.multinomial(probs, 1) |
| | candidate_toks = sample.view(B, L) |
| |
|
| | |
| | rand = torch.rand(B, L, device=self.device) |
| | can_jump = (rand < jump_prob) |
| | updatable = can_jump & self.is_masked(xt) |
| |
|
| | |
| | if guidance: |
| | chosen_candidate = self.binding_guidance(probs, target_toks, B, L) |
| | xt[updatable] = chosen_candidate[updatable] |
| | else: |
| | xt[updatable] = candidate_toks[updatable] |
| |
|
| | tracer.log_step(xt=xt, step_idx = k+1) |
| |
|
| | if k == num_steps-1: |
| | final_logits = total_logits |
| | still_masked = self.is_masked(xt) |
| | |
| | if not converged and not self.is_masked(xt).any(): |
| | converge_idx = k + 1 |
| | converged = True |
| |
|
| | |
| | if still_masked.any(): |
| | final_toks = final_logits.argmax(dim=-1) |
| | xt[still_masked] = final_toks[still_masked] |
| | |
| | tracer.log_step(xt, num_steps + 1) |
| | |
| | binding_affin = self.peptiverse.predict_binding_affinity( |
| | mode = 'wt', |
| | target_ids = target_toks, |
| | binder_ids = xt |
| | )['affinity'] |
| |
|
| | return xt, binding_affin |
| |
|
| | |
| | def binding_guidance(self, probs, target_toks, B, L): |
| | M = self.config.sampling.M |
| | candidate_toks = [] |
| | affinities = [] |
| |
|
| | for _ in range(M): |
| | ith_sample = torch.multinomial(probs, 1).view(B, L) |
| | candidate_toks.append(ith_sample) |
| | |
| | for toks in candidate_toks: |
| | pred = self.peptiverse.predict_binding_affinity( |
| | mode = 'wt', |
| | target_ids = target_toks, |
| | binder_ids = toks.detach() |
| | )['affinity'] |
| | affinities.append(pred) |
| | |
| | affinities = torch.tensor(affinities, dtype=torch.float32) |
| | weights = F.softmax(affinities / self.config.sampling.tau, dim=0) |
| | chosen_idx = torch.multinomial(weights, 1).item() |
| |
|
| | return candidate_toks[chosen_idx] |
| |
|
| |
|
| | def compute_action(self, u_tilt, esm_logits=None): |
| | """ Computes the action functional for evals """ |
| | if esm_logits is not None: |
| | R0 = torch.softmax(esm_logits, dim=-1) |
| | else: |
| | R0 = 1.0 / self.tokenizer.vocab_size |
| |
|
| | psi_u = torch.exp(u_tilt) - u_tilt - 1.0 |
| | action_per_tok = (R0 * psi_u).sum(dim=-1) |
| |
|
| | return action_per_tok.mean().item() |
| |
|
| |
|
| | def top_p_filter(self, logits, p_val): |
| | """ |
| | Implementation of nucleus / top-p sampling |
| | Masks out tokens that contribute to the bottom (1 - p) cumulative probability |
| | """ |
| | |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| |
|
| | |
| | sorted_idx_to_remove = cum_probs > p_val |
| | |
| | |
| | sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() |
| | sorted_idx_to_remove[..., 0] = 0 |
| |
|
| | idx_to_remove = sorted_idx_to_remove.scatter(-1, sorted_indices, sorted_idx_to_remove) |
| | logits[idx_to_remove] = float('-inf') |
| | return logits |
| |
|
| |
|
| | def is_masked(self, xt): |
| | return (xt == self.mask_id) |
| |
|
| |
|
| | def seed_everything(self, seed): |
| | if seed is None: |
| | return |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|