Ahmed-El-Sharkawy commited on
Commit
7f4b1c2
·
1 Parent(s): 98c9750

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ import torchvision.models as models
7
+ import os
8
+ import torch
9
+
10
+ # Set device
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # Load the main classifier (Detector_best_model.pth)
14
+ main_model = models.resnet18(weights=None) # Updated: weights=None
15
+ num_ftrs = main_model.fc.in_features
16
+ main_model.fc = nn.Linear(num_ftrs, 2) # 2 classes: AI-generated_Image, Real_Image
17
+ main_model.load_state_dict(torch.load('best_model (5).pth', map_location=device, weights_only=True)) # Updated: weights_only=True
18
+ main_model = main_model.to(device)
19
+ main_model.eval()
20
+
21
+ # Define class names for the classifier based on the Folder structure
22
+ classes_name = ['AI-generated Image', 'Real Image']
23
+
24
+ def convert_to_rgb(image):
25
+ """
26
+ Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'.
27
+ This is to avoid transparency issues during model training.
28
+ """
29
+ if image.mode in ('P', 'RGBA'):
30
+ return image.convert('RGB')
31
+ return image
32
+
33
+ # Define preprocessing transformations (same used during training)
34
+ preprocess = transforms.Compose([
35
+ transforms.Lambda(convert_to_rgb),
36
+ transforms.Resize((224, 224)), # Resize here, no need for shape argument in gr.Image
37
+ transforms.ToTensor(),
38
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
39
+ ])
40
+
41
+ def classify_image(image):
42
+ # Open the image using PIL
43
+ image = Image.fromarray(image)
44
+
45
+ # Preprocess the image
46
+ input_image = preprocess(image).unsqueeze(0).to(device)
47
+
48
+ # Perform inference with the main classifier
49
+ with torch.no_grad():
50
+ output = main_model(input_image)
51
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
52
+ confidence, predicted_class = torch.max(probabilities, 0)
53
+
54
+ # Main classifier result
55
+ main_prediction = classes_name[predicted_class]
56
+ main_confidence = confidence.item()
57
+
58
+ return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})"
59
+
60
+ # Gradio interface (updated)
61
+ image_input = gr.Image(image_mode="RGB") # Removed shape argument
62
+ output_text = gr.Textbox()
63
+
64
+ gr.Interface(fn=classify_image, inputs=image_input, outputs=[output_text],
65
+ title="Detect AI-generated Image ",
66
+ description="Upload an image to Detected AI-generated Image .",
67
+ theme="default").launch()