{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "30e80151", "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "#!/usr/bin/env python3\n", "\"\"\"\n", "Advanced Eye Disease Detection Training Script - Fixed Version\n", "\"\"\"\n", "\n", "import os\n", "import sys\n", "import json\n", "import numpy as np\n", "import cv2\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "import matplotlib.pyplot as plt\n", "from collections import Counter\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "# Check if required packages are available\n", "try:\n", " import tensorflow as tf\n", " from tensorflow import keras\n", " from tensorflow.keras import layers\n", " from tensorflow.keras.applications import EfficientNetB3, ResNet152V2, DenseNet201\n", " from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", " from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint\n", " from sklearn.model_selection import train_test_split, StratifiedKFold\n", " from sklearn.preprocessing import LabelEncoder\n", " from sklearn.metrics import classification_report, confusion_matrix, accuracy_score\n", " from sklearn.utils.class_weight import compute_class_weight\n", " import seaborn as sns\n", " from imblearn.over_sampling import SMOTE\n", " from imblearn.combine import SMOTETomek\n", " # Optional: Kaggle API (install if needed)\n", " try:\n", " from kaggle.api.kaggle_api_extended import KaggleApi\n", " KAGGLE_AVAILABLE = True\n", " except ImportError:\n", " KAGGLE_AVAILABLE = False\n", " print(\"⚠️ Kaggle API not available. Manual dataset download required.\")\n", "except ImportError as e:\n", " print(f\"Missing required package: {e}\")\n", " print(\"Please run: pip install tensorflow opencv-python scikit-learn matplotlib tqdm seaborn imbalanced-learn\")\n", " sys.exit(1)\n", "\n", "print(\"✅ All required packages are available!\")\n", "\n", "class AdvancedFundusPreprocessor:\n", " \"\"\"Advanced preprocessing pipeline for fundus images\"\"\"\n", " \n", " def __init__(self, image_size=(224, 224)):\n", " self.image_size = image_size\n", " self.setup_clahe_variants()\n", " \n", " def setup_clahe_variants(self):\n", " \"\"\"Setup multiple CLAHE variants for different image characteristics\"\"\"\n", " self.clahe_normal = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))\n", " self.clahe_aggressive = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(4,4))\n", " self.clahe_gentle = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(16,16))\n", " \n", " def advanced_preprocess_image(self, image_path):\n", " \"\"\"Advanced preprocessing with multiple enhancement techniques\"\"\"\n", " try:\n", " # Read image\n", " image = cv2.imread(str(image_path))\n", " if image is None:\n", " print(f\"⚠️ Could not read image: {image_path}\")\n", " return None\n", " \n", " # Convert to RGB\n", " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", " \n", " # Apply advanced preprocessing pipeline\n", " image = self.enhance_fundus_image(image)\n", " \n", " # Resize with high-quality interpolation\n", " image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_LANCZOS4)\n", " \n", " # Normalize to [0, 1]\n", " image = image.astype(np.float32) / 255.0\n", " \n", " return image\n", " except Exception as e:\n", " print(f\"Error processing {image_path}: {e}\")\n", " return None\n", " \n", " def enhance_fundus_image(self, image):\n", " \"\"\"Advanced fundus-specific enhancement techniques\"\"\"\n", " # Apply CLAHE to LAB color space\n", " lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)\n", " l_channel = lab[:,:,0]\n", " \n", " # Apply CLAHE to L channel\n", " l_clahe = self.clahe_normal.apply(l_channel)\n", " lab[:,:,0] = l_clahe\n", " image_enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)\n", " \n", " # Green channel enhancement (important for fundus images)\n", " green_enhanced = cv2.equalizeHist(image_enhanced[:,:,1])\n", " image_enhanced[:,:,1] = green_enhanced\n", " \n", " # Apply gamma correction\n", " image_enhanced = self.adjust_gamma(image_enhanced, gamma=1.2)\n", " \n", " return image_enhanced\n", " \n", " def adjust_gamma(self, image, gamma=1.0):\n", " \"\"\"Apply gamma correction\"\"\"\n", " inv_gamma = 1.0 / gamma\n", " table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]).astype(\"uint8\")\n", " return cv2.LUT(image, table)\n", "\n", "class AdvancedEnsembleModel:\n", " \"\"\"Advanced ensemble model combining multiple architectures\"\"\"\n", " \n", " def __init__(self, image_size=(224, 224), num_classes=7):\n", " self.image_size = image_size\n", " self.num_classes = num_classes\n", " \n", " def create_efficientnet_model(self):\n", " \"\"\"Create EfficientNetB3 based model\"\"\"\n", " base_model = EfficientNetB3(\n", " weights='imagenet',\n", " include_top=False,\n", " input_shape=(*self.image_size, 3)\n", " )\n", " \n", " # Fine-tune top layers\n", " base_model.trainable = True\n", " for layer in base_model.layers[:-30]:\n", " layer.trainable = False\n", " \n", " inputs = keras.Input(shape=(*self.image_size, 3))\n", " \n", " # Base model\n", " x = base_model(inputs, training=False)\n", " \n", " # Advanced pooling\n", " gap = layers.GlobalAveragePooling2D()(x)\n", " gmp = layers.GlobalMaxPooling2D()(x)\n", " x = layers.Concatenate()([gap, gmp])\n", " \n", " # Advanced head\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Dropout(0.5)(x)\n", " \n", " x = layers.Dense(512, activation='relu')(x)\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Dropout(0.4)(x)\n", " \n", " x = layers.Dense(256, activation='relu')(x)\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Dropout(0.3)(x)\n", " \n", " outputs = layers.Dense(self.num_classes, activation='softmax')(x)\n", " \n", " return keras.Model(inputs, outputs, name='EfficientNetB3_Model')\n", " \n", " def create_resnet_model(self):\n", " \"\"\"Create ResNet152V2 based model\"\"\"\n", " base_model = ResNet152V2(\n", " weights='imagenet',\n", " include_top=False,\n", " input_shape=(*self.image_size, 3)\n", " )\n", " \n", " base_model.trainable = True\n", " for layer in base_model.layers[:-40]:\n", " layer.trainable = False\n", " \n", " inputs = keras.Input(shape=(*self.image_size, 3))\n", " \n", " x = base_model(inputs, training=False)\n", " \n", " # Global pooling\n", " x = layers.GlobalAveragePooling2D()(x)\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Dropout(0.5)(x)\n", " \n", " x = layers.Dense(512, activation='relu')(x)\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Dropout(0.4)(x)\n", " \n", " outputs = layers.Dense(self.num_classes, activation='softmax')(x)\n", " \n", " return keras.Model(inputs, outputs, name='ResNet152V2_Model')\n", " \n", " def create_densenet_model(self):\n", " \"\"\"Create DenseNet201 based model\"\"\"\n", " base_model = DenseNet201(\n", " weights='imagenet',\n", " include_top=False,\n", " input_shape=(*self.image_size, 3)\n", " )\n", " \n", " base_model.trainable = True\n", " for layer in base_model.layers[:-50]:\n", " layer.trainable = False\n", " \n", " inputs = keras.Input(shape=(*self.image_size, 3))\n", " \n", " x = base_model(inputs, training=False)\n", " \n", " x = layers.GlobalAveragePooling2D()(x)\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Dropout(0.5)(x)\n", " \n", " x = layers.Dense(512, activation='relu')(x)\n", " x = layers.BatchNormalization()(x)\n", " x = layers.Dropout(0.4)(x)\n", " \n", " outputs = layers.Dense(self.num_classes, activation='softmax')(x)\n", " \n", " return keras.Model(inputs, outputs, name='DenseNet201_Model')\n", "\n", "class AdvancedFundusTrainer:\n", " def __init__(self):\n", " self.dataset_path = \"fundus_dataset\"\n", " self.model_path = \"models\"\n", " self.image_size = (224, 224)\n", " self.batch_size = 16\n", " \n", " # Initialize components\n", " self.preprocessor = AdvancedFundusPreprocessor(self.image_size)\n", " self.ensemble_model = AdvancedEnsembleModel(self.image_size)\n", " \n", " # Create directories\n", " Path(self.model_path).mkdir(exist_ok=True)\n", " Path(\"logs\").mkdir(exist_ok=True)\n", " Path(\"plots\").mkdir(exist_ok=True)\n", " \n", " def download_dataset(self):\n", " \"\"\"Download the Kaggle dataset\"\"\"\n", " if not KAGGLE_AVAILABLE:\n", " print(\"❌ Kaggle API not available. Please download dataset manually.\")\n", " print(\"Dataset: https://www.kaggle.com/datasets/linchundan/fundusimage1000\")\n", " print(f\"Extract to: {self.dataset_path}/\")\n", " return False\n", " \n", " print(\"📥 Downloading dataset from Kaggle...\")\n", " \n", " try:\n", " api = KaggleApi()\n", " api.authenticate()\n", " \n", " api.dataset_download_files(\n", " \"linchundan/fundusimage1000\",\n", " path=self.dataset_path,\n", " unzip=True\n", " )\n", " print(\"✅ Dataset downloaded successfully!\")\n", " return True\n", " \n", " except Exception as e:\n", " print(f\"❌ Download failed: {e}\")\n", " print(\"Please download dataset manually and extract to fundus_dataset/\")\n", " return False\n", " \n", " def load_dataset(self):\n", " \"\"\"Load dataset with preprocessing - Fixed Version\"\"\"\n", " print(\"📁 Loading dataset...\")\n", " \n", " images = []\n", " labels = []\n", " \n", " dataset_path = Path(self.dataset_path)\n", " \n", " # Debug: Print directory structure\n", " print(f\"\\nDirectory structure at {dataset_path}:\")\n", " for item in dataset_path.rglob('*'):\n", " print(f\" {item.relative_to(dataset_path)}\")\n", " \n", " # Check for common dataset structures\n", " possible_paths = [\n", " dataset_path,\n", " dataset_path / \"1000images\",\n", " dataset_path / \"fundusimage1000\",\n", " dataset_path / \"images\",\n", " dataset_path / \"data\"\n", " ]\n", " \n", " actual_path = None\n", " for path in possible_paths:\n", " if path.exists():\n", " print(f\"\\nChecking path: {path}\")\n", " # Count image files in this path\n", " image_count = sum(1 for _ in path.rglob('*.[jJ][pP][gG]')) + \\\n", " sum(1 for _ in path.rglob('*.[jJ][pP][eE][gG]')) + \\\n", " sum(1 for _ in path.rglob('*.[pP][nN][gG]'))\n", " print(f\" Found {image_count} images\")\n", " if image_count > 0:\n", " actual_path = path\n", " break\n", " \n", " if actual_path is None:\n", " print(f\"\\n❌ No valid dataset found in {dataset_path}\")\n", " print(\"Please ensure the dataset contains image files in one of these structures:\")\n", " print(\"1. Directly in fundus_dataset/\")\n", " print(\"2. In fundus_dataset/1000images/\")\n", " print(\"3. In fundus_dataset/fundusimage1000/\")\n", " print(\"4. In subdirectories by class\")\n", " return None, None, None\n", " \n", " print(f\"\\n✅ Using dataset path: {actual_path}\")\n", " \n", " # Find all image files\n", " image_files = list(actual_path.rglob('*.[jJ][pP][gG]')) + \\\n", " list(actual_path.rglob('*.[jJ][pP][eE][gG]')) + \\\n", " list(actual_path.rglob('*.[pP][nN][gG]'))\n", " \n", " if not image_files:\n", " print(\"❌ No image files found in the dataset directory\")\n", " return None, None, None\n", " \n", " print(f\"\\nFound {len(image_files)} image files\")\n", " \n", " # Process images\n", " class_counts = Counter()\n", " for image_file in tqdm(image_files, desc=\"Processing images\"):\n", " processed_image = self.preprocessor.advanced_preprocess_image(image_file)\n", " if processed_image is not None:\n", " images.append(processed_image)\n", " # Use parent directory name as class label\n", " class_label = image_file.parent.name\n", " if class_label == actual_path.name: # If image is in root directory\n", " class_label = \"unknown\"\n", " labels.append(class_label)\n", " class_counts[class_label] += 1\n", " \n", " if len(images) == 0:\n", " print(\"❌ No images loaded successfully\")\n", " return None, None, None\n", " \n", " # Print class distribution\n", " print(\"\\nClass distribution:\")\n", " for class_name, count in class_counts.most_common():\n", " print(f\" {class_name}: {count} images\")\n", " \n", " # Convert to numpy arrays\n", " X = np.array(images)\n", " y = np.array(labels)\n", " \n", " # Encode labels\n", " label_encoder = LabelEncoder()\n", " y_encoded = label_encoder.fit_transform(y)\n", " \n", " print(f\"\\n✅ Dataset loaded: {len(X)} images, {len(label_encoder.classes_)} classes\")\n", " \n", " return X, y_encoded, label_encoder.classes_\n", " \n", " def balance_dataset(self, X, y):\n", " \"\"\"Balance dataset using SMOTE\"\"\"\n", " print(\"⚖️ Balancing dataset...\")\n", " \n", " try:\n", " # Reshape for SMOTE\n", " X_flattened = X.reshape(X.shape[0], -1)\n", " \n", " # Apply SMOTE\n", " smote = SMOTE(random_state=42, k_neighbors=min(5, len(np.unique(y))-1))\n", " X_balanced, y_balanced = smote.fit_resample(X_flattened, y)\n", " \n", " # Reshape back\n", " X_balanced = X_balanced.reshape(-1, *self.image_size, 3)\n", " \n", " print(f\"Dataset balanced: {len(X_balanced)} samples\")\n", " return X_balanced, y_balanced\n", " \n", " except Exception as e:\n", " print(f\"SMOTE failed: {e}, using original dataset\")\n", " return X, y\n", " \n", " def create_data_generators(self, X_train, y_train, X_val, y_val):\n", " \"\"\"Create data generators with augmentation\"\"\"\n", " # Training data generator with augmentation\n", " train_datagen = ImageDataGenerator(\n", " rotation_range=20,\n", " width_shift_range=0.15,\n", " height_shift_range=0.15,\n", " horizontal_flip=True,\n", " vertical_flip=True,\n", " zoom_range=0.15,\n", " brightness_range=[0.8, 1.2],\n", " shear_range=0.1,\n", " fill_mode='reflect'\n", " )\n", " \n", " # Validation data generator (no augmentation)\n", " val_datagen = ImageDataGenerator()\n", " \n", " train_generator = train_datagen.flow(\n", " X_train, y_train, \n", " batch_size=self.batch_size,\n", " shuffle=True\n", " )\n", " \n", " val_generator = val_datagen.flow(\n", " X_val, y_val,\n", " batch_size=self.batch_size,\n", " shuffle=False\n", " )\n", " \n", " return train_generator, val_generator\n", " \n", " def create_ensemble(self, num_classes):\n", " \"\"\"Create ensemble model\"\"\"\n", " print(\"🧠 Creating ensemble model...\")\n", " \n", " # Create individual models\n", " efficientnet_model = self.ensemble_model.create_efficientnet_model()\n", " resnet_model = self.ensemble_model.create_resnet_model()\n", " densenet_model = self.ensemble_model.create_densenet_model()\n", " \n", " # Update number of classes for each model\n", " self.ensemble_model.num_classes = num_classes\n", " \n", " # Compile models\n", " optimizer = keras.optimizers.Adam(learning_rate=1e-4)\n", " \n", " for model in [efficientnet_model, resnet_model, densenet_model]:\n", " model.compile(\n", " optimizer=optimizer,\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy']\n", " )\n", " \n", " models = [efficientnet_model, resnet_model, densenet_model]\n", " model_names = ['EfficientNetB3', 'ResNet152V2', 'DenseNet201']\n", " \n", " print(f\"✅ Created ensemble with {len(models)} models\")\n", " return models, model_names\n", " \n", " def train_ensemble(self, models, model_names, train_gen, val_gen, class_names, epochs=30):\n", " \"\"\"Train ensemble models\"\"\"\n", " print(\"🚀 Training ensemble...\")\n", " \n", " # Calculate class weights for imbalanced data\n", " y_train_full = []\n", " for _ in range(len(train_gen)):\n", " _, y_batch = next(train_gen)\n", " y_train_full.extend(y_batch)\n", " \n", " class_weights = compute_class_weight(\n", " 'balanced',\n", " classes=np.unique(y_train_full),\n", " y=y_train_full\n", " )\n", " class_weight_dict = dict(enumerate(class_weights))\n", " \n", " # Reset generator\n", " train_gen.reset()\n", " \n", " # Train each model\n", " histories = []\n", " for model, model_name in zip(models, model_names):\n", " print(f\"\\n{'='*50}\")\n", " print(f\"Training {model_name}\")\n", " print(f\"{'='*50}\")\n", " \n", " # Callbacks\n", " callbacks = [\n", " EarlyStopping(\n", " monitor='val_accuracy',\n", " patience=10,\n", " restore_best_weights=True,\n", " verbose=1\n", " ),\n", " ReduceLROnPlateau(\n", " monitor='val_loss',\n", " patience=5,\n", " factor=0.5,\n", " min_lr=1e-7,\n", " verbose=1\n", " ),\n", " ModelCheckpoint(\n", " f\"{self.model_path}/best_{model_name.lower()}_model.h5\",\n", " monitor='val_accuracy',\n", " save_best_only=True,\n", " verbose=1\n", " )\n", " ]\n", " \n", " history = model.fit(\n", " train_gen,\n", " epochs=epochs,\n", " validation_data=val_gen,\n", " callbacks=callbacks,\n", " class_weight=class_weight_dict,\n", " verbose=1\n", " )\n", " \n", " histories.append(history)\n", " train_gen.reset()\n", " val_gen.reset()\n", " \n", " return histories\n", " \n", " def evaluate_ensemble(self, models, model_names, X_test, y_test, class_names):\n", " \"\"\"Evaluate ensemble model\"\"\"\n", " print(\"📊 Evaluating ensemble...\")\n", " \n", " # Individual model predictions\n", " all_predictions = []\n", " individual_scores = []\n", " \n", " for model, model_name in zip(models, model_names):\n", " print(f\"\\nEvaluating {model_name}:\")\n", " \n", " test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)\n", " predictions = model.predict(X_test, verbose=0)\n", " all_predictions.append(predictions)\n", " \n", " individual_scores.append({\n", " 'model': model_name,\n", " 'accuracy': test_accuracy,\n", " 'loss': test_loss\n", " })\n", " \n", " print(f\" Accuracy: {test_accuracy:.4f}\")\n", " \n", " # Ensemble predictions (weighted average)\n", " weights = [0.4, 0.35, 0.25] # EfficientNet gets highest weight\n", " ensemble_predictions = np.average(all_predictions, axis=0, weights=weights)\n", " ensemble_pred_classes = np.argmax(ensemble_predictions, axis=1)\n", " \n", " # Calculate ensemble accuracy\n", " ensemble_accuracy = accuracy_score(y_test, ensemble_pred_classes)\n", " \n", " print(f\"\\n🎯 ENSEMBLE RESULTS:\")\n", " print(f\"Ensemble Accuracy: {ensemble_accuracy:.4f} ({ensemble_accuracy*100:.2f}%)\")\n", " \n", " # Classification report\n", " print(\"\\nEnsemble Classification Report:\")\n", " report = classification_report(\n", " y_test, ensemble_pred_classes, \n", " target_names=class_names,\n", " output_dict=True,\n", " zero_division=0\n", " )\n", " print(classification_report(y_test, ensemble_pred_classes, target_names=class_names, zero_division=0))\n", " \n", " # Plot confusion matrix\n", " self.plot_confusion_matrix(y_test, ensemble_pred_classes, class_names)\n", " \n", " return ensemble_accuracy, individual_scores, report\n", " \n", " def plot_confusion_matrix(self, y_true, y_pred, class_names):\n", " \"\"\"Plot and save confusion matrix\"\"\"\n", " cm = confusion_matrix(y_true, y_pred)\n", " \n", " plt.figure(figsize=(10, 8))\n", " sns.heatmap(\n", " cm, \n", " annot=True, \n", " fmt='d', \n", " cmap='Blues',\n", " xticklabels=class_names,\n", " yticklabels=class_names,\n", " annot_kws={'size': 10}\n", " )\n", " plt.title('Ensemble Confusion Matrix', fontsize=16)\n", " plt.ylabel('True Label', fontsize=12)\n", " plt.xlabel('Predicted Label', fontsize=12)\n", " plt.xticks(rotation=45, ha='right')\n", " plt.yticks(rotation=0)\n", " plt.tight_layout()\n", " \n", " plt.savefig('plots/confusion_matrix.png', dpi=300, bbox_inches='tight')\n", " plt.close()\n", " \n", " print(\"✅ Confusion matrix saved to plots/confusion_matrix.png\")\n", " \n", " def plot_training_history(self, histories, model_names):\n", " \"\"\"Plot training history\"\"\"\n", " plt.figure(figsize=(15, 10))\n", " \n", " colors = ['blue', 'red', 'green']\n", " \n", " # Accuracy\n", " plt.subplot(2, 2, 1)\n", " for history, model_name, color in zip(histories, model_names, colors):\n", " plt.plot(history.history['accuracy'], label=f'{model_name} Train', color=color, linestyle='-')\n", " plt.plot(history.history['val_accuracy'], label=f'{model_name} Val', color=color, linestyle='--')\n", " \n", " plt.title('Model Accuracy')\n", " plt.xlabel('Epoch')\n", " plt.ylabel('Accuracy')\n", " plt.legend()\n", " plt.grid(True)\n", " \n", " # Loss\n", " plt.subplot(2, 2, 2)\n", " for history, model_name, color in zip(histories, model_names, colors):\n", " plt.plot(history.history['loss'], label=f'{model_name} Train', color=color, linestyle='-')\n", " plt.plot(history.history['val_loss'], label=f'{model_name} Val', color=color, linestyle='--')\n", " \n", " plt.title('Model Loss')\n", " plt.xlabel('Epoch')\n", " plt.ylabel('Loss')\n", " plt.legend()\n", " plt.grid(True)\n", " \n", " plt.tight_layout()\n", " plt.savefig('plots/training_history.png', dpi=300, bbox_inches='tight')\n", " plt.close()\n", " \n", " print(\"✅ Training history saved to plots/training_history.png\")\n", " \n", " def save_ensemble(self, models, model_names, class_names, ensemble_accuracy, individual_scores):\n", " \"\"\"Save ensemble models and metadata\"\"\"\n", " print(\"💾 Saving ensemble...\")\n", " \n", " # Save individual models\n", " model_files = []\n", " for model, model_name in zip(models, model_names):\n", " model_file = f\"{self.model_path}/{model_name.lower()}_model.keras\"\n", " model.save(model_file)\n", " model_files.append(model_file)\n", " print(f\"✅ {model_name} saved to: {model_file}\")\n", " \n", " # Save metadata\n", " metadata = {\n", " 'ensemble_accuracy': float(ensemble_accuracy),\n", " 'individual_models': individual_scores,\n", " 'model_files': model_files,\n", " 'num_classes': len(class_names),\n", " 'image_size': self.image_size,\n", " 'class_names': class_names.tolist(),\n", " 'ensemble_weights': [0.4, 0.35, 0.25]\n", " }\n", " \n", " metadata_file = f\"{self.model_path}/ensemble_metadata.json\"\n", " with open(metadata_file, 'w') as f:\n", " json.dump(metadata, f, indent=2)\n", " \n", " # Save class names\n", " classes_file = f\"{self.model_path}/classes.json\"\n", " with open(classes_file, 'w') as f:\n", " json.dump(class_names.tolist(), f, indent=2)\n", " \n", " print(f\"✅ Metadata saved to: {metadata_file}\")\n", " print(f\"✅ Classes saved to: {classes_file}\")\n", " \n", " def run_training(self):\n", " \"\"\"Run the complete training pipeline\"\"\"\n", " print(\"🎯 Eye Disease Detection - Advanced Training Pipeline\")\n", " print(\"=\" * 60)\n", " \n", " # Step 1: Download dataset if needed\n", " if not os.path.exists(self.dataset_path):\n", " if not self.download_dataset():\n", " print(\"Please download and extract the dataset manually.\")\n", " return False\n", " \n", " # Step 2: Load dataset\n", " result = self.load_dataset()\n", " if result[0] is None:\n", " print(\"❌ Failed to load dataset. Exiting.\")\n", " return False\n", " \n", " X, y, class_names = result\n", " \n", " if len(X) < 50:\n", " print(f\"❌ Not enough images ({len(X)}). Need at least 50 for training.\")\n", " return False\n", " \n", " # Step 3: Balance dataset\n", " X_balanced, y_balanced = self.balance_dataset(X, y)\n", " \n", " # Step 4: Split dataset\n", " print(\"✂️ Splitting dataset...\")\n", " X_train, X_temp, y_train, y_temp = train_test_split(\n", " X_balanced, y_balanced, test_size=0.3, random_state=42, stratify=y_balanced\n", " )\n", " X_val, X_test, y_val, y_test = train_test_split(\n", " X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp\n", " )\n", " \n", " print(f\"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}\")\n", " \n", " # Step 5: Create data generators\n", " train_gen, val_gen = self.create_data_generators(X_train, y_train, X_val, y_val)\n", " \n", " # Step 6: Create and train ensemble\n", " models, model_names = self.create_ensemble(len(class_names))\n", " histories = self.train_ensemble(models, model_names, train_gen, val_gen, class_names)\n", " \n", " # Step 7: Evaluate ensemble\n", " ensemble_accuracy, individual_scores, report = self.evaluate_ensemble(\n", " models, model_names, X_test, y_test, class_names\n", " )\n", " \n", " # Step 8: Save results\n", " self.save_ensemble(models, model_names, class_names, ensemble_accuracy, individual_scores)\n", " self.plot_training_history(histories, model_names)\n", " \n", " print(\"\\n🎉 TRAINING COMPLETED!\")\n", " print(f\"🎯 Final Ensemble Accuracy: {ensemble_accuracy:.4f} ({ensemble_accuracy*100:.2f}%)\")\n", " \n", " return True\n", "\n", "def main():\n", " \"\"\"Main function\"\"\"\n", " print(\"🚀 Advanced Eye Disease Detection Training\")\n", " print(\"=\" * 50)\n", " \n", " # System information\n", " print(f\"🔧 System Information:\")\n", " print(f\" TensorFlow version: {tf.__version__}\")\n", " print(f\" GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}\")\n", " \n", " # GPU setup\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " if gpus:\n", " try:\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " print(\" ✅ GPU memory growth enabled\")\n", " except RuntimeError as e:\n", " print(f\" ⚠️ GPU setup warning: {e}\")\n", " \n", " print()\n", " \n", " # Initialize and run trainer\n", " trainer = AdvancedFundusTrainer()\n", " success = trainer.run_training()\n", " \n", " if success:\n", " print(\"\\n\" + \"=\"*50)\n", " print(\"🎉 TRAINING SUCCESSFULLY COMPLETED!\")\n", " print(\"=\"*50)\n", " print(\"\\n📦 Generated Assets:\")\n", " print(\" 🤖 Ensemble models (EfficientNetB3 + ResNet152V2 + DenseNet201)\")\n", " print(\" 📊 Performance analysis and visualizations\")\n", " print(\" 📋 Metadata for deployment\")\n", " \n", " else:\n", " print(\"\\n❌ TRAINING FAILED\")\n", " print(\"Please check the error messages above.\")\n", "\n", "if __name__ == \"__main__\":\n", " main()\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }