|
|
""" |
|
|
Evaluation on simulated test set with 30 random samplings. |
|
|
Implements evaluation protocol from Section 4.4. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import AutoTokenizer |
|
|
import numpy as np |
|
|
|
|
|
from config import Config |
|
|
from models.mmrm import MMRM |
|
|
from data_utils.dataset import MMRMDataset |
|
|
from evaluation.metrics import RestorationMetrics |
|
|
from utils.font_utils import FontManager |
|
|
from utils.tensorboard_tracker import TensorBoardTracker |
|
|
|
|
|
|
|
|
def evaluate_on_test_set( |
|
|
config: Config, |
|
|
checkpoint_path: str, |
|
|
num_samples: int = 30, |
|
|
num_masks: int = 1 |
|
|
) -> dict: |
|
|
""" |
|
|
Evaluate model on test set with multiple random samplings. |
|
|
|
|
|
As per paper Section 4.4: "all simulation results are the averages obtained |
|
|
after randomly sampling the damaged characters on the test set 30 times" |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
checkpoint_path: Path to model checkpoint |
|
|
num_samples: Number of random samplings (default: 30) |
|
|
num_masks: Number of masks per sample (1 for single, or use random 1-5) |
|
|
|
|
|
Returns: |
|
|
Dictionary of averaged metrics |
|
|
""" |
|
|
device = torch.device(config.device if torch.cuda.is_available() or config.device == "cuda" else "cpu") |
|
|
|
|
|
|
|
|
tb_tracker = TensorBoardTracker(config) |
|
|
|
|
|
|
|
|
model = MMRM(config).to(device) |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only = False) |
|
|
|
|
|
|
|
|
if 'decoder_state_dict' in checkpoint: |
|
|
print("Detected Phase 1 checkpoint (separate encoder/decoder state dicts).") |
|
|
|
|
|
try: |
|
|
model.context_encoder.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.text_decoder.load_state_dict(checkpoint['decoder_state_dict']) |
|
|
print("Successfully loaded Phase 1 weights into ContextEncoder and TextDecoder.") |
|
|
print("Warning: ImageEncoder and ImageDecoder are using random initialization (expected for Phase 1).") |
|
|
except RuntimeError as e: |
|
|
|
|
|
print(f"Error loading Phase 1 weights: {e}") |
|
|
raise e |
|
|
else: |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
print(f"Loaded model from {checkpoint_path}") |
|
|
print(f"Evaluating with {num_samples} random samplings...") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.roberta_model) |
|
|
font_manager = FontManager(config.font_dir, config.image_size, config.min_black_pixels) |
|
|
|
|
|
|
|
|
with tb_tracker.start_run(run_name="Evaluation_Simulation"): |
|
|
|
|
|
tb_tracker.log_params({ |
|
|
"checkpoint_path": checkpoint_path, |
|
|
"num_samples": num_samples, |
|
|
"num_masks": num_masks, |
|
|
"evaluation_type": "simulation" |
|
|
}) |
|
|
tb_tracker.set_tags({ |
|
|
"evaluation": "simulation", |
|
|
"num_samplings": str(num_samples) |
|
|
}) |
|
|
|
|
|
|
|
|
all_metrics = [] |
|
|
|
|
|
for sample_idx in range(num_samples): |
|
|
print(f"\nSampling {sample_idx + 1}/{num_samples}") |
|
|
|
|
|
|
|
|
test_dataset = MMRMDataset( |
|
|
config, |
|
|
'test', |
|
|
tokenizer, |
|
|
font_manager, |
|
|
num_masks=num_masks, |
|
|
curriculum_epoch=None |
|
|
) |
|
|
|
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory |
|
|
) |
|
|
|
|
|
|
|
|
metrics = RestorationMetrics(config.top_k_values) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in test_loader: |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
mask_positions = batch['mask_positions'].to(device) |
|
|
damaged_images = batch['damaged_images'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
|
|
|
text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images) |
|
|
|
|
|
|
|
|
metrics.update(text_logits, labels) |
|
|
|
|
|
sample_metrics = metrics.compute() |
|
|
all_metrics.append(sample_metrics) |
|
|
|
|
|
|
|
|
sample_metrics_prefixed = {f"eval/sampling_{sample_idx}_{k}": v for k, v in sample_metrics.items()} |
|
|
tb_tracker.log_metrics(sample_metrics_prefixed) |
|
|
|
|
|
print(f" Acc={sample_metrics['accuracy']:.2f}%, " |
|
|
f"Hit@5={sample_metrics['hit_5']:.2f}%, " |
|
|
f"Hit@10={sample_metrics['hit_10']:.2f}%, " |
|
|
f"Hit@20={sample_metrics['hit_20']:.2f}%, " |
|
|
f"MRR={sample_metrics['mrr']:.2f}") |
|
|
|
|
|
|
|
|
averaged_metrics = {} |
|
|
for key in all_metrics[0].keys(): |
|
|
values = [m[key] for m in all_metrics] |
|
|
averaged_metrics[key] = np.mean(values) |
|
|
averaged_metrics[f'{key}_std'] = np.std(values) |
|
|
|
|
|
|
|
|
tb_tracker.log_metrics({ |
|
|
"eval/avg_accuracy": averaged_metrics['accuracy'], |
|
|
"eval/avg_hit_5": averaged_metrics['hit_5'], |
|
|
"eval/avg_hit_10": averaged_metrics['hit_10'], |
|
|
"eval/avg_hit_20": averaged_metrics['hit_20'], |
|
|
"eval/avg_mrr": averaged_metrics['mrr'], |
|
|
"eval/std_accuracy": averaged_metrics['accuracy_std'], |
|
|
"eval/std_hit_5": averaged_metrics['hit_5_std'], |
|
|
"eval/std_hit_10": averaged_metrics['hit_10_std'], |
|
|
"eval/std_hit_20": averaged_metrics['hit_20_std'], |
|
|
"eval/std_mrr": averaged_metrics['mrr_std'] |
|
|
}) |
|
|
|
|
|
|
|
|
tb_tracker.log_dict(averaged_metrics, "evaluation_results.json", artifact_path="metrics") |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"Final Results (averaged over {num_samples} samplings):") |
|
|
print(f"{'='*70}") |
|
|
print(f"Accuracy: {averaged_metrics['accuracy']:.2f} ± {averaged_metrics['accuracy_std']:.2f}%") |
|
|
print(f"Hit@5: {averaged_metrics['hit_5']:.2f} ± {averaged_metrics['hit_5_std']:.2f}%") |
|
|
print(f"Hit@10: {averaged_metrics['hit_10']:.2f} ± {averaged_metrics['hit_10_std']:.2f}%") |
|
|
print(f"Hit@20: {averaged_metrics['hit_20']:.2f} ± {averaged_metrics['hit_20_std']:.2f}%") |
|
|
print(f"MRR: {averaged_metrics['mrr']:.2f} ± {averaged_metrics['mrr_std']:.2f}") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
return averaged_metrics |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python evaluate_simulation.py <checkpoint_path> [num_samples]") |
|
|
sys.exit(1) |
|
|
|
|
|
checkpoint_path = sys.argv[1] |
|
|
num_samples = int(sys.argv[2]) if len(sys.argv) > 2 else 30 |
|
|
|
|
|
config = Config() |
|
|
results = evaluate_on_test_set(config, checkpoint_path, num_samples) |
|
|
|