Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Embedding Test Suite CLI | |
| Tests embedding quality by sampling sequences from EC dataset (default), | |
| embedding them with various pooling methods, and reporting statistics | |
| on distribution, NaNs, and sparsity. | |
| """ | |
| import json | |
| import argparse | |
| import math | |
| import random | |
| import numpy as np | |
| import torch | |
| from typing import Dict, List, Optional | |
| try: | |
| from data.data_mixin import DataMixin, DataArguments | |
| from embedder import Embedder, EmbeddingArguments | |
| from base_models.get_base_models import standard_models | |
| from seed_utils import set_global_seed, get_global_seed | |
| from utils import print_message | |
| except ImportError: | |
| from ..data.data_mixin import DataMixin, DataArguments | |
| from ..embedder import Embedder, EmbeddingArguments | |
| from ..base_models.get_base_models import standard_models | |
| from ..seed_utils import set_global_seed, get_global_seed | |
| from ..utils import print_message | |
| # Default test datasets | |
| DEFAULT_TEST_DATASETS = [ | |
| 'EC', # multilabel | |
| ] | |
| seed = get_global_seed() | |
| if seed is not None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| def load_and_sample_sequences( | |
| dataset_names: List[str], | |
| sample_frac: float = 0.1, | |
| max_length: int = 1024, | |
| trim: bool = False | |
| ) -> Dict[str, List[str]]: | |
| """ | |
| Load datasets and sample sequences from them. | |
| Args: | |
| dataset_names: List of dataset names to load | |
| sample_frac: Fraction of sequences to sample (default 0.1 = 10%) | |
| max_length: Maximum sequence length | |
| trim: Whether to trim sequences to max_length | |
| Returns: | |
| Dictionary mapping dataset names to lists of sampled sequences | |
| """ | |
| dataset_seqs = {} | |
| for dataset_name in dataset_names: | |
| print_message(f"Loading dataset: {dataset_name}") | |
| try: | |
| # Load dataset using DataMixin | |
| data_args = DataArguments( | |
| data_names=[dataset_name], | |
| max_length=max_length, | |
| trim=trim | |
| ) | |
| data_mixin = DataMixin(data_args) | |
| datasets, all_seqs = data_mixin.get_data() | |
| # Get sequences from all splits | |
| sequences = [] | |
| if dataset_name in datasets: | |
| train_set, valid_set, test_set, _, _, ppi = datasets[dataset_name] | |
| if ppi: | |
| # For PPI datasets, combine SeqA and SeqB | |
| sequences.extend(list(train_set['SeqA'])) | |
| sequences.extend(list(train_set['SeqB'])) | |
| sequences.extend(list(valid_set['SeqA'])) | |
| sequences.extend(list(valid_set['SeqB'])) | |
| sequences.extend(list(test_set['SeqA'])) | |
| sequences.extend(list(test_set['SeqB'])) | |
| else: | |
| sequences.extend(list(train_set['seqs'])) | |
| sequences.extend(list(valid_set['seqs'])) | |
| sequences.extend(list(test_set['seqs'])) | |
| else: | |
| # Use all sequences if dataset processing failed | |
| sequences = list(all_seqs) | |
| # Sample | |
| sequences = list(set(sequences)) | |
| n_samples = max(1, math.ceil(len(sequences) * sample_frac)) | |
| sampled = random.sample(sequences, min(n_samples, len(sequences))) | |
| dataset_seqs[dataset_name] = sampled | |
| print_message(f"Sampled {len(sampled)} sequences from {len(sequences)} total") | |
| except Exception as e: | |
| print_message(f"Error loading dataset {dataset_name}: {e}") | |
| continue | |
| return dataset_seqs | |
| def compute_diagnostics(embeddings: torch.Tensor, zero_eps: float = 1e-8) -> Dict[str, float]: | |
| emb = embeddings.detach().float().cpu().numpy() | |
| flat = emb.ravel() | |
| is_nan = np.isnan(flat) | |
| is_inf = np.isinf(flat) | |
| is_finite = np.isfinite(flat) | |
| finite = flat[is_finite] | |
| if finite.size == 0: | |
| # If everything is NaN/Inf | |
| return { | |
| "n_samples": int(emb.shape[0]), | |
| "embedding_dim": int(emb.shape[1]), | |
| "finite_count": 0, | |
| "nan_count": int(is_nan.sum()), | |
| "inf_count": int(is_inf.sum()), | |
| } | |
| near_zero = np.abs(finite) < zero_eps | |
| sample_l2 = np.linalg.norm(emb, axis=1) | |
| return { | |
| "n_samples": int(emb.shape[0]), | |
| "embedding_dim": int(emb.shape[1]), | |
| "finite_count": int(finite.size), | |
| "finite_fraction": float(finite.size / flat.size), | |
| "nan_count": int(is_nan.sum()), | |
| "nan_fraction": float(is_nan.mean()), | |
| "inf_count": int(is_inf.sum()), | |
| "inf_fraction": float(is_inf.mean()), | |
| "zero_eps": float(zero_eps), | |
| "near_zero_count": int(near_zero.sum()), | |
| "near_zero_fraction": float(near_zero.mean()), | |
| "mean": float(np.mean(finite)), | |
| "std": float(np.std(finite)), | |
| "min": float(np.min(finite)), | |
| "max": float(np.max(finite)), | |
| "p25": float(np.percentile(finite, 25)), | |
| "p50": float(np.percentile(finite, 50)), | |
| "p75": float(np.percentile(finite, 75)), | |
| "p95": float(np.percentile(finite, 95)), | |
| "p99": float(np.percentile(finite, 99)), | |
| "mean_l2": float(np.mean(sample_l2)), | |
| "std_l2": float(np.std(sample_l2)), | |
| "p95_l2": float(np.percentile(sample_l2, 95)), | |
| } | |
| def embed_and_diagnose( | |
| sequences: List[str], | |
| model_name: str, | |
| pooling_types: List[str], | |
| batch_size: int = 16, | |
| num_workers: int = 0 | |
| ) -> Dict[str, Dict[str, float]]: | |
| """ | |
| Embed sequences and compute diagnostics for each pooling type. | |
| Args: | |
| sequences: List of sequences to embed | |
| model_name: Name of the model to use | |
| pooling_types: List of pooling types to test | |
| batch_size: Batch size for embedding | |
| num_workers: Number of workers for data loading | |
| Returns: | |
| Dictionary mapping pooling types to their diagnostics | |
| """ | |
| print_message(f"Embedding {len(sequences)} sequences with {model_name}") | |
| # Parse pooling types so combinations can be tested | |
| pooling_list = {} | |
| for pool_type in pooling_types: | |
| # Check if it's a combination | |
| if ',' in pool_type: | |
| # Split and create a list for the combination | |
| pool_list = [p.strip() for p in pool_type.split(',')] | |
| pooling_list[pool_type] = pool_list | |
| else: | |
| # Single pooling type | |
| pooling_list[pool_type] = [pool_type] | |
| results = {} | |
| # Load model once and reuse for all pooling types | |
| print_message(f"Loading model: {model_name}") | |
| from base_models.get_base_models import get_base_model | |
| model, tokenizer = get_base_model(model_name) | |
| for pool_type, pool_list in pooling_list.items(): | |
| print_message(f"Testing pooling: {pool_type} (types: {pool_list})") | |
| # Set up embedder for this pooling type | |
| embedder_args = EmbeddingArguments( | |
| embedding_batch_size=batch_size, | |
| embedding_num_workers=num_workers, | |
| download_embeddings=False, | |
| matrix_embed=False, | |
| embedding_pooling_types=pool_list, | |
| save_embeddings=False, | |
| embed_dtype=torch.float32, | |
| sql=False, | |
| embedding_save_dir='embeddings' | |
| ) | |
| embedder = Embedder(embedder_args, sequences) | |
| try: | |
| # read embeddings from disk if they exist | |
| to_embed, save_path, embeddings_dict = embedder._read_embeddings_from_disk(model_name) | |
| if len(to_embed) > 0: | |
| result = embedder._embed_sequences( | |
| to_embed, save_path, model, tokenizer, embeddings_dict | |
| ) | |
| if result is not None: | |
| embeddings_dict = result | |
| if embeddings_dict is None or len(embeddings_dict) == 0: | |
| print_message(f"Warning: No embeddings returned for {model_name} with {pool_type}") | |
| continue | |
| embedding_tensors = [] | |
| for seq in sequences: | |
| if seq in embeddings_dict: | |
| embedding_tensors.append(embeddings_dict[seq]) | |
| if len(embedding_tensors) == 0: | |
| print_message(f"Error: No embeddings found for {pool_type}") | |
| continue | |
| embeddings = torch.stack(embedding_tensors) | |
| diagnostics = compute_diagnostics(embeddings) | |
| results[pool_type] = diagnostics | |
| except Exception as e: | |
| print_message(f"Error embedding with {model_name} using {pool_type}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| continue | |
| return results | |
| def run_test_suite( | |
| dataset_names: Optional[List[str]] = None, | |
| model_names: Optional[List[str]] = None, | |
| pooling_methods: List[str] = ['cls', 'mean,var'], | |
| sample_frac: float = 0.1, | |
| batch_size: int = 16, | |
| num_workers: int = 0 | |
| ) -> Dict: | |
| """ | |
| Run the embedding test suite. | |
| """ | |
| if dataset_names is None: | |
| dataset_names = DEFAULT_TEST_DATASETS | |
| if model_names is None: | |
| model_names = standard_models | |
| print_message(f"Running embedding test suite") | |
| print_message(f"Datasets: {dataset_names}") | |
| print_message(f"Models: {model_names}") | |
| print_message(f"Pooling methods: {pooling_methods}") | |
| print_message(f"Sample fraction: {sample_frac}") | |
| dataset_seqs = load_and_sample_sequences(dataset_names, sample_frac=sample_frac) | |
| if len(dataset_seqs) == 0: | |
| print_message("Error: No sequences loaded") | |
| return {} | |
| all_results = {} | |
| for dataset_name, sequences in dataset_seqs.items(): | |
| print_message(f"\nProcessing dataset: {dataset_name}") | |
| all_results[dataset_name] = {} | |
| for model_name in model_names: | |
| print_message(f"Model: {model_name}") | |
| model_results = embed_and_diagnose( | |
| sequences, | |
| model_name, | |
| pooling_methods, | |
| batch_size=batch_size, | |
| num_workers=num_workers | |
| ) | |
| if model_results: | |
| all_results[dataset_name][model_name] = model_results | |
| print_table_results(all_results) | |
| print_json_results(all_results) | |
| return all_results | |
| def print_table_results(results: Dict): | |
| """Print results in table format.""" | |
| print("\n" + "="*100) | |
| print("EMBEDDING TEST SUITE RESULTS") | |
| print("="*100) | |
| for dataset_name, dataset_results in results.items(): | |
| print(f"\nDataset: {dataset_name}") | |
| print("-" * 100) | |
| for model_name, model_results in dataset_results.items(): | |
| print(f"\n Model: {model_name}") | |
| for pool_type, diagnostics in model_results.items(): | |
| print(f"\nPooling: {pool_type}") | |
| print(f"Samples: {diagnostics['n_samples']}, Dim: {diagnostics['embedding_dim']}") | |
| print(f"Mean: {diagnostics['mean']:.6f}, Std: {diagnostics['std']:.6f}") | |
| print(f"Min: {diagnostics['min']:.6f}, Max: {diagnostics['max']:.6f}") | |
| print(f"Percentiles: P25={diagnostics['p25']:.6f}, P50={diagnostics['p50']:.6f}, " | |
| f"P75={diagnostics['p75']:.6f}, P95={diagnostics['p95']:.6f}, P99={diagnostics['p99']:.6f}") | |
| print(f"NaN: {diagnostics['nan_count']} ({diagnostics['nan_fraction']*100:.2f}%)") | |
| if 'near_zero_count' in diagnostics: | |
| print(f"Near zeros: {diagnostics['near_zero_count']} ({diagnostics['near_zero_fraction']*100:.2f}%)") | |
| print(f"Inf: {diagnostics['inf_count']} ({diagnostics['inf_fraction']*100:.2f}%)") | |
| # Flag anomalies | |
| anomalies = [] | |
| if diagnostics['nan_fraction'] > 0: | |
| anomalies.append(f"NaNs detected ({diagnostics['nan_fraction']*100:.2f}%)") | |
| if 'near_zero_fraction' in diagnostics and diagnostics['near_zero_fraction'] > 0.2: | |
| anomalies.append(f"High sparsity ({diagnostics['near_zero_fraction']*100:.2f}%)") | |
| if diagnostics['inf_fraction'] > 0: | |
| anomalies.append(f"Infs detected ({diagnostics['inf_fraction']*100:.2f}%)") | |
| if abs(diagnostics['mean']) > 100: | |
| anomalies.append(f"Extreme mean ({diagnostics['mean']:.2f})") | |
| if diagnostics['std'] > 100: | |
| anomalies.append(f"Extreme std ({diagnostics['std']:.2f})") | |
| if anomalies: | |
| print(f"Anomalies: {', '.join(anomalies)}") | |
| else: | |
| print(f"No anomalies detected") | |
| def print_json_results(results: Dict): | |
| """Print results in JSON format.""" | |
| print("\n" + "="*50) | |
| print("JSON RESULTS") | |
| print("="*50) | |
| print(json.dumps(results, indent=2)) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description='Embedding Test Suite - Test embedding quality across datasets and models' | |
| ) | |
| parser.add_argument( | |
| '--datasets', | |
| nargs='+', | |
| default=None, | |
| help=f'List of dataset names to test (default: EC)' | |
| ) | |
| parser.add_argument( | |
| '--model_names', | |
| nargs='+', | |
| default=None, | |
| help='List of model names to test (default: all currently_supported_models)' | |
| ) | |
| parser.add_argument( | |
| '--pooling_methods', | |
| nargs='+', | |
| default=['cls', 'mean,var'], | |
| help='List of pooling methods to test (default: mean, var, cls, parti, mean,var)' | |
| ) | |
| parser.add_argument( | |
| '--sample_frac', | |
| type=float, | |
| default=0.1, | |
| help='Fraction of sequences to sample from each dataset (default: 0.1)' | |
| ) | |
| parser.add_argument( | |
| '--batch_size', | |
| type=int, | |
| default=16, | |
| help='Batch size for embedding (default: 16)' | |
| ) | |
| parser.add_argument( | |
| '--num_workers', | |
| type=int, | |
| default=0, | |
| help='Number of workers for data loading (default: 0)' | |
| ) | |
| parser.add_argument( | |
| '--seed', | |
| type=int, | |
| default=None, | |
| help='Random seed for reproducibility' | |
| ) | |
| args = parser.parse_args() | |
| # Set seed if provided | |
| if args.seed is not None: | |
| set_global_seed(args.seed) | |
| # Run test suite | |
| results = run_test_suite( | |
| dataset_names=args.datasets, | |
| model_names=args.model_names, | |
| pooling_methods=args.pooling_methods, | |
| sample_frac=args.sample_frac, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers | |
| ) | |
| return results | |
| if __name__ == '__main__': | |
| main() | |