Spaces:
Sleeping
Sleeping
| # ---------------------------------------------------------------------------- | |
| # IMPORTS | |
| # ---------------------------------------------------------------------------- | |
| import os | |
| import torch | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import json | |
| import time | |
| import numpy as np | |
| from sklearn.metrics import roc_auc_score, accuracy_score | |
| from networks import ImageClassifier | |
| from parser import get_parser | |
| from dataset import create_dataloader | |
| def test(loader, model, settings, device): | |
| model.eval() | |
| start_time = time.time() | |
| # File paths | |
| output_dir = f'./results/{settings.name}/{settings.data_keys}/data/' | |
| os.makedirs(output_dir, exist_ok=True) | |
| csv_filename = os.path.join(output_dir, 'results.csv') | |
| metrics_filename = os.path.join(output_dir, 'metrics.json') | |
| image_results_filename = os.path.join(output_dir, 'image_results.json') | |
| # Collect all results | |
| all_scores = [] | |
| all_labels = [] | |
| all_paths = [] | |
| image_results = [] | |
| # Parse dataset keys from settings.data_keys (format: "key1&key2&..." or single "key") | |
| dataset_keys = settings.data_keys.split('&') if '&' in settings.data_keys else [settings.data_keys] | |
| # Extract training dataset keys from model name (format: "training_keys_freeze_down" or "training_keys") | |
| # The model name typically contains the training dataset keys used for training | |
| training_dataset_keys = [] | |
| model_name = settings.name | |
| # Remove common suffixes like "_freeze_down" | |
| if '_freeze_down' in model_name: | |
| training_name = model_name.replace('_freeze_down', '') | |
| else: | |
| training_name = model_name | |
| # Split by & to get individual training dataset keys | |
| if '&' in training_name: | |
| training_dataset_keys = training_name.split('&') | |
| else: | |
| training_dataset_keys = [training_name] | |
| # Write CSV header | |
| with open(csv_filename, 'w') as f: | |
| f.write(f"{','.join(['name', 'pro', 'flag'])}\n") | |
| with torch.no_grad(): | |
| with tqdm(loader, unit='batch', mininterval=0.5) as tbatch: | |
| tbatch.set_description(f'Validation') | |
| for (data, labels, paths) in tbatch: | |
| data = data.to(device) | |
| labels = labels.to(device) | |
| scores = model(data).squeeze(1) | |
| # Collect results | |
| for score, label, path in zip(scores, labels, paths): | |
| score_val = score.item() | |
| label_val = label.item() | |
| all_scores.append(score_val) | |
| all_labels.append(label_val) | |
| all_paths.append(path) | |
| image_results.append({ | |
| 'path': path, | |
| 'score': score_val, | |
| 'label': label_val | |
| }) | |
| # Write to CSV (maintain backward compatibility) | |
| with open(csv_filename, 'a') as f: | |
| for score, label, path in zip(scores, labels, paths): | |
| f.write(f"{path}, {score.item()}, {label.item()}\n") | |
| # Calculate metrics | |
| all_scores = np.array(all_scores) | |
| all_labels = np.array(all_labels) | |
| # Convert scores to predictions (threshold at 0, as used in train.py) | |
| predictions = (all_scores > 0).astype(int) | |
| # Calculate overall metrics | |
| total_accuracy = accuracy_score(all_labels, predictions) | |
| # TPR (True Positive Rate) = TP / (TP + FN) = accuracy on fake images (label==1) | |
| fake_mask = all_labels == 1 | |
| if fake_mask.sum() > 0: | |
| tpr = accuracy_score(all_labels[fake_mask], predictions[fake_mask]) | |
| else: | |
| tpr = 0.0 | |
| # Calculate TNR on real images (label==0) in the test set | |
| real_mask = all_labels == 0 | |
| if real_mask.sum() > 0: | |
| # Overall TNR calculated on all real images in the test set | |
| tnr = accuracy_score(all_labels[real_mask], predictions[real_mask]) | |
| else: | |
| tnr = 0.0 | |
| # AUC calculation (needs probabilities, so we'll use sigmoid on scores) | |
| if len(np.unique(all_labels)) > 1: | |
| # Apply sigmoid to convert scores to probabilities | |
| probabilities = torch.sigmoid(torch.tensor(all_scores)).numpy() | |
| auc = roc_auc_score(all_labels, probabilities) | |
| else: | |
| auc = 0.0 | |
| execution_time = time.time() - start_time | |
| # Prepare metrics JSON | |
| metrics = { | |
| 'TPR': float(tpr), | |
| 'TNR': float(tnr), | |
| 'Acc total': float(total_accuracy), | |
| 'AUC': float(auc), | |
| 'execution time': float(execution_time) | |
| } | |
| # Write metrics JSON | |
| with open(metrics_filename, 'w') as f: | |
| json.dump(metrics, f, indent=2) | |
| # Write individual image results JSON | |
| with open(image_results_filename, 'w') as f: | |
| json.dump(image_results, f, indent=2) | |
| print(f'\nMetrics saved to {metrics_filename}') | |
| print(f'Image results saved to {image_results_filename}') | |
| print(f'\nMetrics:') | |
| print(f' TPR: {tpr:.4f}') | |
| print(f' TNR: {tnr:.4f}') | |
| print(f' Accuracy: {total_accuracy:.4f}') | |
| print(f' AUC: {auc:.4f}') | |
| print(f' Execution time: {execution_time:.2f} seconds') | |
| if __name__ == "__main__": | |
| parser = get_parser() | |
| settings = parser.parse_args() | |
| device = torch.device(settings.device if torch.cuda.is_available() else 'cpu') | |
| test_dataloader = create_dataloader(settings, split='test') | |
| model = ImageClassifier(settings) | |
| model.to(device) | |
| path_weight = f'./checkpoint/{settings.name}/weights/best.pt' | |
| state_dict = torch.load(path_weight) | |
| model.load_state_dict(state_dict) | |
| test(test_dataloader, model, settings, device) |