| | """ |
| | Easy inference script for Fake Image Detection |
| | Usage: python inference.py --image path/to/image.jpg |
| | """ |
| |
|
| | import torch |
| | from torchvision import transforms |
| | from PIL import Image |
| | import pickle |
| | import json |
| | import argparse |
| | from huggingface_hub import hf_hub_download |
| | from model import EnhancedFreqVAE, EdgeNormalizingFlow, SemanticDeepSVDD, Ensemble |
| |
|
| |
|
| | def load_models(device='cuda'): |
| | """Load all models from Hugging Face""" |
| | repo_id = "ash12321/fake-image-detection-ensemble" |
| | |
| | print("📥 Downloading models from Hugging Face...") |
| | |
| | |
| | config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
| | with open(config_path, 'r') as f: |
| | config = json.load(f) |
| | |
| | |
| | print("Loading Frequency VAE...") |
| | freq_vae = EnhancedFreqVAE() |
| | vae_path = hf_hub_download(repo_id=repo_id, filename="freq_vae.pth") |
| | freq_vae.load_state_dict(torch.load(vae_path, map_location=device)) |
| | freq_vae.to(device) |
| | freq_vae.eval() |
| | |
| | print("Loading Edge Flow...") |
| | edge_flow = EdgeNormalizingFlow() |
| | flow_path = hf_hub_download(repo_id=repo_id, filename="edge_flow.pth") |
| | edge_flow.load_state_dict(torch.load(flow_path, map_location=device)) |
| | edge_flow.to(device) |
| | edge_flow.eval() |
| | |
| | print("Loading Semantic SVDD...") |
| | semantic_svdd = SemanticDeepSVDD() |
| | svdd_path = hf_hub_download(repo_id=repo_id, filename="semantic_svdd.pth") |
| | checkpoint = torch.load(svdd_path, map_location=device) |
| | semantic_svdd.load_state_dict(checkpoint['model']) |
| | semantic_svdd.center = checkpoint['center'] |
| | semantic_svdd.to(device) |
| | semantic_svdd.eval() |
| | |
| | |
| | print("Loading traditional ML models...") |
| | texture_path = hf_hub_download(repo_id=repo_id, filename="texture_ocsvm.pkl") |
| | with open(texture_path, 'rb') as f: |
| | texture_ocsvm = pickle.load(f) |
| | |
| | color_path = hf_hub_download(repo_id=repo_id, filename="color_model.pkl") |
| | with open(color_path, 'rb') as f: |
| | color_model = pickle.load(f) |
| | |
| | stat_path = hf_hub_download(repo_id=repo_id, filename="stat.pkl") |
| | with open(stat_path, 'rb') as f: |
| | stat = pickle.load(f) |
| | |
| | iforest_path = hf_hub_download(repo_id=repo_id, filename="iforest.pkl") |
| | with open(iforest_path, 'rb') as f: |
| | iforest = pickle.load(f) |
| | |
| | lof_path = hf_hub_download(repo_id=repo_id, filename="lof.pkl") |
| | with open(lof_path, 'rb') as f: |
| | lof = pickle.load(f) |
| | |
| | gmm_path = hf_hub_download(repo_id=repo_id, filename="gmm.pkl") |
| | with open(gmm_path, 'rb') as f: |
| | gmm = pickle.load(f) |
| | |
| | |
| | models_dict = { |
| | 'freq_vae': freq_vae, |
| | 'texture_ocsvm': texture_ocsvm, |
| | 'color_model': color_model, |
| | 'edge_flow': edge_flow, |
| | 'semantic_svdd': semantic_svdd, |
| | 'stat': stat, |
| | 'iforest': iforest, |
| | 'lof': lof, |
| | 'gmm': gmm |
| | } |
| | |
| | ensemble = Ensemble(models_dict) |
| | ensemble.wts = config['weights'] |
| | ensemble.norms = config['norms'] |
| | ensemble.thresh = config['thresh'] |
| | |
| | print("✓ All models loaded!\n") |
| | return ensemble, device |
| |
|
| |
|
| | def predict_image(image_path, ensemble, device): |
| | """Predict if an image is fake""" |
| | |
| | img = Image.open(image_path) |
| | img = img.resize((256, 256), Image.LANCZOS).convert('RGB') |
| | |
| | tfm = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]) |
| | ]) |
| | img_tensor = tfm(img) |
| | |
| | |
| | is_fake, score, individual_scores = ensemble.predict(img_tensor, device) |
| | |
| | return { |
| | 'prediction': 'FAKE' if is_fake else 'REAL', |
| | 'confidence': abs(score), |
| | 'anomaly_score': score, |
| | 'individual_scores': individual_scores |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description='Detect fake images') |
| | parser.add_argument('--image', type=str, required=True, help='Path to image') |
| | parser.add_argument('--device', type=str, default='cuda', help='Device (cuda/cpu)') |
| | args = parser.parse_args() |
| | |
| | |
| | device = args.device if torch.cuda.is_available() else 'cpu' |
| | print(f"Using device: {device}\n") |
| | |
| | |
| | ensemble, device = load_models(device) |
| | |
| | |
| | print(f"Analyzing: {args.image}") |
| | result = predict_image(args.image, ensemble, device) |
| | |
| | print("\n" + "="*50) |
| | print("RESULT") |
| | print("="*50) |
| | print(f"Prediction: {result['prediction']}") |
| | print(f"Confidence: {result['confidence']:.4f}") |
| | print(f"Anomaly Score: {result['anomaly_score']:.4f}") |
| | print(f"\nIndividual Model Scores:") |
| | for model, score in result['individual_scores'].items(): |
| | print(f" {model}: {score:.4f}") |
| | print("="*50) |
| |
|