| import argparse |
| import sys |
| import yaml |
| import torch |
| import numpy as np |
| import pickle |
| from argparse import Namespace |
|
|
| from pathlib import Path |
|
|
| basedir = Path(__file__).resolve().parent.parent |
| sys.path.append(str(basedir)) |
|
|
| from src import utils |
| from src.utils import dict_to_namespace, namespace_to_dict |
| from src.analysis.visualization_utils import mols_to_pdbfile, mol_as_pdb |
| from src.data.data_utils import TensorDict, Residues |
| from src.data.postprocessing import process_all |
| from src.model.lightning import DrugFlow |
| from src.sbdd_metrics.evaluation import compute_all_metrics_drugflow |
|
|
| from tqdm import tqdm |
| from pdb import set_trace |
|
|
|
|
| def combine(base_args, override_args): |
| assert not isinstance(base_args, dict) |
| assert not isinstance(override_args, dict) |
|
|
| arg_dict = base_args.__dict__ |
| for key, value in override_args.__dict__.items(): |
| if key not in arg_dict or arg_dict[key] is None: |
| print(f"Add parameter {key}: {value}") |
| arg_dict[key] = value |
| elif isinstance(value, Namespace): |
| arg_dict[key] = combine(arg_dict[key], value) |
| else: |
| print(f"Replace parameter {key}: {arg_dict[key]} -> {value}") |
| arg_dict[key] = value |
| return base_args |
|
|
|
|
| def path_to_str(input_dict): |
| for key, value in input_dict.items(): |
| if isinstance(value, dict): |
| input_dict[key] = path_to_str(value) |
| else: |
| input_dict[key] = str(value) if isinstance(value, Path) else value |
| return input_dict |
|
|
|
|
| def sample(cfg, model_params, samples_dir, job_id=0, n_jobs=1): |
| print('Sampling...') |
| model = DrugFlow.load_from_checkpoint(cfg.checkpoint, map_location=cfg.device, strict=False, |
| **model_params) |
| model.setup(stage='fit' if cfg.set == 'train' else cfg.set) |
| model.eval().to(cfg.device) |
|
|
| dataloader = getattr(model, f'{cfg.set}_dataloader')() |
| print(f'Real batch size is {dataloader.batch_size * cfg.n_samples}') |
|
|
| name2count = {} |
| for i, data in enumerate(tqdm(dataloader)): |
| if i % n_jobs != job_id: |
| print(f'Skipping batch {i}') |
| continue |
|
|
| new_data = { |
| 'ligand': TensorDict(**data['ligand']).to(cfg.device), |
| 'pocket': Residues(**data['pocket']).to(cfg.device), |
| } |
| try: |
| rdmols, rdpockets, names = model.sample( |
| data=new_data, |
| n_samples=cfg.n_samples, |
| num_nodes=("ground_truth" if cfg.sample_with_ground_truth_size else None) |
| ) |
| except Exception as e: |
| if cfg.set == 'train': |
| names = data['ligand']['name'] |
| print(f'Failed to sample for {names}: {e}') |
| continue |
| else: |
| raise e |
|
|
| for mol, pocket, name in zip(rdmols, rdpockets, names): |
| name = name.replace('.sdf', '') |
| idx = name2count.setdefault(name, 0) |
| output_dir = Path(samples_dir, name) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| if cfg.postprocess: |
| mol = process_all(mol, largest_frag=True, adjust_aromatic_Ns=True, relax_iter=0) |
|
|
| for prop in mol.GetAtoms()[0].GetPropsAsDict().keys(): |
| |
| mol.SetDoubleProp(prop, np.mean([a.GetDoubleProp(prop) for a in mol.GetAtoms()])) |
|
|
| |
| out_pdb_path = Path(output_dir, f'{idx}_ligand_{prop}.pdb') |
| mol_as_pdb(mol, out_pdb_path, bfactor=prop) |
|
|
| out_sdf_path = Path(output_dir, f'{idx}_ligand.sdf') |
| out_pdb_path = Path(output_dir, f'{idx}_pocket.pdb') |
| utils.write_sdf_file(out_sdf_path, [mol]) |
| mols_to_pdbfile([pocket], out_pdb_path) |
|
|
| name2count[name] += 1 |
|
|
|
|
| def evaluate(cfg, model_params, samples_dir): |
| print('Evaluation...') |
| data, table_detailed, table_aggregated = compute_all_metrics_drugflow( |
| in_dir=samples_dir, |
| gnina_path=model_params['train_params'].gnina, |
| reduce_path=cfg.reduce, |
| reference_smiles_path=Path(model_params['train_params'].datadir, 'train_smiles.npy'), |
| n_samples=cfg.n_samples, |
| exclude_evaluators=[] if cfg.exclude_evaluators is None else cfg.exclude_evaluators, |
| ) |
| with open(Path(samples_dir, 'metrics_data.pkl'), 'wb') as f: |
| pickle.dump(data, f) |
| table_detailed.to_csv(Path(samples_dir, 'metrics_detailed.csv'), index=False) |
| table_aggregated.to_csv(Path(samples_dir, 'metrics_aggregated.csv'), index=False) |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser() |
| p.add_argument('--config', type=str) |
| p.add_argument('--job_id', type=int, default=0, help='Job ID') |
| p.add_argument('--n_jobs', type=int, default=1, help='Number of jobs') |
| args = p.parse_args() |
|
|
| with open(args.config, 'r') as f: |
| cfg = yaml.safe_load(f) |
| cfg = dict_to_namespace(cfg) |
|
|
| utils.set_deterministic(seed=cfg.seed) |
| utils.disable_rdkit_logging() |
|
|
| model_params = torch.load(cfg.checkpoint, map_location=cfg.device)['hyper_parameters'] |
| if 'model_args' in cfg: |
| ckpt_args = dict_to_namespace(model_params) |
| model_params = combine(ckpt_args, cfg.model_args).__dict__ |
|
|
| ckpt_path = Path(cfg.checkpoint) |
| ckpt_name = ckpt_path.parts[-1].split('.')[0] |
| n_steps = model_params['simulation_params'].n_steps |
| samples_dir = Path(cfg.sample_outdir, cfg.set, f'{ckpt_name}_T={n_steps}') or \ |
| Path(ckpt_path.parent.parent, 'samples', cfg.set, f'{ckpt_name}_T={n_steps}') |
| assert cfg.set in {'val', 'test', 'train'} |
| samples_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| with open(Path(samples_dir, 'model_params.yaml'), 'w') as f: |
| yaml.dump(path_to_str(namespace_to_dict(model_params)), f) |
| with open(Path(samples_dir, 'sampling_params.yaml'), 'w') as f: |
| yaml.dump(path_to_str(namespace_to_dict(cfg)), f) |
|
|
| if cfg.sample: |
| sample(cfg, model_params, samples_dir, job_id=args.job_id, n_jobs=args.n_jobs) |
|
|
| if cfg.evaluate: |
| assert args.job_id == 0 and args.n_jobs == 1, 'Evaluation is not parallelised on GPU machines' |
| evaluate(cfg, model_params, samples_dir) |