Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| from typing import Optional, Tuple | |
| AA_SET = set('LAGVSERTIPDKQNFYMHWCXBUOZ*') | |
| CODON_SET = set('aA@bB#$%rRnNdDcCeEqQ^G&ghHiIj+MmlJLkK(fFpPoO=szZwSXTtxWyYuvUV]})') | |
| DNA_SET = set('ATCG') | |
| RNA_SET = set('AUCG') | |
| NONCANONICAL_AMINO_ACIDS = set('XBUOZ*') | |
| AMINO_ACID_TO_HUMAN_CODON = { | |
| 'A': 'GCC', | |
| 'R': 'CGC', | |
| 'N': 'AAC', | |
| 'D': 'GAC', | |
| 'C': 'TGC', | |
| 'Q': 'CAG', | |
| 'E': 'GAG', | |
| 'G': 'GGC', | |
| 'H': 'CAC', | |
| 'I': 'ATC', | |
| 'L': 'CTG', | |
| 'K': 'AAG', | |
| 'M': 'ATG', | |
| 'F': 'TTC', | |
| 'P': 'CCC', | |
| 'S': 'AGC', | |
| 'T': 'ACC', | |
| 'W': 'TGG', | |
| 'Y': 'TAC', | |
| 'V': 'GTG', | |
| } | |
| NONCANONICAL_ALANINE_CODON = 'GCT' | |
| AA_TO_CODON_TOKEN = { | |
| 'A': 'A', | |
| 'R': 'B', | |
| 'N': 'N', | |
| 'D': 'D', | |
| 'C': 'C', | |
| 'Q': 'Q', | |
| 'E': 'E', | |
| 'G': 'G', | |
| 'H': 'H', | |
| 'I': 'I', | |
| 'L': 'L', | |
| 'K': 'K', | |
| 'M': '(', | |
| 'F': 'F', | |
| 'P': 'P', | |
| 'S': 'S', | |
| 'T': 'T', | |
| 'W': 'W', | |
| 'Y': 'Y', | |
| 'V': 'V', | |
| } | |
| CODON_TO_AA = { | |
| 'a':'A', | |
| 'A':'A', | |
| '@':'A', | |
| 'b':'A', | |
| 'B':'R', | |
| '#':'R', | |
| '$':'R', | |
| '%':'R', | |
| 'r':'R', | |
| 'R':'R', | |
| 'n':'N', | |
| 'N':'N', | |
| 'd':'D', | |
| 'D':'D', | |
| 'c':'C', | |
| 'C':'C', | |
| 'e':'E', | |
| 'E':'E', | |
| 'q':'Q', | |
| 'Q':'Q', | |
| '^':'G', | |
| 'G':'G', | |
| '&':'G', | |
| 'g':'G', | |
| 'h':'H', | |
| 'H':'H', | |
| 'i':'I', | |
| 'I':'I', | |
| 'j':'I', | |
| '+':'L', | |
| 'M':'L', | |
| 'm':'L', | |
| 'l':'L', | |
| 'J':'L', | |
| 'L':'L', | |
| 'k':'K', | |
| 'K':'K', | |
| '(':'M', | |
| 'f':'F', | |
| 'F':'F', | |
| 'p':'P', | |
| 'P':'P', | |
| 'o':'P', | |
| 'O':'P', | |
| '=':'S', | |
| 's':'S', | |
| 'z':'S', | |
| 'Z':'S', | |
| 'w':'S', | |
| 'S':'S', | |
| 'X':'S', | |
| 'T':'T', | |
| 't':'T', | |
| 'x':'T', | |
| 'W':'T', | |
| 'y':'Y', | |
| 'Y':'Y', | |
| 'u':'V', | |
| 'v':'V', | |
| 'U':'V', | |
| 'V':'V', | |
| ']':'*', | |
| '}':'*', | |
| ')':'*', | |
| } | |
| DNA_CODON_TO_AA = { | |
| 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', | |
| 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', | |
| 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', | |
| 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', | |
| 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', | |
| 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', | |
| 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', | |
| 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', | |
| 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', | |
| 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', | |
| 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', | |
| 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', | |
| 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', | |
| 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', | |
| 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', | |
| 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G', | |
| } | |
| RNA_CODON_TO_AA = { | |
| codon.replace('T', 'U'): aa for codon, aa in DNA_CODON_TO_AA.items() | |
| } | |
| def pad_and_concatenate_dimer( | |
| A: torch.Tensor, | |
| B: torch.Tensor, | |
| a_mask: Optional[torch.Tensor] = None, | |
| b_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Given two sequences A and B with masks, pad (if needed) and concatenate them. | |
| """ | |
| batch_size, L, d = A.size() | |
| if a_mask is None: | |
| a_mask = torch.ones(batch_size, L, device=A.device) | |
| if b_mask is None: | |
| b_mask = torch.ones(batch_size, L, device=A.device) | |
| # Compute the maximum (valid) length in the batch. | |
| max_len = max( | |
| int(a_mask[i].sum().item() + b_mask[i].sum().item()) | |
| for i in range(batch_size) | |
| ) | |
| combined = torch.zeros(batch_size, max_len, d, device=A.device) | |
| combined_mask = torch.zeros(batch_size, max_len, device=A.device) | |
| for i in range(batch_size): | |
| a_len = int(a_mask[i].sum().item()) | |
| b_len = int(b_mask[i].sum().item()) | |
| combined[i, :a_len] = A[i, :a_len] | |
| combined[i, a_len:a_len+b_len] = B[i, :b_len] | |
| combined_mask[i, :a_len+b_len] = 1 | |
| return combined, combined_mask | |