""" Evaluation on simulated test set with 30 random samplings. Implements evaluation protocol from Section 4.4. """ import torch from torch.utils.data import DataLoader from transformers import AutoTokenizer import numpy as np from config import Config from models.mmrm import MMRM from data_utils.dataset import MMRMDataset from evaluation.metrics import RestorationMetrics from utils.font_utils import FontManager from utils.tensorboard_tracker import TensorBoardTracker def evaluate_on_test_set( config: Config, checkpoint_path: str, num_samples: int = 30, num_masks: int = 1 ) -> dict: """ Evaluate model on test set with multiple random samplings. As per paper Section 4.4: "all simulation results are the averages obtained after randomly sampling the damaged characters on the test set 30 times" Args: config: Configuration object checkpoint_path: Path to model checkpoint num_samples: Number of random samplings (default: 30) num_masks: Number of masks per sample (1 for single, or use random 1-5) Returns: Dictionary of averaged metrics """ device = torch.device(config.device if torch.cuda.is_available() or config.device == "cuda" else "cpu") # Initialize TensorBoard tracker tb_tracker = TensorBoardTracker(config) # Load model model = MMRM(config).to(device) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only = False) # Check if this is a Phase 1 checkpoint (separate state dicts) or Phase 2 (full model) if 'decoder_state_dict' in checkpoint: print("Detected Phase 1 checkpoint (separate encoder/decoder state dicts).") # Phase 1 saves context_encoder as 'model_state_dict' and text_decoder as 'decoder_state_dict' try: model.context_encoder.load_state_dict(checkpoint['model_state_dict']) model.text_decoder.load_state_dict(checkpoint['decoder_state_dict']) print("Successfully loaded Phase 1 weights into ContextEncoder and TextDecoder.") print("Warning: ImageEncoder and ImageDecoder are using random initialization (expected for Phase 1).") except RuntimeError as e: # Fallback or detail error reporting print(f"Error loading Phase 1 weights: {e}") raise e else: # Phase 2 saves the full MMRM model in 'model_state_dict' model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"Loaded model from {checkpoint_path}") print(f"Evaluating with {num_samples} random samplings...") # Initialize tokenizer and font manager tokenizer = AutoTokenizer.from_pretrained(config.roberta_model) font_manager = FontManager(config.font_dir, config.image_size, config.min_black_pixels) # Start TensorBoard run with tb_tracker.start_run(run_name="Evaluation_Simulation"): # Log evaluation parameters tb_tracker.log_params({ "checkpoint_path": checkpoint_path, "num_samples": num_samples, "num_masks": num_masks, "evaluation_type": "simulation" }) tb_tracker.set_tags({ "evaluation": "simulation", "num_samplings": str(num_samples) }) # Run multiple samplings all_metrics = [] for sample_idx in range(num_samples): print(f"\nSampling {sample_idx + 1}/{num_samples}") # Create test dataset (randomness comes from mask selection and augmentation) test_dataset = MMRMDataset( config, 'test', tokenizer, font_manager, num_masks=num_masks, curriculum_epoch=None ) test_loader = DataLoader( test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=config.pin_memory ) # Evaluate metrics = RestorationMetrics(config.top_k_values) with torch.no_grad(): for batch in test_loader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) mask_positions = batch['mask_positions'].to(device) damaged_images = batch['damaged_images'].to(device) labels = batch['labels'].to(device) # Forward pass text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images) # Update metrics metrics.update(text_logits, labels) sample_metrics = metrics.compute() all_metrics.append(sample_metrics) # Log individual sampling metrics to TensorBoard sample_metrics_prefixed = {f"eval/sampling_{sample_idx}_{k}": v for k, v in sample_metrics.items()} tb_tracker.log_metrics(sample_metrics_prefixed) print(f" Acc={sample_metrics['accuracy']:.2f}%, " f"Hit@5={sample_metrics['hit_5']:.2f}%, " f"Hit@10={sample_metrics['hit_10']:.2f}%, " f"Hit@20={sample_metrics['hit_20']:.2f}%, " f"MRR={sample_metrics['mrr']:.2f}") # Compute average and std averaged_metrics = {} for key in all_metrics[0].keys(): values = [m[key] for m in all_metrics] averaged_metrics[key] = np.mean(values) averaged_metrics[f'{key}_std'] = np.std(values) # Log averaged metrics to TensorBoard tb_tracker.log_metrics({ "eval/avg_accuracy": averaged_metrics['accuracy'], "eval/avg_hit_5": averaged_metrics['hit_5'], "eval/avg_hit_10": averaged_metrics['hit_10'], "eval/avg_hit_20": averaged_metrics['hit_20'], "eval/avg_mrr": averaged_metrics['mrr'], "eval/std_accuracy": averaged_metrics['accuracy_std'], "eval/std_hit_5": averaged_metrics['hit_5_std'], "eval/std_hit_10": averaged_metrics['hit_10_std'], "eval/std_hit_20": averaged_metrics['hit_20_std'], "eval/std_mrr": averaged_metrics['mrr_std'] }) # Log all metrics as a JSON artifact tb_tracker.log_dict(averaged_metrics, "evaluation_results.json", artifact_path="metrics") print(f"\n{'='*70}") print(f"Final Results (averaged over {num_samples} samplings):") print(f"{'='*70}") print(f"Accuracy: {averaged_metrics['accuracy']:.2f} ± {averaged_metrics['accuracy_std']:.2f}%") print(f"Hit@5: {averaged_metrics['hit_5']:.2f} ± {averaged_metrics['hit_5_std']:.2f}%") print(f"Hit@10: {averaged_metrics['hit_10']:.2f} ± {averaged_metrics['hit_10_std']:.2f}%") print(f"Hit@20: {averaged_metrics['hit_20']:.2f} ± {averaged_metrics['hit_20_std']:.2f}%") print(f"MRR: {averaged_metrics['mrr']:.2f} ± {averaged_metrics['mrr_std']:.2f}") print(f"{'='*70}") return averaged_metrics if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python evaluate_simulation.py [num_samples]") sys.exit(1) checkpoint_path = sys.argv[1] num_samples = int(sys.argv[2]) if len(sys.argv) > 2 else 30 config = Config() results = evaluate_on_test_set(config, checkpoint_path, num_samples)