| | import os, json |
| | from pathlib import Path |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader |
| | import optuna |
| | from datasets import load_from_disk, DatasetDict |
| | from scipy.stats import spearmanr |
| | from lightning.pytorch import seed_everything |
| | seed_everything(1986) |
| |
|
| | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float: |
| | rho = spearmanr(y_true, y_pred).correlation |
| | if rho is None or np.isnan(rho): |
| | return 0.0 |
| | return float(rho) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor: |
| | high = y >= 9.0 |
| | low = y < 7.0 |
| | mid = ~(high | low) |
| | cls = torch.zeros_like(y, dtype=torch.long) |
| | cls[mid] = 1 |
| | cls[low] = 2 |
| | return cls |
| |
|
| |
|
| | |
| | |
| | |
| | def load_split_paired(path: str): |
| | dd = load_from_disk(path) |
| | if not isinstance(dd, DatasetDict): |
| | raise ValueError(f"Expected DatasetDict at {path}") |
| | if "train" not in dd or "val" not in dd: |
| | raise ValueError(f"DatasetDict missing train/val at {path}") |
| | return dd["train"], dd["val"] |
| |
|
| |
|
| | |
| | |
| | |
| | def collate_pair_pooled(batch): |
| | Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) |
| | Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) |
| | y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
| | return Pt, Pb, y |
| |
|
| |
|
| | |
| | |
| | |
| | def collate_pair_unpooled(batch): |
| | B = len(batch) |
| | Ht = len(batch[0]["target_embedding"][0]) |
| | Hb = len(batch[0]["binder_embedding"][0]) |
| | Lt_max = max(int(x["target_length"]) for x in batch) |
| | Lb_max = max(int(x["binder_length"]) for x in batch) |
| |
|
| | Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32) |
| | Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32) |
| | Mt = torch.zeros(B, Lt_max, dtype=torch.bool) |
| | Mb = torch.zeros(B, Lb_max, dtype=torch.bool) |
| | y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
| |
|
| | for i, x in enumerate(batch): |
| | t = torch.tensor(x["target_embedding"], dtype=torch.float32) |
| | b = torch.tensor(x["binder_embedding"], dtype=torch.float32) |
| | lt, lb = t.shape[0], b.shape[0] |
| | Pt[i, :lt] = t |
| | Pb[i, :lb] = b |
| | Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool) |
| | Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool) |
| |
|
| | return Pt, Mt, Pb, Mb, y |
| |
|
| |
|
| | |
| | |
| | |
| | class CrossAttnPooled(nn.Module): |
| | """ |
| | pooled vectors -> treat as single-token sequences for cross attention |
| | """ |
| | def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
| | super().__init__() |
| | self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
| | self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(n_layers): |
| | self.layers.append(nn.ModuleDict({ |
| | "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
| | "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
| | "n1t": nn.LayerNorm(hidden), |
| | "n2t": nn.LayerNorm(hidden), |
| | "n1b": nn.LayerNorm(hidden), |
| | "n2b": nn.LayerNorm(hidden), |
| | "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| | "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| | })) |
| |
|
| | self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
| | self.reg = nn.Linear(hidden, 1) |
| | self.cls = nn.Linear(hidden, 3) |
| |
|
| | def forward(self, t_vec, b_vec): |
| | |
| | t = self.t_proj(t_vec).unsqueeze(0) |
| | b = self.b_proj(b_vec).unsqueeze(0) |
| |
|
| | for L in self.layers: |
| | t_attn, _ = L["attn_tb"](t, b, b) |
| | t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) |
| | t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) |
| |
|
| | b_attn, _ = L["attn_bt"](b, t, t) |
| | b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) |
| | b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) |
| |
|
| | t0 = t[0] |
| | b0 = b[0] |
| | z = torch.cat([t0, b0], dim=-1) |
| | h = self.shared(z) |
| | return self.reg(h).squeeze(-1), self.cls(h) |
| |
|
| |
|
| | class CrossAttnUnpooled(nn.Module): |
| | """ |
| | token sequences with masks; alternating cross attention. |
| | """ |
| | def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
| | super().__init__() |
| | self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
| | self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(n_layers): |
| | self.layers.append(nn.ModuleDict({ |
| | "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
| | "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
| | "n1t": nn.LayerNorm(hidden), |
| | "n2t": nn.LayerNorm(hidden), |
| | "n1b": nn.LayerNorm(hidden), |
| | "n2b": nn.LayerNorm(hidden), |
| | "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| | "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| | })) |
| |
|
| | self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
| | self.reg = nn.Linear(hidden, 1) |
| | self.cls = nn.Linear(hidden, 3) |
| |
|
| | def masked_mean(self, X, M): |
| | Mf = M.unsqueeze(-1).float() |
| | denom = Mf.sum(dim=1).clamp(min=1.0) |
| | return (X * Mf).sum(dim=1) / denom |
| |
|
| | def forward(self, T, Mt, B, Mb): |
| | |
| | T = self.t_proj(T) |
| | Bx = self.b_proj(B) |
| |
|
| | kp_t = ~Mt |
| | kp_b = ~Mb |
| |
|
| | for L in self.layers: |
| | |
| | T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) |
| | T = L["n1t"](T + T_attn) |
| | T = L["n2t"](T + L["fft"](T)) |
| |
|
| | |
| | B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) |
| | Bx = L["n1b"](Bx + B_attn) |
| | Bx = L["n2b"](Bx + L["ffb"](Bx)) |
| |
|
| | t_pool = self.masked_mean(T, Mt) |
| | b_pool = self.masked_mean(Bx, Mb) |
| | z = torch.cat([t_pool, b_pool], dim=-1) |
| | h = self.shared(z) |
| | return self.reg(h).squeeze(-1), self.cls(h) |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def eval_spearman_pooled(model, loader): |
| | model.eval() |
| | ys, ps = [], [] |
| | for t, b, y in loader: |
| | t = t.to(DEVICE, non_blocking=True) |
| | b = b.to(DEVICE, non_blocking=True) |
| | pred, _ = model(t, b) |
| | ys.append(y.numpy()) |
| | ps.append(pred.detach().cpu().numpy()) |
| | return safe_spearmanr(np.concatenate(ys), np.concatenate(ps)) |
| |
|
| | @torch.no_grad() |
| | def eval_spearman_unpooled(model, loader): |
| | model.eval() |
| | ys, ps = [], [] |
| | for T, Mt, B, Mb, y in loader: |
| | T = T.to(DEVICE, non_blocking=True) |
| | Mt = Mt.to(DEVICE, non_blocking=True) |
| | B = B.to(DEVICE, non_blocking=True) |
| | Mb = Mb.to(DEVICE, non_blocking=True) |
| | pred, _ = model(T, Mt, B, Mb) |
| | ys.append(y.numpy()) |
| | ps.append(pred.detach().cpu().numpy()) |
| | return safe_spearmanr(np.concatenate(ys), np.concatenate(ps)) |
| |
|
| | def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0): |
| | model.train() |
| | for t, b, y in loader: |
| | t = t.to(DEVICE, non_blocking=True) |
| | b = b.to(DEVICE, non_blocking=True) |
| | y = y.to(DEVICE, non_blocking=True) |
| | y_cls = affinity_to_class_tensor(y) |
| |
|
| | opt.zero_grad(set_to_none=True) |
| | pred, logits = model(t, b) |
| | L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls) |
| | L.backward() |
| | if clip is not None: |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) |
| | opt.step() |
| |
|
| | def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0): |
| | model.train() |
| | for T, Mt, B, Mb, y in loader: |
| | T = T.to(DEVICE, non_blocking=True) |
| | Mt = Mt.to(DEVICE, non_blocking=True) |
| | B = B.to(DEVICE, non_blocking=True) |
| | Mb = Mb.to(DEVICE, non_blocking=True) |
| | y = y.to(DEVICE, non_blocking=True) |
| | y_cls = affinity_to_class_tensor(y) |
| |
|
| | opt.zero_grad(set_to_none=True) |
| | pred, logits = model(T, Mt, B, Mb) |
| | L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls) |
| | L.backward() |
| | if clip is not None: |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) |
| | opt.step() |
| |
|
| |
|
| | |
| | |
| | |
| | def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float: |
| | lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True) |
| | wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True) |
| | dropout = trial.suggest_float("dropout", 0.0, 0.4) |
| | hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768]) |
| | n_heads = trial.suggest_categorical("n_heads", [4, 8]) |
| | n_layers = trial.suggest_int("n_layers", 1, 4) |
| | cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True) |
| | batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128]) |
| |
|
| | |
| | if mode == "pooled": |
| | Ht = len(train_ds[0]["target_embedding"]) |
| | Hb = len(train_ds[0]["binder_embedding"]) |
| | collate = collate_pair_pooled |
| | model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
| | train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
| | val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
| | eval_fn = eval_spearman_pooled |
| | train_fn = train_one_epoch_pooled |
| |
|
| | else: |
| | Ht = len(train_ds[0]["target_embedding"][0]) |
| | Hb = len(train_ds[0]["binder_embedding"][0]) |
| | collate = collate_pair_unpooled |
| | model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
| | train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
| | val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
| | eval_fn = eval_spearman_unpooled |
| | train_fn = train_one_epoch_unpooled |
| |
|
| | opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
| | loss_reg = nn.MSELoss() |
| | loss_cls = nn.CrossEntropyLoss() |
| |
|
| | best = -1e9 |
| | bad = 0 |
| | patience = 10 |
| |
|
| | for ep in range(1, 61): |
| | train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) |
| | rho = eval_fn(model, val_loader) |
| |
|
| | trial.report(rho, ep) |
| | if trial.should_prune(): |
| | raise optuna.TrialPruned() |
| |
|
| | if rho > best + 1e-6: |
| | best = rho |
| | bad = 0 |
| | else: |
| | bad += 1 |
| | if bad >= patience: |
| | break |
| |
|
| | return float(best) |
| |
|
| |
|
| | |
| | |
| | |
| | def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50): |
| | out_dir = Path(out_dir) |
| | out_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | train_ds, val_ds = load_split_paired(dataset_path) |
| | print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}") |
| |
|
| | study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) |
| | study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials) |
| |
|
| | study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False) |
| | best = study.best_trial |
| | best_params = dict(best.params) |
| |
|
| | |
| | lr = float(best_params["lr"]) |
| | wd = float(best_params["weight_decay"]) |
| | dropout = float(best_params["dropout"]) |
| | hidden = int(best_params["hidden_dim"]) |
| | n_heads = int(best_params["n_heads"]) |
| | n_layers = int(best_params["n_layers"]) |
| | cls_w = float(best_params["cls_weight"]) |
| | batch = int(best_params["batch_size"]) |
| |
|
| | loss_reg = nn.MSELoss() |
| | loss_cls = nn.CrossEntropyLoss() |
| |
|
| | if mode == "pooled": |
| | Ht = len(train_ds[0]["target_embedding"]) |
| | Hb = len(train_ds[0]["binder_embedding"]) |
| | model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
| | collate = collate_pair_pooled |
| | train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
| | val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
| | eval_fn = eval_spearman_pooled |
| | train_fn = train_one_epoch_pooled |
| | else: |
| | Ht = len(train_ds[0]["target_embedding"][0]) |
| | Hb = len(train_ds[0]["binder_embedding"][0]) |
| | model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
| | collate = collate_pair_unpooled |
| | train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
| | val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
| | eval_fn = eval_spearman_unpooled |
| | train_fn = train_one_epoch_unpooled |
| |
|
| | opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
| |
|
| | best_rho = -1e9 |
| | bad = 0 |
| | patience = 20 |
| | best_state = None |
| |
|
| | for ep in range(1, 201): |
| | train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) |
| | rho = eval_fn(model, val_loader) |
| |
|
| | if rho > best_rho + 1e-6: |
| | best_rho = rho |
| | bad = 0 |
| | best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} |
| | else: |
| | bad += 1 |
| | if bad >= patience: |
| | break |
| |
|
| | if best_state is not None: |
| | model.load_state_dict(best_state) |
| |
|
| | |
| | torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt") |
| | with open(out_dir / "best_params.json", "w") as f: |
| | json.dump(best_params, f, indent=2) |
| |
|
| | print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)") |
| | ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True) |
| | ap.add_argument("--out_dir", type=str, required=True) |
| | ap.add_argument("--n_trials", type=int, default=50) |
| | args = ap.parse_args() |
| |
|
| | run( |
| | dataset_path=args.dataset_path, |
| | out_dir=args.out_dir, |
| | mode=args.mode, |
| | n_trials=args.n_trials, |
| | ) |
| |
|