Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| #! /usr/bin/env python3 | |
| """ | |
| This script is a simple example of how to fine-tune a Synthyra FastPLM model for a protein sequence regression or classification task. | |
| For regression we look at the binding affinity of two proteins (pkd) | |
| For classification we look at the solubility of a protein (membrane bound or not) | |
| """ | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from datasets import load_dataset | |
| from torch.utils.data import Dataset as TorchDataset | |
| from typing import List, Tuple, Dict, Union, Any | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| Trainer, | |
| TrainingArguments, | |
| EarlyStoppingCallback, | |
| EvalPrediction | |
| ) | |
| from peft import LoraConfig, get_peft_model | |
| from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
| from scipy.stats import spearmanr | |
| # Shared arguments for the trainer | |
| BASE_TRAINER_KWARGS = { | |
| "warmup_steps": 500, | |
| "weight_decay": 0.01, | |
| "logging_steps": 100, | |
| "eval_strategy": "steps", | |
| "eval_steps": 500, | |
| "save_strategy": "steps", | |
| "save_steps": 500, | |
| "load_best_model_at_end": True, | |
| "metric_for_best_model": "eval_loss", | |
| "greater_is_better": False, | |
| "report_to": "none", | |
| "label_names": ["labels"] | |
| } | |
| # Dataset classes | |
| class PairDatasetHF(TorchDataset): | |
| """ | |
| Dataset class for protein pair data (e.g., protein-protein interactions). | |
| Args: | |
| data: The dataset containing protein sequences and labels | |
| col_a: Column name for the first protein sequence | |
| col_b: Column name for the second protein sequence | |
| label_col: Column name for the labels | |
| max_length: Maximum sequence length to consider | |
| """ | |
| def __init__(self, dataset: Any, col_a: str, col_b: str, label_col: str, max_length: int = 2048): | |
| self.seqs_a = dataset[col_a] | |
| self.seqs_b = dataset[col_b] | |
| self.labels = dataset[label_col] | |
| self.max_length = max_length | |
| def __len__(self) -> int: | |
| return len(self.seqs_a) | |
| def __getitem__(self, idx: int) -> Tuple[str, str, Union[float, int]]: | |
| seq_a = self.seqs_a[idx][:self.max_length] | |
| seq_b = self.seqs_b[idx][:self.max_length] | |
| label = self.labels[idx] | |
| return seq_a, seq_b, label | |
| class SequenceDatasetHF(TorchDataset): | |
| """ | |
| Dataset class for single protein sequence data. | |
| Args: | |
| dataset: The dataset containing protein sequences and labels | |
| col_name: Column name for the protein sequences | |
| label_col: Column name for the labels | |
| max_length: Maximum sequence length to consider | |
| """ | |
| def __init__(self, dataset: Any, col_name: str = 'seqs', label_col: str = 'labels', max_length: int = 2048): | |
| self.seqs = dataset[col_name] | |
| self.labels = dataset[label_col] | |
| self.max_length = max_length | |
| def __len__(self) -> int: | |
| return len(self.seqs) | |
| def __getitem__(self, idx: int) -> Tuple[str, Union[float, int]]: | |
| seq = self.seqs[idx][:self.max_length] | |
| label = self.labels[idx] | |
| return seq, label | |
| class PairCollator: | |
| """ | |
| Collator for protein pair data that handles tokenization and tensor conversion. | |
| Args: | |
| tokenizer: The tokenizer to use for encoding sequences | |
| regression: Whether this is a regression task (True) or classification (False) | |
| """ | |
| def __init__(self, tokenizer: Any, regression: bool = False): | |
| self.tokenizer = tokenizer | |
| self.regression = regression | |
| def __call__(self, batch: List[Tuple[str, str, Union[float, int]]]) -> Dict[str, torch.Tensor]: | |
| seqs_a, seqs_b, labels = zip(*batch) | |
| labels = torch.tensor(labels) | |
| if self.regression: | |
| labels = labels.float() | |
| else: | |
| labels = labels.long() | |
| tokenized = self.tokenizer( | |
| seqs_a, seqs_b, | |
| padding='longest', | |
| pad_to_multiple_of=8, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': tokenized['input_ids'], | |
| 'attention_mask': tokenized['attention_mask'], | |
| 'labels': labels | |
| } | |
| class SequenceCollator: | |
| """ | |
| Collator for single protein sequence data that handles tokenization and tensor conversion. | |
| Args: | |
| tokenizer: The tokenizer to use for encoding sequences | |
| regression: Whether this is a regression task (True) or classification (False) | |
| """ | |
| def __init__(self, tokenizer: Any, regression: bool = False): | |
| self.tokenizer = tokenizer | |
| self.regression = regression | |
| def __call__(self, batch: List[Tuple[str, Union[float, int]]]) -> Dict[str, torch.Tensor]: | |
| seqs, labels = zip(*batch) | |
| labels = torch.tensor(labels) | |
| if self.regression: | |
| labels = labels.float() | |
| else: | |
| labels = labels.long() | |
| tokenized = self.tokenizer( | |
| seqs, | |
| padding='longest', | |
| pad_to_multiple_of=8, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': tokenized['input_ids'], | |
| 'attention_mask': tokenized['attention_mask'], | |
| 'labels': labels | |
| } | |
| # Get the model ready, with or without LoRA | |
| def initialize_model(model_name: str, num_labels: int, use_lora: bool = True, lora_config: Any = None): | |
| """ | |
| Initialize a model with optional LoRA support | |
| Args: | |
| model_name: Name or path of the pretrained model | |
| num_labels: Number of labels for the task (1 for regression) | |
| use_lora: Whether to use LoRA for fine-tuning | |
| lora_config: Custom LoRA configuration (optional) | |
| Returns: | |
| model: The initialized model | |
| tokenizer: The model's tokenizer | |
| """ | |
| print(f"Loading model {model_name} with {num_labels} labels...") | |
| # Load base model | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| num_labels=num_labels | |
| ) | |
| tokenizer = model.tokenizer | |
| # Apply LoRA if requested | |
| if use_lora: | |
| # Default LoRA configuration if none provided | |
| if lora_config is None: | |
| # Target modules for ESM++ or ESM2 models | |
| target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"] | |
| lora_config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| lora_dropout=0.01, | |
| bias="none", | |
| target_modules=target_modules, | |
| ) | |
| # Apply LoRA to the model | |
| model = get_peft_model(model, lora_config) | |
| # Unfreeze the classifier head | |
| for param in model.classifier.parameters(): | |
| param.requires_grad = True | |
| # Print parameter statistics | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| non_trainable_params = total_params - trainable_params | |
| print(f"Total parameters: {total_params}") | |
| print(f"Trainable parameters: {trainable_params}") | |
| print(f"Non-trainable parameters: {non_trainable_params}") | |
| print(f"Percentage of parameters being trained: {100 * trainable_params / total_params:.2f}%") | |
| return model, tokenizer | |
| # For computing performance metrics, it's fairly straightforward to add more metrics here | |
| def compute_metrics_regression(p: EvalPrediction) -> Dict[str, float]: | |
| """Compute Spearman correlation for regression tasks""" | |
| predictions, labels = p.predictions, p.label_ids | |
| predictions = predictions[0] if isinstance(predictions, tuple) else predictions | |
| # Calculate Spearman correlation | |
| correlation, p_value = spearmanr(predictions.flatten(), labels.flatten()) | |
| return { | |
| "spearman_correlation": correlation, | |
| "p_value": p_value | |
| } | |
| def compute_metrics_classification(p: EvalPrediction) -> Dict[str, float]: | |
| """Compute accuracy for classification tasks""" | |
| predictions, labels = p.predictions, p.label_ids | |
| predictions = predictions[0] if isinstance(predictions, tuple) else predictions | |
| predictions = np.argmax(predictions, axis=-1) | |
| accuracy = (predictions.flatten() == labels.flatten()).mean() | |
| return { | |
| "accuracy": accuracy | |
| } | |
| # For plotting the results, it's fairly straightforward to add more plots here | |
| def plot_regression_results(preds: np.ndarray, labels: np.ndarray, task_name: str = "Regression") -> float: | |
| """ | |
| Plot regression results with Spearman correlation | |
| Args: | |
| preds: Predicted values | |
| labels: True values | |
| task_name: Name of the task for plot title and filename | |
| Returns: | |
| correlation: Spearman correlation coefficient | |
| """ | |
| # Calculate Spearman correlation | |
| correlation, p_value = spearmanr(preds, labels) | |
| # Create scatter plot | |
| plt.figure(figsize=(10, 8)) | |
| sns.scatterplot(x=labels, y=preds, alpha=0.6) | |
| # Add regression line | |
| sns.regplot(x=labels, y=preds, scatter=False, color='red') | |
| plt.title(f'{task_name} - Spearman Correlation: {correlation:.3f} (p={p_value:.3e})') | |
| plt.xlabel('True Values') | |
| plt.ylabel('Predicted Values') | |
| # Add correlation text | |
| plt.annotate(f'ρ = {correlation:.3f}', xy=(0.05, 0.95), xycoords='axes fraction', | |
| fontsize=12, bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8)) | |
| plt.tight_layout() | |
| plt.savefig(f'{task_name.lower().replace(" ", "_")}_results.png') | |
| plt.show() | |
| return correlation | |
| def plot_classification_results(trainer: Trainer, test_dataset: Any, task_name: str = "Classification") -> float: | |
| """ | |
| Plot classification results with confusion matrix | |
| Args: | |
| trainer: The trained model trainer | |
| test_dataset: Dataset to evaluate on | |
| task_name: Name of the task for plot title and filename | |
| Returns: | |
| accuracy: Classification accuracy | |
| """ | |
| # Get predictions | |
| predictions, labels, _ = trainer.predict(test_dataset) | |
| preds = predictions[0] if isinstance(predictions, tuple) else predictions | |
| pred_values = np.argmax(preds, axis=1) | |
| # Calculate accuracy | |
| accuracy = (pred_values == labels).mean() | |
| # Create confusion matrix | |
| cm = confusion_matrix(labels, pred_values) | |
| # Plot confusion matrix | |
| plt.figure(figsize=(10, 8)) | |
| disp = ConfusionMatrixDisplay(confusion_matrix=cm) | |
| disp.plot(cmap=plt.cm.Blues) | |
| plt.title(f'{task_name} - Accuracy: {accuracy:.3f}') | |
| plt.tight_layout() | |
| plt.savefig(f'{task_name.lower().replace(" ", "_")}_results.png') | |
| plt.show() | |
| return accuracy | |
| # Training functions | |
| def train_regression_model( | |
| model_name: str = 'Synthyra/ESMplusplus_small', | |
| use_lora: bool = True, | |
| custom_lora_config: Any = None, | |
| batch_size: int = 8, | |
| learning_rate: float = 5e-5, | |
| num_epochs: int = 10, | |
| max_length: int = 1024, | |
| gradient_accumulation_steps: int = 1, | |
| patience: int = 3 | |
| ) -> Tuple[Trainer, Any]: | |
| """ | |
| Train a regression model for protein-protein affinity prediction | |
| Args: | |
| model_name: Name or path of the pretrained model | |
| use_lora: Whether to use LoRA for fine-tuning | |
| custom_lora_config: Custom LoRA configuration (optional) | |
| batch_size: Batch size for training | |
| learning_rate: Learning rate for training | |
| num_epochs: Number of epochs for training | |
| max_length: Maximum sequence length to consider | |
| gradient_accumulation_steps: Number of gradient accumulation steps | |
| patience: Number of evaluation calls with no improvement after which training will be stopped | |
| Returns: | |
| trainer: The trained model trainer | |
| test_dataset: The test dataset used for evaluation | |
| """ | |
| print("Loading datasets for regression task...") | |
| # Filter sequences that exceed max_length | |
| def _filter_pair_by_length(example: Any) -> bool: | |
| return len(example['SeqA']) + len(example['SeqB']) <= max_length | |
| # Load datasets | |
| train_data = load_dataset('Synthyra/ProteinProteinAffinity', split='train').filter(_filter_pair_by_length) | |
| valid_data = load_dataset('Synthyra/AffinityBenchmarkv5.5', split='train').filter(_filter_pair_by_length) | |
| test_data = load_dataset('Synthyra/haddock_benchmark', split='train').filter(_filter_pair_by_length) | |
| # Create datasets | |
| train_dataset = PairDatasetHF(train_data, 'SeqA', 'SeqB', 'labels', max_length=max_length) | |
| valid_dataset = PairDatasetHF(valid_data, 'SeqA', 'SeqB', 'labels', max_length=max_length) | |
| test_dataset = PairDatasetHF(test_data, 'SeqA', 'SeqB', 'labels', max_length=max_length) | |
| # Initialize model with modular function | |
| model, tokenizer = initialize_model( | |
| model_name=model_name, | |
| num_labels=1, # Regression task | |
| use_lora=use_lora, | |
| lora_config=custom_lora_config | |
| ) | |
| # Create data collator | |
| data_collator = PairCollator(tokenizer, regression=True) | |
| # Define training arguments | |
| output_dir = "./results_regression_lora" if use_lora else "./results_regression" | |
| logging_dir = "./logs_regression_lora" if use_lora else "./logs_regression" | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| num_train_epochs=num_epochs, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, | |
| logging_dir=logging_dir, | |
| learning_rate=learning_rate, | |
| **BASE_TRAINER_KWARGS | |
| ) | |
| # Create trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=valid_dataset, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics_regression, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=patience)] | |
| ) | |
| metrics = trainer.evaluate(test_dataset) | |
| print(f"Initial metrics: {metrics}") | |
| print("Training regression model...") | |
| trainer.train() | |
| # Evaluate and visualize results | |
| print("Evaluating and visualizing results...") | |
| predictions, labels, metrics = trainer.predict(test_dataset) | |
| preds = predictions[0] if isinstance(predictions, tuple) else predictions | |
| correlation = plot_regression_results(preds.flatten(), labels.flatten(), "Protein-Protein Affinity") | |
| print(f"Final Spearman correlation on test set: {correlation:.3f}") | |
| return trainer, test_dataset | |
| def train_classification_model( | |
| model_name: str = 'Synthyra/ESMplusplus_small', | |
| use_lora: bool = True, | |
| custom_lora_config: Any = None, | |
| batch_size: int = 8, | |
| learning_rate: float = 5e-5, | |
| num_epochs: int = 10, | |
| max_length: int = 512, | |
| gradient_accumulation_steps: int = 1, | |
| patience: int = 3 | |
| ) -> Tuple[Trainer, Any]: | |
| """ | |
| Train a classification model for protein solubility prediction | |
| Args: | |
| model_name: Name or path of the pretrained model | |
| use_lora: Whether to use LoRA for fine-tuning | |
| custom_lora_config: Custom LoRA configuration (optional) | |
| batch_size: Batch size for training | |
| learning_rate: Learning rate for training | |
| num_epochs: Number of epochs for training | |
| max_length: Maximum sequence length to consider | |
| gradient_accumulation_steps: Number of gradient accumulation steps | |
| patience: Number of evaluation calls with no improvement after which training will be stopped | |
| Returns: | |
| trainer: The trained model trainer | |
| """ | |
| print("Loading datasets for classification task...") | |
| # Filter sequences that exceed max_length | |
| def _filter_by_length(example: Any) -> bool: | |
| return len(example['seqs']) <= max_length | |
| # Load datasets | |
| data = load_dataset('GleghornLab/DL2_reg') | |
| train_data = data['train'].filter(_filter_by_length) | |
| valid_data = data['valid'].filter(_filter_by_length) | |
| test_data = data['test'].filter(_filter_by_length) | |
| # Create datasets | |
| train_dataset = SequenceDatasetHF(train_data, 'seqs', 'labels', max_length=max_length) | |
| valid_dataset = SequenceDatasetHF(valid_data, 'seqs', 'labels', max_length=max_length) | |
| test_dataset = SequenceDatasetHF(test_data, 'seqs', 'labels', max_length=max_length) | |
| # Get number of labels | |
| num_labels = len(set(train_data['labels'])) | |
| # Initialize model with modular function | |
| model, tokenizer = initialize_model( | |
| model_name=model_name, | |
| num_labels=num_labels, | |
| use_lora=use_lora, | |
| lora_config=custom_lora_config | |
| ) | |
| # Create data collator | |
| data_collator = SequenceCollator(tokenizer, regression=False) | |
| # Define training arguments | |
| output_dir = "./results_classification_lora" if use_lora else "./results_classification" | |
| logging_dir = "./logs_classification_lora" if use_lora else "./logs_classification" | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| num_train_epochs=num_epochs, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, | |
| logging_dir=logging_dir, | |
| learning_rate=learning_rate, | |
| **BASE_TRAINER_KWARGS | |
| ) | |
| # Create trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=valid_dataset, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics_classification, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=patience)] | |
| ) | |
| metrics = trainer.evaluate(test_dataset) | |
| print(f"Initial metrics: {metrics}") | |
| print("Training classification model...") | |
| trainer.train() | |
| # Evaluate and visualize results | |
| print("Evaluating and visualizing results...") | |
| accuracy = plot_classification_results(trainer, test_dataset, "Protein Solubility") | |
| print(f"Final accuracy on test set: {accuracy:.3f}") | |
| return trainer | |
| # Main function | |
| if __name__ == "__main__": | |
| """ | |
| With default arguments on 4070 laptop GPU | |
| py -m fine_tuning_example --task classification --batch_size 8 --epochs 2 | |
| Runs in 80 seconds with test accuracy of ~89% | |
| py -m fine_tuning_example --task regression --batch_size 2 --max_length 1024 --grad_accum 4 --epochs 2 | |
| Runs in 7 minutes with test Spearman correlation of ~0.72 | |
| """ | |
| import argparse | |
| # Examples of PLMs with efficient implementations offered by Synthyra | |
| MODEL_LIST = [ | |
| 'Synthyra/ESMplusplus_small', | |
| 'Synthyra/ESMplusplus_large', | |
| 'Synthyra/ESM2-8M', | |
| 'Synthyra/ESM2-35M', | |
| 'Synthyra/ESM2-150M', | |
| 'Synthyra/ESM2-650M', | |
| ] | |
| parser = argparse.ArgumentParser(description="Train models for protein tasks") | |
| parser.add_argument("--task", type=str, choices=["regression", "classification", "both"], | |
| default="both", help="Task to train model for") | |
| parser.add_argument("--model_path", type=str, default="Synthyra/ESM2-8M", | |
| help="Path to the model to train") | |
| parser.add_argument("--use_lora", action="store_true", default=True, | |
| help="Whether to use LoRA for fine-tuning") | |
| parser.add_argument("--batch_size", type=int, default=2, | |
| help="Batch size for training") | |
| parser.add_argument("--lr", type=float, default=5e-5, | |
| help="Learning rate for training") | |
| parser.add_argument("--epochs", type=float, default=1.0, | |
| help="Number of epochs for training") | |
| parser.add_argument("--max_length", type=int, default=512, | |
| help="Maximum length of input sequences") | |
| parser.add_argument("--grad_accum", type=int, default=1, | |
| help="Number of gradient accumulation steps") | |
| parser.add_argument("--patience", type=int, default=3, | |
| help="Early stopping patience - number of evaluation calls with no improvement after which training will be stopped") | |
| args = parser.parse_args() | |
| # Print training configuration | |
| print("\n" + "="*50) | |
| print("TRAINING CONFIGURATION") | |
| print("="*50) | |
| print(f"Task: {args.task}") | |
| print(f"Using LoRA: {args.use_lora}") | |
| print(f"Batch size: {args.batch_size}") | |
| print(f"Learning rate: {args.lr}") | |
| print(f"Number of epochs: {args.epochs}") | |
| print(f"Max sequence length: {args.max_length}") | |
| print(f"Gradient Accumulation Steps: {args.grad_accum}") | |
| print(f"Early stopping patience: {args.patience}") | |
| print("="*50 + "\n") | |
| # Train regression model if required | |
| if args.task in ["regression", "both"]: | |
| print("\n" + "="*50) | |
| print("TRAINING REGRESSION MODEL") | |
| print("="*50) | |
| regression_trainer, test_dataset = train_regression_model( | |
| model_name=args.model_path, | |
| use_lora=args.use_lora, | |
| batch_size=args.batch_size, | |
| learning_rate=args.lr, | |
| num_epochs=args.epochs, | |
| max_length=args.max_length, | |
| gradient_accumulation_steps=args.grad_accum, | |
| patience=args.patience | |
| ) | |
| # Train classification model if required | |
| if args.task in ["classification", "both"]: | |
| print("\n" + "="*50) | |
| print("TRAINING CLASSIFICATION MODEL") | |
| print("="*50) | |
| classification_trainer = train_classification_model( | |
| model_name=args.model_path, | |
| use_lora=args.use_lora, | |
| batch_size=args.batch_size, | |
| learning_rate=args.lr, | |
| num_epochs=args.epochs, | |
| max_length=args.max_length, | |
| gradient_accumulation_steps=args.grad_accum, | |
| patience=args.patience | |
| ) | |
| print("\nTraining completed!") | |