In [None]:
#!/usr/bin/env python3
"""
Advanced Eye Disease Detection Training Script - Fixed Version
"""

import os
import sys
import json
import numpy as np
import cv2
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Check if required packages are available
try:
 import tensorflow as tf
 from tensorflow import keras
 from tensorflow.keras import layers
 from tensorflow.keras.applications import EfficientNetB3, ResNet152V2, DenseNet201
 from tensorflow.keras.preprocessing.image import ImageDataGenerator
 from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
 from sklearn.model_selection import train_test_split, StratifiedKFold
 from sklearn.preprocessing import LabelEncoder
 from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
 from sklearn.utils.class_weight import compute_class_weight
 import seaborn as sns
 from imblearn.over_sampling import SMOTE
 from imblearn.combine import SMOTETomek
 # Optional: Kaggle API (install if needed)
 try:
 from kaggle.api.kaggle_api_extended import KaggleApi
 KAGGLE_AVAILABLE = True
 except ImportError:
 KAGGLE_AVAILABLE = False
 print("⚠️ Kaggle API not available. Manual dataset download required.")
except ImportError as e:
 print(f"Missing required package: {e}")
 print("Please run: pip install tensorflow opencv-python scikit-learn matplotlib tqdm seaborn imbalanced-learn")
 sys.exit(1)

print("✅ All required packages are available!")

class AdvancedFundusPreprocessor:
 """Advanced preprocessing pipeline for fundus images"""
 
 def __init__(self, image_size=(224, 224)):
 self.image_size = image_size
 self.setup_clahe_variants()
 
 def setup_clahe_variants(self):
 """Setup multiple CLAHE variants for different image characteristics"""
 self.clahe_normal = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
 self.clahe_aggressive = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(4,4))
 self.clahe_gentle = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(16,16))
 
 def advanced_preprocess_image(self, image_path):
 """Advanced preprocessing with multiple enhancement techniques"""
 try:
 # Read image
 image = cv2.imread(str(image_path))
 if image is None:
 print(f"⚠️ Could not read image: {image_path}")
 return None
 
 # Convert to RGB
 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 # Apply advanced preprocessing pipeline
 image = self.enhance_fundus_image(image)
 
 # Resize with high-quality interpolation
 image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_LANCZOS4)
 
 # Normalize to [0, 1]
 image = image.astype(np.float32) / 255.0
 
 return image
 except Exception as e:
 print(f"Error processing {image_path}: {e}")
 return None
 
 def enhance_fundus_image(self, image):
 """Advanced fundus-specific enhancement techniques"""
 # Apply CLAHE to LAB color space
 lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
 l_channel = lab[:,:,0]
 
 # Apply CLAHE to L channel
 l_clahe = self.clahe_normal.apply(l_channel)
 lab[:,:,0] = l_clahe
 image_enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
 
 # Green channel enhancement (important for fundus images)
 green_enhanced = cv2.equalizeHist(image_enhanced[:,:,1])
 image_enhanced[:,:,1] = green_enhanced
 
 # Apply gamma correction
 image_enhanced = self.adjust_gamma(image_enhanced, gamma=1.2)
 
 return image_enhanced
 
 def adjust_gamma(self, image, gamma=1.0):
 """Apply gamma correction"""
 inv_gamma = 1.0 / gamma
 table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
 return cv2.LUT(image, table)

class AdvancedEnsembleModel:
 """Advanced ensemble model combining multiple architectures"""
 
 def __init__(self, image_size=(224, 224), num_classes=7):
 self.image_size = image_size
 self.num_classes = num_classes
 
 def create_efficientnet_model(self):
 """Create EfficientNetB3 based model"""
 base_model = EfficientNetB3(
 weights='imagenet',
 include_top=False,
 input_shape=(*self.image_size, 3)
 )
 
 # Fine-tune top layers
 base_model.trainable = True
 for layer in base_model.layers[:-30]:
 layer.trainable = False
 
 inputs = keras.Input(shape=(*self.image_size, 3))
 
 # Base model
 x = base_model(inputs, training=False)
 
 # Advanced pooling
 gap = layers.GlobalAveragePooling2D()(x)
 gmp = layers.GlobalMaxPooling2D()(x)
 x = layers.Concatenate()([gap, gmp])
 
 # Advanced head
 x = layers.BatchNormalization()(x)
 x = layers.Dropout(0.5)(x)
 
 x = layers.Dense(512, activation='relu')(x)
 x = layers.BatchNormalization()(x)
 x = layers.Dropout(0.4)(x)
 
 x = layers.Dense(256, activation='relu')(x)
 x = layers.BatchNormalization()(x)
 x = layers.Dropout(0.3)(x)
 
 outputs = layers.Dense(self.num_classes, activation='softmax')(x)
 
 return keras.Model(inputs, outputs, name='EfficientNetB3_Model')
 
 def create_resnet_model(self):
 """Create ResNet152V2 based model"""
 base_model = ResNet152V2(
 weights='imagenet',
 include_top=False,
 input_shape=(*self.image_size, 3)
 )
 
 base_model.trainable = True
 for layer in base_model.layers[:-40]:
 layer.trainable = False
 
 inputs = keras.Input(shape=(*self.image_size, 3))
 
 x = base_model(inputs, training=False)
 
 # Global pooling
 x = layers.GlobalAveragePooling2D()(x)
 x = layers.BatchNormalization()(x)
 x = layers.Dropout(0.5)(x)
 
 x = layers.Dense(512, activation='relu')(x)
 x = layers.BatchNormalization()(x)
 x = layers.Dropout(0.4)(x)
 
 outputs = layers.Dense(self.num_classes, activation='softmax')(x)
 
 return keras.Model(inputs, outputs, name='ResNet152V2_Model')
 
 def create_densenet_model(self):
 """Create DenseNet201 based model"""
 base_model = DenseNet201(
 weights='imagenet',
 include_top=False,
 input_shape=(*self.image_size, 3)
 )
 
 base_model.trainable = True
 for layer in base_model.layers[:-50]:
 layer.trainable = False
 
 inputs = keras.Input(shape=(*self.image_size, 3))
 
 x = base_model(inputs, training=False)
 
 x = layers.GlobalAveragePooling2D()(x)
 x = layers.BatchNormalization()(x)
 x = layers.Dropout(0.5)(x)
 
 x = layers.Dense(512, activation='relu')(x)
 x = layers.BatchNormalization()(x)
 x = layers.Dropout(0.4)(x)
 
 outputs = layers.Dense(self.num_classes, activation='softmax')(x)
 
 return keras.Model(inputs, outputs, name='DenseNet201_Model')

class AdvancedFundusTrainer:
 def __init__(self):
 self.dataset_path = "fundus_dataset"
 self.model_path = "models"
 self.image_size = (224, 224)
 self.batch_size = 16
 
 # Initialize components
 self.preprocessor = AdvancedFundusPreprocessor(self.image_size)
 self.ensemble_model = AdvancedEnsembleModel(self.image_size)
 
 # Create directories
 Path(self.model_path).mkdir(exist_ok=True)
 Path("logs").mkdir(exist_ok=True)
 Path("plots").mkdir(exist_ok=True)
 
 def download_dataset(self):
 """Download the Kaggle dataset"""
 if not KAGGLE_AVAILABLE:
 print("❌ Kaggle API not available. Please download dataset manually.")
 print("Dataset: https://www.kaggle.com/datasets/linchundan/fundusimage1000")
 print(f"Extract to: {self.dataset_path}/")
 return False
 
 print("📥 Downloading dataset from Kaggle...")
 
 try:
 api = KaggleApi()
 api.authenticate()
 
 api.dataset_download_files(
 "linchundan/fundusimage1000",
 path=self.dataset_path,
 unzip=True
 )
 print("✅ Dataset downloaded successfully!")
 return True
 
 except Exception as e:
 print(f"❌ Download failed: {e}")
 print("Please download dataset manually and extract to fundus_dataset/")
 return False
 
 def load_dataset(self):
 """Load dataset with preprocessing - Fixed Version"""
 print("📁 Loading dataset...")
 
 images = []
 labels = []
 
 dataset_path = Path(self.dataset_path)
 
 # Debug: Print directory structure
 print(f"\nDirectory structure at {dataset_path}:")
 for item in dataset_path.rglob('*'):
 print(f" {item.relative_to(dataset_path)}")
 
 # Check for common dataset structures
 possible_paths = [
 dataset_path,
 dataset_path / "1000images",
 dataset_path / "fundusimage1000",
 dataset_path / "images",
 dataset_path / "data"
 ]
 
 actual_path = None
 for path in possible_paths:
 if path.exists():
 print(f"\nChecking path: {path}")
 # Count image files in this path
 image_count = sum(1 for _ in path.rglob('*.[jJ][pP][gG]')) + \
 sum(1 for _ in path.rglob('*.[jJ][pP][eE][gG]')) + \
 sum(1 for _ in path.rglob('*.[pP][nN][gG]'))
 print(f" Found {image_count} images")
 if image_count > 0:
 actual_path = path
 break
 
 if actual_path is None:
 print(f"\n❌ No valid dataset found in {dataset_path}")
 print("Please ensure the dataset contains image files in one of these structures:")
 print("1. Directly in fundus_dataset/")
 print("2. In fundus_dataset/1000images/")
 print("3. In fundus_dataset/fundusimage1000/")
 print("4. In subdirectories by class")
 return None, None, None
 
 print(f"\n✅ Using dataset path: {actual_path}")
 
 # Find all image files
 image_files = list(actual_path.rglob('*.[jJ][pP][gG]')) + \
 list(actual_path.rglob('*.[jJ][pP][eE][gG]')) + \
 list(actual_path.rglob('*.[pP][nN][gG]'))
 
 if not image_files:
 print("❌ No image files found in the dataset directory")
 return None, None, None
 
 print(f"\nFound {len(image_files)} image files")
 
 # Process images
 class_counts = Counter()
 for image_file in tqdm(image_files, desc="Processing images"):
 processed_image = self.preprocessor.advanced_preprocess_image(image_file)
 if processed_image is not None:
 images.append(processed_image)
 # Use parent directory name as class label
 class_label = image_file.parent.name
 if class_label == actual_path.name: # If image is in root directory
 class_label = "unknown"
 labels.append(class_label)
 class_counts[class_label] += 1
 
 if len(images) == 0:
 print("❌ No images loaded successfully")
 return None, None, None
 
 # Print class distribution
 print("\nClass distribution:")
 for class_name, count in class_counts.most_common():
 print(f" {class_name}: {count} images")
 
 # Convert to numpy arrays
 X = np.array(images)
 y = np.array(labels)
 
 # Encode labels
 label_encoder = LabelEncoder()
 y_encoded = label_encoder.fit_transform(y)
 
 print(f"\n✅ Dataset loaded: {len(X)} images, {len(label_encoder.classes_)} classes")
 
 return X, y_encoded, label_encoder.classes_
 
 def balance_dataset(self, X, y):
 """Balance dataset using SMOTE"""
 print("⚖️ Balancing dataset...")
 
 try:
 # Reshape for SMOTE
 X_flattened = X.reshape(X.shape[0], -1)
 
 # Apply SMOTE
 smote = SMOTE(random_state=42, k_neighbors=min(5, len(np.unique(y))-1))
 X_balanced, y_balanced = smote.fit_resample(X_flattened, y)
 
 # Reshape back
 X_balanced = X_balanced.reshape(-1, *self.image_size, 3)
 
 print(f"Dataset balanced: {len(X_balanced)} samples")
 return X_balanced, y_balanced
 
 except Exception as e:
 print(f"SMOTE failed: {e}, using original dataset")
 return X, y
 
 def create_data_generators(self, X_train, y_train, X_val, y_val):
 """Create data generators with augmentation"""
 # Training data generator with augmentation
 train_datagen = ImageDataGenerator(
 rotation_range=20,
 width_shift_range=0.15,
 height_shift_range=0.15,
 horizontal_flip=True,
 vertical_flip=True,
 zoom_range=0.15,
 brightness_range=[0.8, 1.2],
 shear_range=0.1,
 fill_mode='reflect'
 )
 
 # Validation data generator (no augmentation)
 val_datagen = ImageDataGenerator()
 
 train_generator = train_datagen.flow(
 X_train, y_train, 
 batch_size=self.batch_size,
 shuffle=True
 )
 
 val_generator = val_datagen.flow(
 X_val, y_val,
 batch_size=self.batch_size,
 shuffle=False
 )
 
 return train_generator, val_generator
 
 def create_ensemble(self, num_classes):
 """Create ensemble model"""
 print("🧠 Creating ensemble model...")
 
 # Create individual models
 efficientnet_model = self.ensemble_model.create_efficientnet_model()
 resnet_model = self.ensemble_model.create_resnet_model()
 densenet_model = self.ensemble_model.create_densenet_model()
 
 # Update number of classes for each model
 self.ensemble_model.num_classes = num_classes
 
 # Compile models
 optimizer = keras.optimizers.Adam(learning_rate=1e-4)
 
 for model in [efficientnet_model, resnet_model, densenet_model]:
 model.compile(
 optimizer=optimizer,
 loss='sparse_categorical_crossentropy',
 metrics=['accuracy']
 )
 
 models = [efficientnet_model, resnet_model, densenet_model]
 model_names = ['EfficientNetB3', 'ResNet152V2', 'DenseNet201']
 
 print(f"✅ Created ensemble with {len(models)} models")
 return models, model_names
 
 def train_ensemble(self, models, model_names, train_gen, val_gen, class_names, epochs=30):
 """Train ensemble models"""
 print("🚀 Training ensemble...")
 
 # Calculate class weights for imbalanced data
 y_train_full = []
 for _ in range(len(train_gen)):
 _, y_batch = next(train_gen)
 y_train_full.extend(y_batch)
 
 class_weights = compute_class_weight(
 'balanced',
 classes=np.unique(y_train_full),
 y=y_train_full
 )
 class_weight_dict = dict(enumerate(class_weights))
 
 # Reset generator
 train_gen.reset()
 
 # Train each model
 histories = []
 for model, model_name in zip(models, model_names):
 print(f"\n{'='*50}")
 print(f"Training {model_name}")
 print(f"{'='*50}")
 
 # Callbacks
 callbacks = [
 EarlyStopping(
 monitor='val_accuracy',
 patience=10,
 restore_best_weights=True,
 verbose=1
 ),
 ReduceLROnPlateau(
 monitor='val_loss',
 patience=5,
 factor=0.5,
 min_lr=1e-7,
 verbose=1
 ),
 ModelCheckpoint(
 f"{self.model_path}/best_{model_name.lower()}_model.h5",
 monitor='val_accuracy',
 save_best_only=True,
 verbose=1
 )
 ]
 
 history = model.fit(
 train_gen,
 epochs=epochs,
 validation_data=val_gen,
 callbacks=callbacks,
 class_weight=class_weight_dict,
 verbose=1
 )
 
 histories.append(history)
 train_gen.reset()
 val_gen.reset()
 
 return histories
 
 def evaluate_ensemble(self, models, model_names, X_test, y_test, class_names):
 """Evaluate ensemble model"""
 print("📊 Evaluating ensemble...")
 
 # Individual model predictions
 all_predictions = []
 individual_scores = []
 
 for model, model_name in zip(models, model_names):
 print(f"\nEvaluating {model_name}:")
 
 test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
 predictions = model.predict(X_test, verbose=0)
 all_predictions.append(predictions)
 
 individual_scores.append({
 'model': model_name,
 'accuracy': test_accuracy,
 'loss': test_loss
 })
 
 print(f" Accuracy: {test_accuracy:.4f}")
 
 # Ensemble predictions (weighted average)
 weights = [0.4, 0.35, 0.25] # EfficientNet gets highest weight
 ensemble_predictions = np.average(all_predictions, axis=0, weights=weights)
 ensemble_pred_classes = np.argmax(ensemble_predictions, axis=1)
 
 # Calculate ensemble accuracy
 ensemble_accuracy = accuracy_score(y_test, ensemble_pred_classes)
 
 print(f"\n🎯 ENSEMBLE RESULTS:")
 print(f"Ensemble Accuracy: {ensemble_accuracy:.4f} ({ensemble_accuracy*100:.2f}%)")
 
 # Classification report
 print("\nEnsemble Classification Report:")
 report = classification_report(
 y_test, ensemble_pred_classes, 
 target_names=class_names,
 output_dict=True,
 zero_division=0
 )
 print(classification_report(y_test, ensemble_pred_classes, target_names=class_names, zero_division=0))
 
 # Plot confusion matrix
 self.plot_confusion_matrix(y_test, ensemble_pred_classes, class_names)
 
 return ensemble_accuracy, individual_scores, report
 
 def plot_confusion_matrix(self, y_true, y_pred, class_names):
 """Plot and save confusion matrix"""
 cm = confusion_matrix(y_true, y_pred)
 
 plt.figure(figsize=(10, 8))
 sns.heatmap(
 cm, 
 annot=True, 
 fmt='d', 
 cmap='Blues',
 xticklabels=class_names,
 yticklabels=class_names,
 annot_kws={'size': 10}
 )
 plt.title('Ensemble Confusion Matrix', fontsize=16)
 plt.ylabel('True Label', fontsize=12)
 plt.xlabel('Predicted Label', fontsize=12)
 plt.xticks(rotation=45, ha='right')
 plt.yticks(rotation=0)
 plt.tight_layout()
 
 plt.savefig('plots/confusion_matrix.png', dpi=300, bbox_inches='tight')
 plt.close()
 
 print("✅ Confusion matrix saved to plots/confusion_matrix.png")
 
 def plot_training_history(self, histories, model_names):
 """Plot training history"""
 plt.figure(figsize=(15, 10))
 
 colors = ['blue', 'red', 'green']
 
 # Accuracy
 plt.subplot(2, 2, 1)
 for history, model_name, color in zip(histories, model_names, colors):
 plt.plot(history.history['accuracy'], label=f'{model_name} Train', color=color, linestyle='-')
 plt.plot(history.history['val_accuracy'], label=f'{model_name} Val', color=color, linestyle='--')
 
 plt.title('Model Accuracy')
 plt.xlabel('Epoch')
 plt.ylabel('Accuracy')
 plt.legend()
 plt.grid(True)
 
 # Loss
 plt.subplot(2, 2, 2)
 for history, model_name, color in zip(histories, model_names, colors):
 plt.plot(history.history['loss'], label=f'{model_name} Train', color=color, linestyle='-')
 plt.plot(history.history['val_loss'], label=f'{model_name} Val', color=color, linestyle='--')
 
 plt.title('Model Loss')
 plt.xlabel('Epoch')
 plt.ylabel('Loss')
 plt.legend()
 plt.grid(True)
 
 plt.tight_layout()
 plt.savefig('plots/training_history.png', dpi=300, bbox_inches='tight')
 plt.close()
 
 print("✅ Training history saved to plots/training_history.png")
 
 def save_ensemble(self, models, model_names, class_names, ensemble_accuracy, individual_scores):
 """Save ensemble models and metadata"""
 print("💾 Saving ensemble...")
 
 # Save individual models
 model_files = []
 for model, model_name in zip(models, model_names):
 model_file = f"{self.model_path}/{model_name.lower()}_model.keras"
 model.save(model_file)
 model_files.append(model_file)
 print(f"✅ {model_name} saved to: {model_file}")
 
 # Save metadata
 metadata = {
 'ensemble_accuracy': float(ensemble_accuracy),
 'individual_models': individual_scores,
 'model_files': model_files,
 'num_classes': len(class_names),
 'image_size': self.image_size,
 'class_names': class_names.tolist(),
 'ensemble_weights': [0.4, 0.35, 0.25]
 }
 
 metadata_file = f"{self.model_path}/ensemble_metadata.json"
 with open(metadata_file, 'w') as f:
 json.dump(metadata, f, indent=2)
 
 # Save class names
 classes_file = f"{self.model_path}/classes.json"
 with open(classes_file, 'w') as f:
 json.dump(class_names.tolist(), f, indent=2)
 
 print(f"✅ Metadata saved to: {metadata_file}")
 print(f"✅ Classes saved to: {classes_file}")
 
 def run_training(self):
 """Run the complete training pipeline"""
 print("🎯 Eye Disease Detection - Advanced Training Pipeline")
 print("=" * 60)
 
 # Step 1: Download dataset if needed
 if not os.path.exists(self.dataset_path):
 if not self.download_dataset():
 print("Please download and extract the dataset manually.")
 return False
 
 # Step 2: Load dataset
 result = self.load_dataset()
 if result[0] is None:
 print("❌ Failed to load dataset. Exiting.")
 return False
 
 X, y, class_names = result
 
 if len(X) < 50:
 print(f"❌ Not enough images ({len(X)}). Need at least 50 for training.")
 return False
 
 # Step 3: Balance dataset
 X_balanced, y_balanced = self.balance_dataset(X, y)
 
 # Step 4: Split dataset
 print("✂️ Splitting dataset...")
 X_train, X_temp, y_train, y_temp = train_test_split(
 X_balanced, y_balanced, test_size=0.3, random_state=42, stratify=y_balanced
 )
 X_val, X_test, y_val, y_test = train_test_split(
 X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
 )
 
 print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
 
 # Step 5: Create data generators
 train_gen, val_gen = self.create_data_generators(X_train, y_train, X_val, y_val)
 
 # Step 6: Create and train ensemble
 models, model_names = self.create_ensemble(len(class_names))
 histories = self.train_ensemble(models, model_names, train_gen, val_gen, class_names)
 
 # Step 7: Evaluate ensemble
 ensemble_accuracy, individual_scores, report = self.evaluate_ensemble(
 models, model_names, X_test, y_test, class_names
 )
 
 # Step 8: Save results
 self.save_ensemble(models, model_names, class_names, ensemble_accuracy, individual_scores)
 self.plot_training_history(histories, model_names)
 
 print("\n🎉 TRAINING COMPLETED!")
 print(f"🎯 Final Ensemble Accuracy: {ensemble_accuracy:.4f} ({ensemble_accuracy*100:.2f}%)")
 
 return True

def main():
 """Main function"""
 print("🚀 Advanced Eye Disease Detection Training")
 print("=" * 50)
 
 # System information
 print(f"🔧 System Information:")
 print(f" TensorFlow version: {tf.__version__}")
 print(f" GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")
 
 # GPU setup
 gpus = tf.config.experimental.list_physical_devices('GPU')
 if gpus:
 try:
 for gpu in gpus:
 tf.config.experimental.set_memory_growth(gpu, True)
 print(" ✅ GPU memory growth enabled")
 except RuntimeError as e:
 print(f" ⚠️ GPU setup warning: {e}")
 
 print()
 
 # Initialize and run trainer
 trainer = AdvancedFundusTrainer()
 success = trainer.run_training()
 
 if success:
 print("\n" + "="*50)
 print("🎉 TRAINING SUCCESSFULLY COMPLETED!")
 print("="*50)
 print("\n📦 Generated Assets:")
 print(" 🤖 Ensemble models (EfficientNetB3 + ResNet152V2 + DenseNet201)")
 print(" 📊 Performance analysis and visualizations")
 print(" 📋 Metadata for deployment")
 
 else:
 print("\n❌ TRAINING FAILED")
 print("Please check the error messages above.")

if __name__ == "__main__":
 main()
