# ---------------------------------------------------------------------------- # IMPORTS # ---------------------------------------------------------------------------- import os import argparse import json import time import yaml import torch from PIL import Image import torchvision.transforms.v2 as Tv2 from networks import ImageClassifier import sys project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(project_root) from support.detect_utils import format_result, save_result, get_device # ---------------------------------------------------------------------------- # IMAGE PREPROCESSING # ---------------------------------------------------------------------------- def preprocess_image(image_path): """ Load and preprocess a single image for model input. Uses the same normalization as test.py (ImageNet stats). """ # Load image image = Image.open(image_path).convert('RGB') # Apply transforms (same as test split without augmentation) transform = Tv2.Compose([ Tv2.ToImage(), Tv2.ToDtype(torch.float32, scale=True), Tv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Apply transforms and add batch dimension tensor = transform(image) tensor = tensor.unsqueeze(0) # Add batch dimension return tensor # ---------------------------------------------------------------------------- # CONFIG LOADING AND PARSING # ---------------------------------------------------------------------------- def load_config(config_path): """Load configuration from YAML file.""" with open(config_path, 'r') as f: config = yaml.safe_load(f) return config def parse_detector_args(detector_args, default_num_centers=1): """ Parse detector_args list (e.g., ["--arch", "nodown", "--prototype", "--freeze"]) into a settings object. """ class Settings: def __init__(self): self.arch = "nodown" self.freeze = False self.prototype = False self.num_centers = default_num_centers settings = Settings() i = 0 while i < len(detector_args): arg = detector_args[i] if arg == "--arch": if i + 1 < len(detector_args): settings.arch = detector_args[i + 1] i += 2 else: i += 1 elif arg == "--freeze": settings.freeze = True i += 1 elif arg == "--prototype": settings.prototype = True i += 1 elif arg == "--num_centers": if i + 1 < len(detector_args): settings.num_centers = int(detector_args[i + 1]) i += 2 else: i += 1 else: i += 1 return settings def resolve_config_path(config_path): """ Resolve config path. If relative, resolve relative to project root (two levels up from detect.py location). """ if os.path.isabs(config_path): return config_path # Get directory of detect.py (detectors/R50_TF/) detect_dir = os.path.dirname(os.path.abspath(__file__)) # Go two levels up to project root project_root = os.path.dirname(os.path.dirname(detect_dir)) # Join with config path return os.path.join(project_root, config_path) # ---------------------------------------------------------------------------- # INFERENCE # ---------------------------------------------------------------------------- def run_inference(model, image_path, device): """ Run inference on a single image. Returns: (probability, label, runtime_ms) """ start_time = time.time() # Preprocess image image_tensor = preprocess_image(image_path) image_tensor = image_tensor.to(device) # Run inference model.eval() with torch.no_grad(): raw_score_tensor = model(image_tensor).squeeze(1) # shape [1] # Convert to probability using sigmoid probability = torch.sigmoid(raw_score_tensor).item() # Determine label (fake if probability > 0.5, else real) label = "fake" if probability > 0.5 else "real" # Calculate runtime in milliseconds runtime_ms = int((time.time() - start_time) * 1000) return probability, label, runtime_ms # ---------------------------------------------------------------------------- # MAIN # ---------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description='Single image inference for R50_TF detector') parser.add_argument('--input', type=str, required=False, help='Path to input image (alias: --image)') parser.add_argument('--image', type=str, required=False, help='Path to input image (alias for --input)') parser.add_argument('--output', type=str, default='/tmp/result.json', help='Path to output JSON file') parser.add_argument('--checkpoint', type=str, required=False, help='Path to model checkpoint file') parser.add_argument('--model', type=str, required=False, help='Model name or checkpoint directory (alias for --checkpoint)') parser.add_argument('--config', type=str, default='configs/R50_TF.yaml', help='Path to YAML config file') parser.add_argument('--device', type=str, default=None, help='Device to use (cuda:0, cpu, etc.)') args = parser.parse_args() # Normalize image argument: prefer --image over --input if provided if args.image: args.input = args.image checkpoint_path = None if args.checkpoint: checkpoint_path = args.checkpoint elif getattr(args, 'model', None): detect_dir = os.path.dirname(os.path.abspath(__file__)) candidate = os.path.join(detect_dir, 'checkpoint', args.model, 'weights', 'best.pt') if os.path.exists(candidate): checkpoint_path = candidate else: # If model refers directly to a file path, accept it if os.path.isabs(args.model) and os.path.exists(args.model): checkpoint_path = args.model else: # Try resolving relative to project root project_root = os.path.dirname(os.path.dirname(detect_dir)) candidate2 = os.path.join(project_root, args.model) if os.path.exists(candidate2): checkpoint_path = candidate2 # If still not found, keep existing behavior (will raise later) if checkpoint_path: args.checkpoint = checkpoint_path # Resolve config path config_path = resolve_config_path(args.config) if not os.path.exists(config_path): raise FileNotFoundError(f"Configuration file not found: {config_path}") # Load config config = load_config(config_path) # Parse detector_args from config detector_args = config.get('detector_args', []) settings = parse_detector_args(detector_args) # Get device from config if available, else use argument # Prioritize argument if explicitly provided (we assume if it's not default, or if we trust the caller) # Since we want to support --device cpu override, we should prioritize args.device device_str = args.device # Only check config if args.device wasn't explicitly passed (but here it has a default) # Let's assume if the user passed --device, they want that. # But args.device has a default 'cuda:0'. # We should change the default to None to distinguish. if args.device is None: if config.get('global', {}).get('device_override'): device_override = config['global']['device_override'] if device_override and device_override != "null" and device_override != "": device_str = device_override else: device_str = 'cuda:0' else: device_str = args.device # Determine device if device_str.startswith('cuda') and not torch.cuda.is_available(): print(f"Warning: CUDA requested but not available. Using CPU.") device = torch.device('cpu') else: device = torch.device(device_str if torch.cuda.is_available() else 'cpu') # Load model print(f"Loading model from {args.checkpoint}") model = ImageClassifier(settings) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) model.to(device) model.eval() # Run inference print(f"Running inference on {args.input}") probability, label, runtime_ms = run_inference(model, args.input, device) # Format result to match other detectors (prediction/confidence/elapsed_time) elapsed_time = runtime_ms / 1000.0 formatted = format_result(label, float(round(probability, 4)), elapsed_time) # Save using shared utility (if output path is provided) if args.output: save_result(formatted, args.output) print(f"Results saved to {args.output}") # Print concise output for user print(f"Prediction: {formatted['prediction']}") print(f"Confidence: {formatted['confidence']:.4f}") print(f"Time: {formatted['elapsed_time']:.3f}s") if __name__ == '__main__': main()