| | import gzip |
| | import json |
| | import os |
| | import pickle |
| | from abc import abstractmethod |
| | from os.path import exists |
| | from typing import List |
| | import string |
| | import random |
| | import biotite.structure |
| | import numpy as np |
| | import pandas as pd |
| | import socket |
| | import torch |
| | from Bio.PDB import PDBParser |
| | from Bio.PDB.DSSP import DSSP |
| | from biopandas.pdb import PandasPdb |
| | from biotite.sequence import ProteinSequence |
| | from biotite.structure import get_chains |
| | from biotite.structure.io import pdbx, pdb |
| | from biotite.structure.residues import get_residues |
| | from torch_cluster import radius_graph, knn_graph |
| |
|
| | |
| | AF2_DATA_PATH = './data.files/af2.files/' |
| | |
| | AF2_REP_DATA_PATH = "NA" |
| | |
| | ESM_MODEL_SIZE = '650M' |
| | ESM_DATA_PATH = f'./data.files/esm.files/' |
| | |
| | MSA_DATA_PATH_ARCHIVE = './data.files/gMVP.MSA/' |
| | MSA_DATA_PATH = './data.files/MSA/' |
| | |
| | PAE_DATA_PATH = 'NA' |
| | |
| | |
| | MSA_ATTN_DATA_PATH = './data.files/esm.MSA/' |
| | NUM_THREADS = 42 |
| | |
| | with open(f'./utils/LANGUAGE_MODEL.{ESM_MODEL_SIZE}.pkl', 'rb') as f: |
| | LANGUAGE_MODEL = pickle.load(f) |
| | with open(f'./utils/ALPHABET_CONVERTER.{ESM_MODEL_SIZE}.pkl', 'rb') as f: |
| | ALPHABET_CONVERTER = pickle.load(f) |
| | with open(f'./utils/ESM_AA_EMBEDDING_DICT.{ESM_MODEL_SIZE}.pkl', 'rb') as f: |
| | ESM_AA_EMBEDDING_DICT = pickle.load(f) |
| | with open(f'./utils/ESM_AA_EMBEDDING_DICT.esm1b.pkl', 'rb') as f: |
| | ESM1b_AA_EMBEDDING_DICT = pickle.load(f) |
| | |
| | with open(f'./utils/AA_5_DIM_EMBED.pkl', 'rb') as f: |
| | AA_5DIM_EMBED = pickle.load(f) |
| | |
| | ESM_TOKENS = ['<cls>', '<pad>', '<eos>', '<unk>', |
| | 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', |
| | 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', |
| | 'X', 'B', 'U', 'Z', 'O', '.', '-', |
| | '<null_1>', '<mask>'] |
| | |
| | AA_DICT = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', |
| | 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', |
| | 'X', 'B', 'U', 'Z', 'O', '<mask>'] |
| | AA_DICT_HUMAN = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', |
| | 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C'] |
| | DSSP_DICT = ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-', 'P'] |
| | PTM_DICT = {'ac': 0, 'ga': 1, 'gl': 2, 'm1': 3, 'm2': 4, 'm3': 5, 'me': 6, 'p': 7, 'sm': 8, 'ub': 9} |
| |
|
| | class Mutation: |
| | """ |
| | A mutation object that stores the information of a mutation. |
| | Can specify max_len of sequence to crop the sequence. |
| | Can specify af2_file to ignore the input sequence and use the AF2 sequence instead. |
| | """ |
| | def __init__(self, uniprot_id, transcript_id, seq_orig, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len=2251, af2_file=None): |
| | |
| | self.seq = None |
| | self.seq_start = None |
| | self.seq_end = None |
| | self.seq_start_orig = None |
| | self.seq_end_orig = None |
| | self.pos = None |
| | self.uniprot_id = None |
| | self.af2_file = None |
| | self.af2_rep_file_prefix = None |
| | self.af2_seq_index = None |
| | self.msa_seq_index = None |
| | self.esm_seq_index = None |
| | self.af2_rep_index = None |
| | self.ref_aa = None |
| | self.alt_aa = None |
| | self.ESM_prefix = None |
| | self.crop = False |
| | self.seq_len = None |
| | self.seq_len_orig = None |
| | self.max_len = max_len |
| | self.half_max_len = max_len // 2 |
| | self.set_af2_fragment_idx(seq_orig, seq_orig_len, uniprot_id, pos_orig, af2_file) |
| | self.transcript_id = transcript_id |
| | self.set_ref_alt_aa(ref_aa, alt_aa) |
| | self.init_af2_file_idx() |
| | self.crop_fn() |
| | |
| | def set_af2_fragment_idx(self, seq_orig, seq_orig_len, uniprot_id, pos_orig, af2_file): |
| | self.seq_len_orig = seq_orig_len |
| | if isinstance(pos_orig, str): |
| | pos_orig = np.array([int(i) for i in pos_orig.split(';')]) |
| | else: |
| | pos_orig = np.array([int(pos_orig)]) |
| | if af2_file is None or pd.isna(af2_file): |
| | if uniprot_id.find('-F') != -1: |
| | idx = int(uniprot_id.split('-F')[-1]) |
| | uniprot_id = uniprot_id.split('-F')[0] |
| | seq_start = 1 |
| | seq_end = seq_orig_len |
| | self.seq_start_orig = seq_start |
| | self.seq_end_orig = seq_end |
| | seq = seq_orig |
| | pos = pos_orig |
| | self.ESM_prefix = f'{uniprot_id}-F{idx}' |
| | seq_len = 1400 |
| | self.af2_rep_file_prefix = f'{AF2_REP_DATA_PATH}/{uniprot_id}-F{idx}/{uniprot_id}-F{idx}' |
| | else: |
| | self.ESM_prefix = f'{uniprot_id}' |
| | if seq_orig_len > 2700: |
| | idx = min(max(1, pos_orig[0] // 200 - 2), seq_orig_len // 200 - 5) |
| | seq_start = (idx - 1) * 200 + 1 |
| | seq_end = min((idx + 6) * 200, seq_orig_len) |
| | self.seq_start_orig = seq_start |
| | self.seq_end_orig = seq_end |
| | seq = seq_orig[seq_start - 1:seq_end] |
| | pos = pos_orig - seq_start + 1 |
| | seq_len = seq_end - seq_start + 1 |
| | seq_start = 1 |
| | seq_end = seq_len |
| | else: |
| | idx = 1 |
| | seq_start = 1 |
| | seq_end = seq_orig_len |
| | self.seq_start_orig = seq_start |
| | self.seq_end_orig = seq_end |
| | seq_len = seq_orig_len |
| | seq = seq_orig |
| | pos = pos_orig |
| | if uniprot_id == "Q8WZ42": |
| | self.ESM_prefix = f'{uniprot_id}-F{idx}' |
| | if seq_orig_len >= 7000: |
| | self.af2_rep_file_prefix = f'{AF2_REP_DATA_PATH}/{uniprot_id}-F{idx}/{uniprot_id}-F{idx}' |
| | else: |
| | self.af2_rep_file_prefix = f'{AF2_REP_DATA_PATH}/{uniprot_id}/{uniprot_id}' |
| | self.seq = seq |
| | self.seq_start = seq_start |
| | self.seq_end = seq_end |
| | self.seq_len = seq_len |
| | self.pos = pos |
| | self.uniprot_id = uniprot_id |
| | self.af2_file = f'{AF2_DATA_PATH}/AF-{uniprot_id}-F{idx}-model_v4.pdb.gz' |
| | else: |
| | self.af2_file = af2_file |
| | self.ESM_prefix = uniprot_id |
| | self.seq = seq_orig |
| | self.seq_start = 1 |
| | self.seq_end = seq_orig_len |
| | self.seq_start_orig = self.seq_start |
| | self.seq_end_orig = self.seq_end |
| | self.seq_len = seq_orig_len |
| | self.pos = pos_orig |
| | self.uniprot_id = uniprot_id |
| |
|
| | def set_ref_alt_aa(self, ref_aa, alt_aa): |
| | |
| | if ";" in ref_aa or ";" in alt_aa: |
| | |
| | self.ref_aa = np.array(ref_aa.split(';')) |
| | self.alt_aa = np.array(alt_aa.split(';')) |
| | else: |
| | |
| | self.ref_aa = np.array([ref_aa]) |
| | self.alt_aa = np.array([alt_aa]) |
| |
|
| | def init_af2_file_idx(self): |
| | if not exists(self.af2_file): |
| | print(f'Warning: {self.uniprot_id} AF2 file not found: {self.af2_file}') |
| | self.af2_file = None |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.af2_seq_index = None |
| |
|
| | def crop_fn(self): |
| | seq_len = self.seq_len |
| | pos = self.pos |
| | seq_start = self.seq_start |
| | seq_end = self.seq_end |
| | seq = self.seq |
| | |
| | if seq_len >= self.max_len: |
| | if pos[0] <= self.half_max_len: |
| | seq_start = 1 |
| | seq_end = self.max_len |
| | seq = seq[:self.max_len] |
| | pos = pos |
| | seq_len = self.max_len |
| | elif seq_len - pos[0] <= self.max_len - self.half_max_len: |
| | seq_start = seq_len - self.max_len + 1 |
| | seq_end = seq_len |
| | seq = seq[seq_start - 1:] |
| | pos = pos - seq_start + 1 |
| | seq_len = self.max_len |
| | else: |
| | seq_start = pos[0] - self.half_max_len |
| | seq_end = pos[0] + self.max_len - self.half_max_len - 1 |
| | seq = seq[seq_start - 1:seq_end] |
| | pos = pos - seq_start + 1 |
| | seq_len = self.max_len |
| | self.crop = True |
| | self.seq = seq |
| | self.seq_start = seq_start |
| | self.seq_end = seq_end |
| | self.seq_len = seq_len |
| | self.pos = pos |
| |
|
| | def set_af2_seq_index(self, idx): |
| | self.af2_seq_index = idx |
| |
|
| | def set_msa_seq_index(self, idx): |
| | self.msa_seq_index = idx |
| |
|
| | def set_esm_seq_index(self, idx): |
| | self.esm_seq_index = idx |
| | |
| | def set_af2_rep_index(self, idx): |
| | self.af2_rep_index = idx |
| |
|
| |
|
| | class RandomPointMutation(Mutation): |
| | def __init__(self, uniprot_id, transcript_id, seq_orig, seq_orig_len, max_len=2251): |
| | pos_orig = np.random.randint(1, seq_orig_len + 1) |
| | ref_aa = seq_orig[pos_orig - 1] |
| | alt_aa = np.random.choice(list("ACDEFGHIKLMNPQRSTVWY")) |
| | super().__init__(uniprot_id, transcript_id, seq_orig, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len) |
| |
|
| |
|
| | class MaskPredictPointMutation(Mutation): |
| | |
| | def __init__(self, uniprot_id, transcript_id, seq_orig, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len=2251, af2_file=None): |
| | if pos_orig is None or pos_orig == 0: |
| | pos_orig = np.random.randint(1, seq_orig_len + 1) |
| | self.ESM_prefix = None |
| | self.max_len = max_len |
| | self.half_max_len = max_len // 2 |
| | super().__init__(uniprot_id, transcript_id, seq_orig, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len=max_len, af2_file=af2_file) |
| | |
| |
|
| | def init_af2_file_idx(self): |
| | |
| | if not exists(self.af2_file): |
| | print(f'Warning: {self.uniprot_id} AF2 file not found: {self.af2_file}') |
| | self.af2_file = None |
| | self.af2_seq_index = None |
| |
|
| |
|
| | def convert_to_onesite(dataset: pd.DataFrame): |
| | |
| | if 'ref_aa' not in dataset.columns: |
| | dataset['ref_aa'] = dataset['ref'] |
| | if 'alt_aa' not in dataset.columns: |
| | dataset['alt_aa'] = dataset['alt'] |
| | dataset_onesite = dataset.copy(deep=True) |
| | dataset_onesite = dataset_onesite.drop_duplicates(subset=['uniprotID', 'pos.orig']) |
| | |
| | |
| | |
| | score_cols = [col for col in dataset.columns if col.startswith('score')] |
| | confidence_cols = [col for col in dataset.columns if col.startswith('confidence.score')] |
| | |
| | if len(confidence_cols) == 0: |
| | confidence_cols = [f'confidence.score.{i}' for i in range(len(score_cols))] |
| | for col in confidence_cols: |
| | dataset[col] = 1 |
| | dataset_onesite[col] = 1 |
| | for i in dataset_onesite.index: |
| | subdataset = dataset[(dataset['uniprotID'] == dataset_onesite.loc[i, 'uniprotID']) & (dataset['pos.orig'] == dataset_onesite.loc[i, 'pos.orig'])] |
| | dataset_onesite.loc[i, 'ref_aa'] = ';'.join(subdataset['ref_aa'].values) |
| | dataset_onesite.loc[i, 'alt_aa'] = ';'.join(subdataset['alt_aa'].values) |
| | |
| | if len(score_cols) > 0: |
| | for col in score_cols: |
| | dataset_onesite.loc[i, col] = ';'.join(subdataset[col].values.astype('str')) |
| | if len(confidence_cols) > 0: |
| | for col in confidence_cols: |
| | dataset_onesite.loc[i, col] = ';'.join(subdataset[col].values.astype('str')) |
| | return dataset_onesite |
| |
|
| |
|
| | def load_structure(fpath, chain=None): |
| | """ |
| | Args: |
| | fpath: filepath to either pdb or cif file |
| | chain: the chain id or list of chain ids to load |
| | Returns: |
| | biotite.structure.AtomArray |
| | """ |
| | if fpath.endswith('cif'): |
| | with open(fpath) as fin: |
| | pdbxf = pdbx.PDBxFile.read(fin) |
| | structure = pdbx.get_structure(pdbxf, model=1) |
| | elif fpath.endswith('cif.gz'): |
| | with gzip.open(fpath, 'rt') as fin: |
| | pdbxf = pdbx.PDBxFile.read(fin) |
| | structure = pdbx.get_structure(pdbxf, model=1) |
| | elif fpath.endswith('pdb'): |
| | with open(fpath) as fin: |
| | pdbf = pdb.PDBFile.read(fin) |
| | structure = pdb.get_structure(pdbf, model=1) |
| | elif fpath.endswith('pdb.gz'): |
| | with gzip.open(fpath, 'rt') as fin: |
| | pdbf = pdb.PDBFile.read(fin) |
| | structure = pdb.get_structure(pdbf, model=1) |
| | else: |
| | raise ValueError("Invalid file extension") |
| | |
| | |
| | all_chains = get_chains(structure) |
| | if len(all_chains) == 0: |
| | raise ValueError('No chains found in the input file.') |
| | if chain is None: |
| | chain_ids = all_chains |
| | elif isinstance(chain, list): |
| | chain_ids = chain |
| | else: |
| | chain_ids = [chain] |
| | for chain in chain_ids: |
| | if chain not in all_chains: |
| | raise ValueError(f'Chain {chain} not found in input file') |
| | chain_filter = [a.chain_id in chain_ids for a in structure] |
| | structure = structure[chain_filter] |
| | return structure |
| |
|
| |
|
| | def extract_coords_from_structure(structure: biotite.structure.AtomArray): |
| | """ |
| | Args: |
| | structure: An instance of biotite AtomArray |
| | Returns: |
| | Tuple coords |
| | - coords is an L x 5 x 3 array for N, C, O, CA, CB coordinates |
| | """ |
| | coords = get_atom_coords_residue_wise(["N", "C", "O", "CA", "CB"], structure) |
| | return coords |
| |
|
| |
|
| | def extract_sidechain_from_structure(structure: biotite.structure.AtomArray): |
| | """ |
| | Args: |
| | structure: An instance of biotite AtomArray |
| | Returns: |
| | Tuple coords |
| | - coords is an L x 31 x 3 array for side chain coordinates |
| | """ |
| | coords = get_atom_coords_residue_wise(['CD', 'CD1', 'CD2', 'CE', 'CE1', |
| | 'CE2', 'CE3', 'CG', 'CG1', 'CG2', |
| | 'CH2', 'CZ', 'CZ2', 'CZ3', 'ND1', |
| | 'ND2', 'NE', 'NE1', 'NE2', 'NH1', |
| | 'NH2', 'NZ', 'OD1', 'OD2', 'OE1', |
| | 'OE2', 'OG', 'OG1', 'OH', 'SD', |
| | 'SG'], |
| | structure) |
| | return coords |
| |
|
| |
|
| | def extract_residues_from_structure(structure: biotite.structure.AtomArray): |
| | """ |
| | Args: |
| | structure: An instance of biotite AtomArray |
| | Returns: |
| | Tuple (coords, seq) |
| | - coords is an L x 3 x 3 array for N, CA, C coordinates |
| | - seq is the extracted sequence |
| | """ |
| | residue_identities = get_residues(structure)[1] |
| | seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) |
| | return seq |
| |
|
| |
|
| | def get_atom_coords_residue_wise(atoms: List[str], struct: biotite.structure.AtomArray): |
| | """ |
| | Example for atoms argument: ["N", "O", "CA", "C", "CB"] |
| | """ |
| |
|
| | def filterfn(s, axis=None): |
| | filters = np.stack([s.atom_name == name for name in atoms], axis=1) |
| | filter_sum = filters.sum(0) |
| | if not np.all(filter_sum <= np.ones(filters.shape[1])): |
| | raise RuntimeError("structure has multiple atoms with same name") |
| | index = filters.argmax(0) |
| | coords = s[index].coord |
| | coords[filter_sum == 0] = float("nan") |
| | return coords |
| |
|
| | return biotite.structure.apply_residue_wise(struct, struct, filterfn) |
| |
|
| |
|
| | def get_mutations(uniprot_id, transcript_id, seq, seq_orig_len, |
| | pos_orig, ref_aa, alt_aa, max_len=1400, af2_file=None): |
| | mutation = Mutation(uniprot_id, transcript_id, seq, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len, af2_file) |
| | if mutation.af2_file is None: |
| | print( |
| | f"No AF2 file found for this mutation "+ |
| | f"{mutation.uniprot_id}:{mutation.ref_aa}:{mutation.pos}:{mutation.alt_aa}. Skipping..." |
| | ) |
| | return False |
| | else: |
| | return mutation |
| |
|
| |
|
| | def get_random_point_mutations(uniprot_id, transcript_id, seq, seq_orig_len, |
| | pos_orig, ref_aa, alt_aa, score): |
| | if score == -1: |
| | point_mutation = RandomPointMutation(uniprot_id, transcript_id, seq, seq_orig_len) |
| | else: |
| | point_mutation = Mutation(uniprot_id, transcript_id, seq, seq_orig_len, pos_orig, ref_aa, alt_aa) |
| | if point_mutation.af2_file is None: |
| | return False |
| | else: |
| | return point_mutation |
| |
|
| |
|
| | def get_mask_predict_point_mutations(uniprot_id, transcript_id, seq, seq_orig_len, |
| | pos_orig, ref_aa, alt_aa, max_len=2251, af2_file=None): |
| | point_mutation = MaskPredictPointMutation(uniprot_id, transcript_id, seq, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len, af2_file) |
| | |
| | if point_mutation.af2_file is None: |
| | print( |
| | f"No AF2 file found for this mutation "+ |
| | f"{point_mutation.uniprot_id}:{point_mutation.ref_aa}:{point_mutation.pos}:{point_mutation.alt_aa}. Skipping..." |
| | ) |
| | return False |
| | else: |
| | return point_mutation |
| |
|
| |
|
| | def get_coords_from_af2(af2_file, add_sidechain=False): |
| | pdb_path = af2_file |
| | structure = load_structure(pdb_path) |
| | af2_coords = extract_coords_from_structure(structure) |
| | if add_sidechain: |
| | af2_coords_sidechain = extract_sidechain_from_structure(structure) |
| | af2_coords = np.concatenate([af2_coords, af2_coords_sidechain], axis=1) |
| | return af2_coords |
| |
|
| |
|
| | def get_plddt_from_af2(af2_file): |
| | pdb_file = PandasPdb().read_pdb(af2_file) |
| | pdb_file = pdb_file.df['ATOM'].drop_duplicates(subset=['residue_number']) |
| | plddt = pdb_file['b_factor'].values |
| | return plddt |
| |
|
| |
|
| | def get_dssp_from_af2(af2_file): |
| | p = PDBParser() |
| | with gzip.open(af2_file, 'rt') as f: |
| | structure = p.get_structure('', f) |
| | model = structure[0] |
| | |
| | |
| | |
| | random.seed(hash(af2_file)) |
| | tmpfile = '/share/descartes/Users/gz2294/tmp/'+ ''.join(random.choices(string.ascii_letters, k=5)) + '.pdb' |
| | with open(tmpfile, 'w') as f: |
| | f.write(gzip.open(af2_file, 'rt').read()) |
| | dssp = DSSP(model, tmpfile, file_type="PDB", dssp="/share/descartes/Users/gz2294/miniconda3/bin/mkdssp") |
| | os.remove(tmpfile) |
| | |
| | dssp = pd.DataFrame(dssp) |
| | sec_struc = np.eye(len(DSSP_DICT), dtype=np.float32)[[DSSP_DICT.index(i) for i in dssp.iloc[:, 2].values]] |
| | return np.concatenate([sec_struc, |
| | dssp.iloc[:, 3].values[:, None], |
| | dssp.iloc[:, 4].values[:, None] / 180 * np.pi, |
| | dssp.iloc[:, 5].values[:, None] / 180 * np.pi], axis=1) |
| |
|
| |
|
| | def get_ptm_from_mutation(mutation: Mutation, ptm_ref): |
| | |
| | |
| | uniprotID = mutation.uniprot_id |
| | ptm_ref = ptm_ref[ptm_ref['uniprotID'] == uniprotID] |
| | seq = mutation.seq |
| | |
| | ptm_ref['pos'] = ptm_ref['pos'] - mutation.seq_start_orig - mutation.seq_start + 1 |
| | ptm_ref = ptm_ref[ptm_ref['pos'] >= 0] |
| | ptm_ref = ptm_ref[ptm_ref['pos'] < mutation.seq_len] |
| | ptm_mat = np.zeros([mutation.seq_len, len(PTM_DICT)]) |
| | for i in ptm_ref.index: |
| | if ptm_ref['ref'].loc[i] == seq[ptm_ref['pos'].loc[i]]: |
| | ptm_mat[ptm_ref['pos'].loc[i], PTM_DICT[ptm_ref['type'].loc[i]]] = 1 |
| | return ptm_mat |
| |
|
| |
|
| | def get_knn_graphs_from_af2(af2_coords, radius=None, max_neighbors=None, loop=False, gpu_id=None): |
| | CA_coord = af2_coords[:, 3] |
| | if radius is None: |
| | edge_index = np.indices((af2_coords.shape[0], af2_coords.shape[0])).reshape(2, -1) |
| | |
| | if not loop: |
| | edge_index = edge_index[:, edge_index[0] != edge_index[1]] |
| | else: |
| | if max_neighbors is None: |
| | max_neighbors = af2_coords.shape[0] + 1 |
| | with torch.no_grad(): |
| | CA_coord = torch.from_numpy(CA_coord) |
| | edge_index = knn_graph( |
| | x=CA_coord.to(f'cuda:{gpu_id}') if gpu_id is not None and torch.cuda.is_available() else CA_coord, |
| | |
| | loop=loop, |
| | |
| | k=max_neighbors, |
| | num_workers=NUM_THREADS, |
| | ).detach().cpu().numpy() |
| | del CA_coord |
| | return edge_index |
| |
|
| |
|
| | def get_radius_graphs_from_af2(af2_coords, radius, loop=False, gpu_id=None): |
| | CA_coord = af2_coords[:, 3] |
| | max_neighbors = af2_coords.shape[0] + 1 |
| | with torch.no_grad(): |
| | CA_coord = torch.from_numpy(CA_coord) |
| | edge_index = radius_graph( |
| | x=CA_coord.to(f'cuda:{gpu_id}') if gpu_id is not None and torch.cuda.is_available() else CA_coord, |
| | r=radius, |
| | loop=loop, |
| | max_num_neighbors=max_neighbors, |
| | num_workers=NUM_THREADS, |
| | ).detach().cpu().numpy() |
| | del CA_coord |
| | return edge_index |
| |
|
| |
|
| | def get_radius_knn_graphs_from_af2(af2_coords, center_nodes, radius, max_neighbors, loop=False, gpu_id=None): |
| | |
| | CA_coord = af2_coords[:, 3] |
| | with torch.no_grad(): |
| | CA_coord = torch.from_numpy(CA_coord) |
| | edge_index = radius_graph( |
| | x=CA_coord.to(f'cuda:{gpu_id}') if gpu_id is not None and torch.cuda.is_available() else CA_coord, |
| | r=radius, |
| | loop=loop, |
| | max_num_neighbors=af2_coords.shape[0] + 1, |
| | num_workers=NUM_THREADS, |
| | ).detach().cpu().numpy() |
| | |
| | edge_index_radius = edge_index[:, np.isin(edge_index[0], center_nodes)] |
| | |
| | edge_index = knn_graph( |
| | x=CA_coord.to(f'cuda:{gpu_id}') if gpu_id is not None and torch.cuda.is_available() else CA_coord, |
| | loop=loop, |
| | k=max_neighbors, |
| | num_workers=NUM_THREADS, |
| | ).detach().cpu().numpy() |
| | del CA_coord |
| | |
| | edge_index = edge_index[:, np.isin(edge_index[0], edge_index_radius.flatten()) & np.isin(edge_index[1], edge_index_radius.flatten())] |
| | return edge_index |
| |
|
| |
|
| | def get_graphs_from_neighbor(af2_coords, max_neighbors=None, loop=False): |
| | nodes = af2_coords.shape[0] |
| | if max_neighbors is None: |
| | |
| | max_neighbors = nodes + 1 |
| | edge_graph = np.ones((nodes, nodes)) |
| | |
| | edge_graph *= np.tri(nodes, k=int(np.floor(max_neighbors / 2))) \ |
| | * np.tri(nodes, k=int(np.floor(max_neighbors / 2))).T |
| | edge_index = np.array(np.where(edge_graph == 1)) |
| | if not loop: |
| | edge_index = edge_index[:, edge_index[0] != edge_index[1]] |
| | return edge_index |
| |
|
| |
|
| | def get_embedding_from_esm2(protein, check_mode=True, seq_start=None, seq_end=None): |
| | if isinstance(protein, str): |
| | file_path = f"{ESM_DATA_PATH}/{protein}.representations.layer.48.npy" |
| | if os.path.exists(file_path): |
| | if check_mode: |
| | return True |
| | wt_orig = np.load(file_path) |
| | |
| | batch_tokens = wt_orig[max(0, seq_start): |
| | min(wt_orig.shape[0] - 1, seq_end + 1)] |
| | else: |
| | if check_mode: |
| | return False |
| | batch_tokens = np.zeros([seq_end - seq_start + 1, 5120 if ESM_MODEL_SIZE == "15B" else 1280]) |
| | elif isinstance(protein, np.ndarray): |
| | batch_tokens = protein[max(0, seq_start): |
| | min(protein.shape[0] - 1, seq_end + 1)] |
| | else: |
| | raise ValueError("protein must be either a string of uniprotID or a numpy array") |
| | return batch_tokens |
| |
|
| |
|
| | def get_esm_dict_from_uniprot(uniprotID): |
| | file_path = f"{ESM_DATA_PATH}/{uniprotID}.representations.layer.48.npy" |
| | wt_orig = np.load(file_path) |
| | return wt_orig |
| |
|
| |
|
| | def get_af2_single_rep_dict_from_prefix(uniprotID_prefix, filter=False): |
| | |
| | file_path = f"{uniprotID_prefix}_single_repr_rank_001_alphafold2_ptm_model_1_seed_000.npy" |
| | wt_orig = np.load(file_path) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return wt_orig |
| |
|
| |
|
| | def get_af2_pairwise_rep_dict_from_prefix(uniprotID_prefix): |
| | file_path = f"{uniprotID_prefix}_pair_repr_rank_001_alphafold2_ptm_model_1_seed_000.npy" |
| | wt_orig = np.load(file_path) |
| | |
| | |
| | |
| | |
| | |
| | |
| | return wt_orig |
| |
|
| |
|
| | def get_embedding_from_esm1b(protein, check_mode=True, seq_start=None, seq_end=None): |
| | if isinstance(protein, str): |
| | file_path = f"/share/vault/Users/gz2294/Data/DMS/ClinVar.HGMD.PrimateAI.syn/esm1b.embedding.uniprotIDs/{protein}.representations.layer.48.npy" |
| | if os.path.exists(file_path): |
| | if check_mode: |
| | return True |
| | wt_orig = np.load(file_path) |
| | |
| | batch_tokens = wt_orig[max(0, seq_start): |
| | min(wt_orig.shape[0] - 1, seq_end + 1)] |
| | else: |
| | if check_mode: |
| | return False |
| | batch_tokens = np.zeros([seq_end - seq_start + 1, 5120 if ESM_MODEL_SIZE == "15B" else 1280]) |
| | elif isinstance(protein, np.ndarray): |
| | batch_tokens = protein[max(0, seq_start): |
| | min(protein.shape[0] - 1, seq_end + 1)] |
| | else: |
| | raise ValueError("protein must be either a string of uniprotID or a numpy array") |
| | return batch_tokens |
| |
|
| |
|
| | def get_embedding_from_onehot(seq, seq_start=None, seq_end=None, return_idx=False, aa_dict=None, return_onehot_mat=False): |
| | if aa_dict is None: |
| | idx = [AA_DICT.index(aa) for aa in seq] |
| | protein = np.eye(len(AA_DICT))[idx] |
| | one_hot_mat = np.eye(len(AA_DICT)) |
| | else: |
| | idx = [aa_dict.index(aa) for aa in seq] |
| | protein = np.eye(len(aa_dict))[idx] |
| | one_hot_mat = np.eye(len(aa_dict)) |
| | if seq_start is not None and seq_end is not None: |
| | batch_tokens = protein[max(0, seq_start - 1): min(protein.shape[0], seq_end)] |
| | else: |
| | batch_tokens = protein |
| | if return_idx: |
| | if return_onehot_mat: |
| | return batch_tokens, np.array(idx), one_hot_mat |
| | else: |
| | return batch_tokens, np.array(idx) |
| | else: |
| | if return_onehot_mat: |
| | return batch_tokens, one_hot_mat |
| | else: |
| | return batch_tokens |
| |
|
| |
|
| | def get_embedding_from_esm_onehot(seq, seq_start=None, seq_end=None, return_idx=False, aa_dict=None, return_onehot_mat=False): |
| | if aa_dict is None: |
| | idx = [ESM_TOKENS.index('<cls>')] + [ESM_TOKENS.index(aa) for aa in seq] + [ESM_TOKENS.index('<eos>')] |
| | |
| | protein = np.array(idx) |
| | else: |
| | idx = [aa_dict.index(aa) for aa in seq] |
| | protein = np.array(idx) |
| | if seq_start is not None and seq_end is not None: |
| | batch_tokens = protein[max(0, seq_start - 1): min(protein.shape[0], seq_end)] |
| | else: |
| | batch_tokens = protein |
| | if return_idx: |
| | if return_onehot_mat: |
| | return batch_tokens, np.array(idx), None |
| | else: |
| | return batch_tokens, np.array(idx) |
| | else: |
| | if return_onehot_mat: |
| | return batch_tokens, None |
| | else: |
| | return batch_tokens |
| |
|
| |
|
| | def get_embedding_from_5dim(seq, seq_start=None, seq_end=None): |
| | protein = np.array([AA_5DIM_EMBED[aa] for aa in seq]) |
| | if seq_start is not None and seq_end is not None: |
| | batch_tokens = protein[max(0, seq_start - 1): min(protein.shape[0], seq_end)] |
| | else: |
| | batch_tokens = protein |
| | return batch_tokens |
| |
|
| |
|
| | def get_embedding_from_onehot_nonzero(seq, seq_start=None, seq_end=None, return_idx=False, |
| | aa_dict=None, min_prob=0.001, return_onehot_mat=False): |
| | if aa_dict is None: |
| | aa_dict = AA_DICT |
| | one_hot_mat = np.eye(len(aa_dict)) |
| | n_special_tok = 0 |
| | for special_tok in ['<mask>', '<pad>']: |
| | if special_tok in aa_dict: |
| | one_hot_mat[aa_dict.index(special_tok), :] = -1 |
| | one_hot_mat[:, aa_dict.index(special_tok)] = -1 |
| | one_hot_mat[aa_dict.index(special_tok), aa_dict.index(special_tok)] = 2 |
| | n_special_tok += 1 |
| | one_hot_mat[one_hot_mat == 0] = min_prob |
| | one_hot_mat[one_hot_mat == 1] = 1 - min_prob * (len(aa_dict) - n_special_tok) |
| | one_hot_mat[one_hot_mat == -1] = 0 |
| | one_hot_mat[one_hot_mat == 2] = 1 |
| | idx = [aa_dict.index(aa) for aa in seq] |
| | protein = one_hot_mat[idx] |
| | if seq_start is not None and seq_end is not None: |
| | batch_tokens = protein[max(0, seq_start - 1): min(protein.shape[0], seq_end)] |
| | else: |
| | batch_tokens = protein |
| | if return_idx: |
| | if return_onehot_mat: |
| | return batch_tokens, np.array(idx), one_hot_mat |
| | else: |
| | return batch_tokens, np.array(idx) |
| | else: |
| | if return_onehot_mat: |
| | return batch_tokens, one_hot_mat |
| | else: |
| | return batch_tokens |
| |
|
| | |
| | def get_conservation_from_msa(mutation: Mutation, check_mode=False): |
| | transcript = mutation.transcript_id |
| | seq = mutation.seq |
| | seq_start = mutation.seq_start_orig |
| | seq_end = mutation.seq_end_orig |
| | if seq_start is None: |
| | seq_start = 1 |
| | if seq_end is None: |
| | seq_end = len(seq) |
| | msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
| | if not os.path.exists(f'{MSA_DATA_PATH}/{transcript}.pickle'): |
| | matched_line = False |
| | else: |
| | with open(os.path.join(MSA_DATA_PATH, transcript + '.pickle'), 'rb') as file: |
| | msa_mat = pickle.load(file) |
| | msa_seq = ''.join(msa_alphabet[msa_mat[seq_start - 1:seq_end, 0].astype(int)]) |
| | if mutation.crop: |
| | msa_seq = msa_seq[mutation.seq_start -1:mutation.seq_end] |
| | matched_line = msa_seq == seq |
| | if matched_line: |
| | if check_mode: |
| | return True |
| | |
| | conservation = msa_mat[seq_start - 1:seq_end, 1:41] |
| | else: |
| | if check_mode: |
| | return False |
| | conservation = np.zeros([seq_end - seq_start + 1, 40]) |
| | if mutation.crop: |
| | conservation = conservation[mutation.seq_start -1:mutation.seq_end] |
| | return conservation |
| |
|
| |
|
| | def get_msa_dict_from_transcript_archive(transcript): |
| | msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
| | if pd.isna(transcript) or not os.path.exists(f'{MSA_DATA_PATH}/{transcript}.pickle'): |
| | msa_seq = '' |
| | conservation = np.zeros([0, 20]) |
| | msa = np.zeros([0, 200]) |
| | else: |
| | with open(os.path.join(MSA_DATA_PATH, transcript + '.pickle'), 'rb') as file: |
| | msa_mat = pickle.load(file) |
| | msa_seq = ''.join(msa_alphabet[msa_mat[:, 0].astype(int)]) |
| | conservation = msa_mat[:, 1:21] |
| | msa = msa_mat[:, 21:221] |
| | return msa_seq, conservation, msa |
| |
|
| |
|
| | def get_msa_dict_from_transcript(uniprotID): |
| | msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
| | if pd.isna(uniprotID) or not os.path.exists(f'{MSA_DATA_PATH}/{uniprotID}_MSA.npy'): |
| | msa_seq = '' |
| | conservation = np.zeros([0, 20]) |
| | msa = np.zeros([0, 199]) |
| | else: |
| | msa_mat = np.load(f'{MSA_DATA_PATH}/{uniprotID}_MSA.npy') |
| | msa_seq = ''.join(msa_alphabet[msa_mat[:, 0].astype(int)]) |
| | conservation = np.eye(21)[msa_mat.astype(int)].mean(axis=1)[:, :20] |
| | msa = msa_mat |
| | return msa_seq, conservation, msa |
| |
|
| |
|
| | def get_confidence_from_af2file(af2file, pLDDT): |
| | uniprotID = af2file.split('/')[-1].split('.')[0].split('-model')[0] |
| | if pd.isna(uniprotID) or not os.path.exists(f'{PAE_DATA_PATH}/{uniprotID[3:6]}/{uniprotID}-predicted_aligned_error_v4.json.gz'): |
| | |
| | |
| | pae = (200 - pLDDT[None, :] - pLDDT[:, None]) / 4 if not pLDDT is None else None |
| | else: |
| | with gzip.open(f'{PAE_DATA_PATH}/{uniprotID[3:6]}/{uniprotID}-predicted_aligned_error_v4.json.gz', 'rt') as f: |
| | pae = json.load(f) |
| | |
| | pae = np.array(pae[0]['predicted_aligned_error']) |
| | return pae |
| |
|
| |
|
| | def get_msa(mutation: Mutation, check_mode=False): |
| | transcript = mutation.transcript_id |
| | seq = mutation.seq |
| | seq_start = mutation.seq_start_orig |
| | seq_end = mutation.seq_end_orig |
| | if seq_start is None: |
| | seq_start = 1 |
| | if seq_end is None: |
| | seq_end = len(seq) |
| | msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
| | if not os.path.exists(f'{MSA_DATA_PATH}/{transcript}.pickle'): |
| | matched_line = False |
| | else: |
| | with open(os.path.join(MSA_DATA_PATH, transcript + '.pickle'), 'rb') as file: |
| | msa_mat = pickle.load(file) |
| | msa_seq = ''.join(msa_alphabet[msa_mat[seq_start - 1:seq_end, 0].astype(int)]) |
| | if mutation.crop: |
| | msa_seq = msa_seq[mutation.seq_start -1:mutation.seq_end] |
| | matched_line = msa_seq == seq |
| | if matched_line: |
| | if check_mode: |
| | return True |
| | |
| | msa = msa_mat[seq_start - 1:seq_end, 21:221] |
| | else: |
| | if check_mode: |
| | return False |
| | msa = np.zeros([seq_end - seq_start + 1, 200]) |
| | if mutation.crop: |
| | msa = msa[mutation.seq_start -1:mutation.seq_end] |
| | return msa |
| |
|
| |
|
| | def get_logits_from_esm2(protein, check_mode=True, seq_start=None, seq_end=None): |
| | if isinstance(protein, str): |
| | file_path = f"{ESM_DATA_PATH}/{protein}.logits.npy" |
| | if os.path.exists(file_path): |
| | if check_mode: |
| | return True |
| | wt_orig = np.load(file_path) |
| | |
| | batch_tokens = wt_orig[max(0, seq_start): |
| | min(wt_orig.shape[0] - 1, seq_end + 1)] |
| | else: |
| | if check_mode: |
| | return False |
| | batch_tokens = np.zeros([seq_end - seq_start + 1, 32]) |
| | elif isinstance(protein, np.ndarray): |
| | batch_tokens = protein[max(0, seq_start): |
| | min(protein.shape[0] - 1, seq_end + 1)] |
| | else: |
| | raise ValueError("protein must be either a string of uniprotID or a numpy array") |
| | return batch_tokens |
| |
|
| |
|
| | def get_attn_from_msa(transcript, seq, check_mode=False, seq_start=None, seq_end=None): |
| | NUM_LAYERS = 6 |
| | msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
| | if isinstance(transcript, str): |
| | if pd.isna(transcript) \ |
| | or not os.path.exists(f'{MSA_DATA_PATH}/{transcript}.pickle') \ |
| | or not os.path.exists(f'{MSA_ATTN_DATA_PATH}/{transcript}.row_attentions.pt'): |
| | matched_line = False |
| | else: |
| | with open(os.path.join(MSA_DATA_PATH, transcript + '.pickle'), 'rb') as file: |
| | msa_mat = pickle.load(file) |
| | if seq_start is None: |
| | seq_start = 1 |
| | if seq_end is None: |
| | seq_end = len(seq) |
| | msa_seq = ''.join(msa_alphabet[msa_mat[seq_start - 1:seq_end, 0].astype(int)]) |
| | matched_line = msa_seq == seq |
| | if matched_line: |
| | if check_mode: |
| | return True |
| | msa_row_attns = torch.load( |
| | os.path.join(MSA_ATTN_DATA_PATH, transcript + '.row_attentions.pt')).detach().numpy() |
| | msa_contacts = torch.load(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.contacts.pt')).detach().numpy() |
| | |
| | |
| | msa_row_attns = msa_row_attns[:, (12 - NUM_LAYERS):, :, seq_start - 1:seq_end, seq_start - 1:seq_end] |
| | msa_contacts = msa_contacts[:, seq_start - 1:seq_end, seq_start - 1:seq_end] |
| | msa_pairwise = np.concatenate([msa_row_attns.reshape(-1, msa_row_attns.shape[-2], msa_row_attns.shape[-1]), |
| | msa_contacts], axis=0).transpose((1, 2, 0)) |
| | else: |
| | if check_mode: |
| | return False |
| | msa_pairwise = np.zeros([seq_end - seq_start + 1, seq_end - seq_start + 1, NUM_LAYERS * 12 + 1]) |
| | elif isinstance(transcript, tuple): |
| | msa_row_attns = transcript[0] |
| | msa_contacts = transcript[1] |
| | if msa_row_attns is not None and msa_contacts is not None: |
| | msa_row_attns = msa_row_attns[:, (12 - NUM_LAYERS):, :, seq_start - 1:seq_end, seq_start - 1:seq_end] |
| | msa_contacts = msa_contacts[:, seq_start - 1:seq_end, seq_start - 1:seq_end] |
| | msa_pairwise = np.concatenate([msa_row_attns.reshape(-1, msa_row_attns.shape[-2], msa_row_attns.shape[-1]), |
| | msa_contacts], axis=0).transpose((1, 2, 0)) |
| | else: |
| | msa_pairwise = np.zeros([seq_end - seq_start + 1, seq_end - seq_start + 1, NUM_LAYERS * 12 + 1]) |
| | else: |
| | raise ValueError("transcript must be either a string of transcriptID" |
| | " or a tuple of msa_row_attns and msa_contacts") |
| | return msa_pairwise |
| |
|
| |
|
| | def get_contacts_from_msa(mutation: Mutation, check_mode=False): |
| | transcript = mutation.transcript_id |
| | seq = mutation.seq |
| | seq_start = mutation.seq_start |
| | seq_end = mutation.seq_end |
| | msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
| | if pd.isna(transcript) \ |
| | or not os.path.exists(f'{MSA_DATA_PATH_ARCHIVE}/{transcript}.pickle') \ |
| | or not os.path.exists(f'{MSA_ATTN_DATA_PATH}/{transcript}.contacts.pt'): |
| | matched_line = False |
| | else: |
| | with open(os.path.join(MSA_DATA_PATH_ARCHIVE, transcript + '.pickle'), 'rb') as file: |
| | msa_mat = pickle.load(file) |
| | if seq_start is None: |
| | seq_start = 1 |
| | if seq_end is None: |
| | seq_end = len(seq) |
| | msa_seq = ''.join(msa_alphabet[msa_mat[seq_start - 1:seq_end, 0].astype(int)]) |
| | matched_line = msa_seq == seq |
| | if matched_line: |
| | if check_mode: |
| | return True |
| | msa_contacts = torch.load(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.contacts.pt')).detach().numpy() |
| | |
| | msa_contacts = msa_contacts[:, seq_start - 1:seq_end, seq_start - 1:seq_end] |
| | msa_pairwise = msa_contacts.transpose((1, 2, 0)) |
| | else: |
| | |
| | if not os.path.exists(f'{ESM_DATA_PATH}/{mutation.ESM_prefix}.contacts.npy'): |
| | if check_mode: |
| | return False |
| | msa_pairwise = np.zeros([seq_end - seq_start + 1, seq_end - seq_start + 1, 1]) |
| | else: |
| | if check_mode: |
| | return True |
| | msa_pairwise = np.load(f'{ESM_DATA_PATH}/{mutation.ESM_prefix}.contacts.npy') |
| | msa_pairwise = np.expand_dims(msa_pairwise[seq_start - 1:seq_end, seq_start - 1:seq_end], axis=2) |
| | return msa_pairwise |
| |
|
| | |
| | def get_contacts_from_msa_by_identifier(identifier): |
| | str_split = identifier.split(":") |
| | transcript = str_split[0] |
| | seq = str_split[1] |
| | seq_start = int(str_split[2]) |
| | seq_end = int(str_split[3]) |
| | check_mode = False |
| | return get_contacts_from_msa(transcript, seq, check_mode, seq_start, seq_end) |
| |
|
| | |
| | def load_embedding_from_esm2(protein): |
| | file_path = f"{ESM_DATA_PATH}/{protein}.representations.layer.48.npy" |
| | assert os.path.exists(file_path) |
| | return np.load(file_path) |
| |
|
| | |
| | def load_logits_from_esm2(protein): |
| | file_path = f"{ESM_DATA_PATH}/{protein}.logits.npy" |
| | assert os.path.exists(file_path) |
| | return np.load(file_path) |
| |
|
| | |
| | def load_attn_from_msa(transcript): |
| | if os.path.exists(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.row_attentions.pt')) and \ |
| | os.path.exists(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.contacts.pt')): |
| | msa_row_attns = torch.load(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.row_attentions.pt')).detach().numpy() |
| | msa_contacts = torch.load(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.contacts.pt')).detach().numpy() |
| | return msa_row_attns, msa_contacts |
| | else: |
| | return None, None |
| |
|
| |
|
| | def _test_load(): |
| | test_file = pd.read_csv('/share/terra/Users/gz2294/ld1/Data/DMS/ClinVar.HGMD.PrimateAI.syn/training.csv', |
| | index_col=0) |
| | |
| | idx = np.where(test_file['sequence.len.orig'] == 4753)[0][0] |
| | point_mutation = get_mutations(test_file['uniprotID'].iloc[idx], |
| | test_file['ENST'].iloc[idx], |
| | test_file['wt.orig'].iloc[idx], |
| | test_file['sequence.len.orig'].iloc[idx], |
| | test_file['pos.orig'].iloc[idx], |
| | test_file['ref'].iloc[idx], |
| | test_file['alt'].iloc[idx]) |
| | coords = get_coords_from_af2(point_mutation.af2_file) |
| |
|
| | CA_coord = coords[:, 3] |
| | embed_data = get_embedding_from_esm2(point_mutation.uniprot_id, False, |
| | point_mutation.seq_start, point_mutation.seq_end) |
| | |
| | coev_strength = get_attn_from_msa(point_mutation.transcript_id, point_mutation.seq, False, |
| | point_mutation.seq_start, point_mutation.seq_end) |
| | edge_index = np.indices((coords.shape[0], coords.shape[0])).reshape(2, -1) |
| | |
| | edge_index = edge_index[:, edge_index[0] != edge_index[1]] |
| | edge_attr = coev_strength[edge_index[0], edge_index[1], :] |
| | |
| | CA_CB = coords[:, [4]] - coords[:, [3]] |
| | CA_C = coords[:, [1]] - coords[:, [3]] |
| | CA_O = coords[:, [2]] - coords[:, [3]] |
| | CA_N = coords[:, [0]] - coords[:, [3]] |
| | nodes_vector = np.concatenate([CA_CB, CA_C, CA_O, CA_N], axis=1) |
| | |
| | features = dict( |
| | pos=torch.from_numpy(CA_coord), x=torch.from_numpy(embed_data), |
| | edge_index=torch.from_numpy(edge_index), edge_attr=torch.from_numpy(edge_attr).to(torch.float), |
| | node_vec_attr=torch.from_numpy(nodes_vector).transpose(1, 2) |
| | ) |
| | from torch_geometric.data import Data |
| |
|
| | map_data = Data(**features) |
| | return map_data |
| |
|
| |
|
| | if __name__ == '__main__': |
| | print(_test_load()) |
| |
|