Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| import os | |
| import torch | |
| # Set device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load the main classifier (Detector_best_model.pth) | |
| main_model = models.mobilenet_v3_large(weights=None) # Updated: weights=None | |
| #num_ftrs = main_model.fc.in_features | |
| # main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image | |
| num_ftrs = main_model.classifier[3].in_features | |
| main_model.classifier[3] = nn.Linear(num_ftrs, 2) | |
| # main_model.fc = nn.Sequential( | |
| # nn.Dropout(p=0.5), # Match the training architecture | |
| # nn.Linear(num_ftrs, 2) # 2 classes: AI-generated Image, Real Image | |
| # ) | |
| main_model.load_state_dict(torch.load('best_model3_mobilenetv3_large.pth', map_location=device, weights_only=True)) # Updated: weights_only=True | |
| main_model = main_model.to(device) | |
| main_model.eval() | |
| # Define class names for the classifier based on the Folder structure | |
| classes_name = ['AI-generated Image', 'Real Image'] | |
| def convert_to_rgb(image): | |
| """ | |
| Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'. | |
| This is to avoid transparency issues during model training. | |
| """ | |
| if image.mode in ('P', 'RGBA'): | |
| return image.convert('RGB') | |
| return image | |
| # Define preprocessing transformations (same used during training) | |
| preprocess = transforms.Compose([ | |
| transforms.Lambda(convert_to_rgb), | |
| transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization | |
| ]) | |
| def classify_image(image): | |
| # Open the image using PIL | |
| image = Image.fromarray(image) | |
| # Preprocess the image | |
| input_image = preprocess(image).unsqueeze(0).to(device) | |
| # Perform inference with the main classifier | |
| with torch.no_grad(): | |
| output = main_model(input_image) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| confidence, predicted_class = torch.max(probabilities, 0) | |
| # Main classifier result | |
| main_prediction = classes_name[predicted_class] | |
| main_confidence = confidence.item() | |
| return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})" | |
| # Gradio interface (updated) | |
| image_input = gr.Image(image_mode="RGB") # Removed shape argument | |
| output_text = gr.Textbox() | |
| # gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text], | |
| # title="Detect AI-generated Image ", | |
| # description="Upload an image to Detected AI-generated Image .", | |
| # theme="default").launch() | |
| gr.Interface( | |
| fn=classify_image, | |
| inputs=image_input, | |
| outputs=[output_text], | |
| title="Detect AI-generated Image", | |
| description=( | |
| "Upload an art image From 6 websites, collecting data from this to detect if it's AI-generated or a real image. take care image jpg or png only.\n\n" | |
| "### Main Dataset Used:\n" | |
| "- [AI-generated Images vs Real Images (Kaggle)](https://www.kaggle.com/datasets/tristanzhang32/ai-generated-images-vs-real-images)\n\n" | |
| "**Fake Images Collected From:**\n" | |
| "- 10,000 from [Stable Diffusion (OpenArt AI)](https://www.openart.ai)\n" | |
| "- 10,000 from [MidJourney (Imagine.Art)](https://www.imagine.art)\n" | |
| "- 10,000 from [DALL·E (OpenAI)](https://openai.com/dall-e-2)\n\n" | |
| "**Real Images Collected From:**\n" | |
| "- 7,500 from [WikiArt](https://www.wikiart.org)\n" | |
| "- 22,500 from [Pexels](https://www.pexels.com) and [Unsplash but take care image jpg or png only ](https://unsplash.com)\n" | |
| ), | |
| theme="default" | |
| ).launch() |