import sys import time import os import csv import torch import json from util import Logger, printSet from validate import validate from networks.resnet import resnet50 from options.test_options import TestOptions import networks.resnet as resnet import numpy as np import random from data import create_dataloader from sklearn.metrics import roc_auc_score, accuracy_score from tqdm import tqdm import pandas as pd def seed_torch(seed=1029): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.backends.cudnn.enabled = False seed_torch(100) opt = TestOptions().parse(print_options=False) opt.model_path = os.path.join(f'./checkpoint/{opt.name}/weights/best.pt') print(f'Model_path {opt.model_path}') # get model model = resnet50(num_classes=1) model.load_state_dict(torch.load(opt.model_path, map_location='cpu'), strict=True) model.to(opt.device) model.eval() opt.no_resize = False opt.no_crop = True output_dir = f'./results/{opt.name}/data/{opt.data_keys}' os.makedirs(output_dir, exist_ok=True) test_dataloader = create_dataloader(opt, split='test') model.eval() # File paths csv_filename = os.path.join(output_dir, 'results.csv') metrics_filename = os.path.join(output_dir, 'metrics.json') image_results_filename = os.path.join(output_dir, 'image_results.json') # Extract training dataset keys from model name (format: "training_keys_freeze_down" or "training_keys") training_dataset_keys = [] model_name = opt.name if '_freeze_down' in model_name: training_name = model_name.replace('_freeze_down', '') else: training_name = model_name if '&' in training_name: training_dataset_keys = training_name.split('&') else: training_dataset_keys = [training_name] # Collect all results all_scores = [] all_labels = [] all_paths = [] image_results = [] start_time = time.time() # Write CSV header with open(csv_filename, 'w') as f: f.write(f"{','.join(['name', 'pro', 'flag'])}\n") with torch.no_grad(): with tqdm(test_dataloader, unit='batch', mininterval=0.5) as tbatch: tbatch.set_description(f'Validation') for (data, labels, paths) in tbatch: data = data.to(opt.device) labels = labels.to(opt.device) scores = model(data).squeeze(1) # Collect results for score, label, path in zip(scores, labels, paths): score_val = score.item() label_val = label.item() all_scores.append(score_val) all_labels.append(label_val) all_paths.append(path) image_results.append({ 'path': path, 'score': score_val, 'label': label_val }) # Write to CSV (maintain backward compatibility) with open(csv_filename, 'a') as f: for score, label, path in zip(scores, labels, paths): f.write(f"{path}, {score.item()}, {label.item()}\n") # Calculate metrics all_scores = np.array(all_scores) all_labels = np.array(all_labels) # Convert scores to probabilities using sigmoid (as done in validate.py) probabilities = torch.sigmoid(torch.tensor(all_scores)).numpy() # Convert probabilities to predictions using threshold 0.5 (as done in validate.py) predictions = (probabilities > 0.5).astype(int) # Calculate overall metrics total_accuracy = accuracy_score(all_labels, predictions) # TPR (True Positive Rate) = TP / (TP + FN) = accuracy on fake images (label==1) fake_mask = all_labels == 1 if fake_mask.sum() > 0: tpr = accuracy_score(all_labels[fake_mask], predictions[fake_mask]) else: tpr = 0.0 # Calculate TNR on real images (label==0) in the test set real_mask = all_labels == 0 if real_mask.sum() > 0: # Overall TNR calculated on all real images in the test set tnr = accuracy_score(all_labels[real_mask], predictions[real_mask]) else: tnr = 0.0 # AUC calculation (using probabilities) if len(np.unique(all_labels)) > 1: # Need both classes for AUC auc = roc_auc_score(all_labels, probabilities) else: auc = 0.0 execution_time = time.time() - start_time # Prepare metrics JSON metrics = { 'TPR': float(tpr), 'TNR': float(tnr), 'Acc total': float(total_accuracy), 'AUC': float(auc), 'execution time': float(execution_time) } # Write metrics JSON with open(metrics_filename, 'w') as f: json.dump(metrics, f, indent=2) # Write individual image results JSON with open(image_results_filename, 'w') as f: json.dump(image_results, f, indent=2) print(f'\nMetrics saved to {metrics_filename}') print(f'Image results saved to {image_results_filename}') print(f'\nMetrics:') print(f' TPR: {tpr:.4f}') print(f' TNR: {tnr:.4f}') print(f' Accuracy: {total_accuracy:.4f}') print(f' AUC: {auc:.4f}') print(f' Execution time: {execution_time:.2f} seconds')