| import argparse |
| import os |
| import pickle |
| import sys |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| from rdkit import Chem |
| from scipy.stats import wasserstein_distance |
| from scipy.spatial.distance import jensenshannon |
| from tqdm import tqdm |
|
|
| basedir = Path(__file__).resolve().parent.parent.parent |
| sys.path.append(str(basedir)) |
|
|
| from src.data.data_utils import atom_encoder, bond_encoder, encode_atom |
| from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics, get_data_type |
| from src.sbdd_metrics.metrics import FullEvaluator |
|
|
|
|
| DATA_TYPES = data_types = FullEvaluator().dtypes |
|
|
| MEDCHEM_PROPS = [ |
| 'medchem.qed', |
| 'medchem.sa', |
| 'medchem.logp', |
| 'medchem.lipinski', |
| 'medchem.size', |
| 'medchem.n_rotatable_bonds', |
| 'energy.energy', |
| ] |
|
|
| DOCKING_PROPS = [ |
| 'gnina.vina_score', |
| 'gnina.gnina_score', |
| 'gnina.vina_efficiency', |
| 'gnina.gnina_efficiency', |
| ] |
|
|
| RELEVANT_INTERACTIONS = [ |
| 'interactions.HBAcceptor', |
| 'interactions.HBDonor', |
| 'interactions.HB', |
| 'interactions.PiStacking', |
| 'interactions.Hydrophobic', |
| |
| 'interactions.HBAcceptor.normalized', |
| 'interactions.HBDonor.normalized', |
| 'interactions.HB.normalized', |
| 'interactions.PiStacking.normalized', |
| 'interactions.Hydrophobic.normalized' |
| ] |
|
|
|
|
| def compute_discrete_distributions(smiles, name): |
| atom_counter = Counter() |
| bond_counter = Counter() |
|
|
| for smi in tqdm(smiles, desc=name): |
| mol = Chem.MolFromSmiles(smi) |
| mol = Chem.RemoveAllHs(mol, sanitize=False) |
| for atom in mol.GetAtoms(): |
| try: |
| encoded_atom = encode_atom(atom, atom_encoder=atom_encoder) |
| except KeyError: |
| continue |
| atom_counter[encoded_atom] += 1 |
| for bond in mol.GetBonds(): |
| bond_counter[bond_encoder[str(bond.GetBondType())]] += 1 |
|
|
| atom_distribution = np.zeros(len(atom_encoder)) |
| bond_distribution = np.zeros(len(bond_encoder)) |
|
|
| for k, v in atom_counter.items(): |
| atom_distribution[k] = v |
| for k, v in bond_counter.items(): |
| bond_distribution[k] = v |
|
|
| atom_distribution = atom_distribution / atom_distribution.sum() |
| bond_distribution = bond_distribution / bond_distribution.sum() |
|
|
| return atom_distribution, bond_distribution |
|
|
|
|
| def flatten_distribution(data, name, table): |
| aux = ['sample', 'sdf_file', 'pdb_file'] |
| method_distributions = defaultdict(list) |
|
|
| sdf2sample2size = defaultdict(dict) |
| for _, row in table.iterrows(): |
| sdf2sample2size[row['sdf_file']][int(row['sample'])] = row['medchem.size'] |
|
|
| for item in tqdm(data, desc=name): |
| if item['medchem.valid'] is not True: |
| continue |
|
|
| if 'interactions.HBAcceptor' in item and 'interactions.HBDonor' in item: |
| item['interactions.HB'] = item['interactions.HBAcceptor'] + item['interactions.HBDonor'] |
| |
| new_entries = {} |
| for key, value in item.items(): |
| if key.startswith('interactions'): |
| size = sdf2sample2size.get(item['sdf_file'], dict()).get(int(item['sample'])) |
| if size is not None: |
| new_entries[key + '.normalized'] = value / size |
| item.update(new_entries) |
| |
| for key, value in item.items(): |
| if value is None: |
| continue |
| if key in aux: |
| continue |
| if key == 'energy.energy' and abs(value) > 1000: |
| continue |
| |
| if get_data_type(key, DATA_TYPES, default=type(value)) == list: |
| method_distributions[key] += value |
| else: |
| method_distributions[key].append(value) |
| |
| return method_distributions |
|
|
|
|
| def prepare_baseline_data(root_path, baseline_name): |
| metrics_detailed = pd.read_csv(f'{root_path}/metrics_detailed.csv') |
| metrics_detailed = metrics_detailed[metrics_detailed['medchem.valid']] |
| distributions = pickle.load(open(f'{root_path}/metrics_data.pkl', 'rb')) |
| distributions = flatten_distribution(distributions, name=baseline_name, table=metrics_detailed) |
| distributions['energy.energy'] = [v for v in distributions['energy.energy'] if -1000 <= v <= 1000] |
| for prop in MEDCHEM_PROPS + DOCKING_PROPS: |
| distributions[prop] = metrics_detailed[prop].dropna().values.tolist() |
|
|
| smiles = metrics_detailed['representation.smiles'] |
| atom_distribution, bond_distribution = compute_discrete_distributions(smiles, name=baseline_name) |
| discrete_distributions = { |
| 'atom_types': atom_distribution, |
| 'bond_types': bond_distribution, |
| } |
|
|
| return distributions, discrete_distributions |
|
|
|
|
| if __name__ == '__main__': |
| p = argparse.ArgumentParser() |
| p.add_argument('--in_dir', type=Path, required=True, help='Directory with samples') |
| p.add_argument('--out_dir', type=str, required=True, help='Output directory') |
| p.add_argument('--n_samples', type=int, required=False, default=None, help='N samples per target') |
| p.add_argument('--reference_smiles', type=str, default=None, help='Path to the .npy file with reference SMILES (optional)') |
| p.add_argument('--crossdocked_dir', type=str, required=False, default=None, help='Crossdocked data dir for computing distances between distributions') |
| args = p.parse_args() |
|
|
| Path(args.out_dir).mkdir(parents=True, exist_ok=True) |
|
|
| print('Combining data') |
| data = [] |
| for file_path in tqdm(Path(args.in_dir).glob('metrics_data_*.pkl')): |
| with open(file_path, 'rb') as f: |
| d = pickle.load(f) |
| if args.n_samples is not None: |
| d = d[:args.n_samples] |
| data += d |
| with open(Path(args.out_dir, 'metrics_data.pkl'), 'wb') as f: |
| pickle.dump(data, f) |
|
|
| print('Combining detailed metrics') |
| tables = [] |
| for file_path in tqdm(Path(args.in_dir).glob('metrics_detailed_*.csv')): |
| table = pd.read_csv(file_path) |
| if args.n_samples is not None: |
| table = table.head(args.n_samples) |
| tables.append(table) |
|
|
| table_detailed = pd.concat(tables) |
| table_detailed.to_csv(Path(args.out_dir, 'metrics_detailed.csv'), index=False) |
|
|
| print('Computing aggregated metrics') |
| evaluator = FullEvaluator(gnina='gnina', reduce='reduce') |
| table_aggregated = aggregated_metrics( |
| table_detailed, |
| data_types=evaluator.dtypes, |
| validity_metric_name=VALIDITY_METRIC_NAME |
| ) |
|
|
| if args.reference_smiles is not None: |
| reference_smiles = np.load(args.reference_smiles) |
| col_metrics = collection_metrics( |
| table=table_detailed, |
| reference_smiles=reference_smiles, |
| validity_metric_name=VALIDITY_METRIC_NAME, |
| exclude_evaluators=[], |
| ) |
| table_aggregated = pd.concat([table_aggregated, col_metrics]) |
|
|
| table_aggregated.to_csv(Path(args.out_dir, 'metrics_aggregated.csv'), index=False) |
|
|
| |
| if args.crossdocked_dir is not None: |
|
|
| |
| crossdocked_distributions = None |
| crossdocked_discrete_distributions = None |
| precomputed_distr_path = f'{args.crossdocked_dir}/crossdocked_distributions.pkl' |
| precomputed_discrete_distr_path = f'{args.crossdocked_dir}/crossdocked_discrete_distributions.pkl' |
| if os.path.exists(precomputed_distr_path) and os.path.exists(precomputed_discrete_distr_path): |
| |
| with open(precomputed_distr_path, 'rb') as f: |
| crossdocked_distributions = pickle.load(f) |
| with open(precomputed_discrete_distr_path, 'rb') as f: |
| crossdocked_discrete_distributions = pickle.load(f) |
| else: |
| assert os.path.exists(f'{args.crossdocked_dir}/metrics_detailed.csv') |
| assert os.path.exists(f'{args.crossdocked_dir}/metrics_data.pkl') |
| crossdocked_distributions, crossdocked_discrete_distributions = prepare_baseline_data( |
| root_path=args.crossdocked_dir, |
| baseline_name='crossdocked' |
| ) |
| |
| with open(precomputed_distr_path, 'wb') as f: |
| pickle.dump(crossdocked_distributions, f) |
| with open(precomputed_discrete_distr_path, 'wb') as f: |
| pickle.dump(crossdocked_discrete_distributions, f) |
|
|
| |
| bonds = sorted([ |
| (k, len(v)) for k, v in crossdocked_distributions.items() |
| if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 2 |
| ], key=lambda t: t[1], reverse=True)[:5] |
| top_5_bonds = [t[0] for t in bonds] |
|
|
| angles = sorted([ |
| (k, len(v)) for k, v in crossdocked_distributions.items() |
| if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 3 |
| ], key=lambda t: t[1], reverse=True)[:5] |
| top_5_angles = [t[0] for t in angles] |
|
|
| |
| distributions, discrete_distributions = prepare_baseline_data(args.out_dir, 'samples') |
|
|
| |
| distances = {'method': 'method',} |
| relevant_columns = MEDCHEM_PROPS + DOCKING_PROPS + RELEVANT_INTERACTIONS + top_5_bonds + top_5_angles |
| for metric in distributions.keys(): |
| if metric not in relevant_columns: |
| continue |
|
|
| ref = crossdocked_distributions.get(metric) |
| |
| cur = [x for x in distributions.get(metric) if not pd.isna(x)] |
|
|
| if ref is not None and cur is not None and len(cur) > 0: |
| try: |
| distance = wasserstein_distance(ref, cur) |
| except: |
| from pdb import set_trace; set_trace() |
| num_ref = len(ref) |
| num_cur = len(cur) |
| distances[f'WD.{metric}'] = distance |
|
|
| for metric in crossdocked_discrete_distributions.keys(): |
| ref = crossdocked_discrete_distributions.get(metric) |
| cur = discrete_distributions.get(metric) |
| if ref is not None and cur is not None: |
| distance = jensenshannon(p=ref, q=cur) |
| num_ref = len(ref) |
| num_cur = len(cur) |
| distances[f'JS.{metric}'] = distance |
|
|
| dist_table = pd.DataFrame([distances]) |
| dist_table.to_csv(Path(args.out_dir, 'metrics_distances.csv'), index=False) |