Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from torchvision import transforms | |
| import lightning as L | |
| import kornia as K | |
| import numpy as np | |
| import random | |
| class XrayDataset(Dataset): | |
| def __init__(self, | |
| data_frame, | |
| root_dir, | |
| transform=None, | |
| apply_equalization=False): | |
| self.data_frame = data_frame | |
| self.root_dir = root_dir | |
| self.transform = transform | |
| self.apply_equalization = apply_equalization | |
| def __len__(self): | |
| return len(self.data_frame) | |
| def __getitem__(self, idx): | |
| row = self.data_frame.iloc[idx] | |
| img_path = os.path.join(self.root_dir, row["file_name"]) | |
| img = Image.open(img_path) | |
| img = img.convert("L") | |
| if self.transform: | |
| img = self.transform(img) | |
| # Apply CLAHE if flag is set | |
| if self.apply_equalization: | |
| # img = transforms.ToTensor()(img) | |
| img = K.enhance.equalize_clahe(img.unsqueeze(0)).squeeze(0) | |
| label = torch.tensor(row["value"], | |
| dtype=torch.float) # Ensure label is float | |
| return img, label, row["file_name"] | |
| class XrayData(L.LightningDataModule): | |
| common_seed = 42 | |
| def seed_worker(worker_id): | |
| worker_seed = torch.initial_seed() % 2**32 | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) | |
| def __init__( | |
| self, | |
| root_dir, | |
| label_csv, | |
| batch_size=32, | |
| val_split=0.2, | |
| apply_equalization=False, | |
| ): | |
| super().__init__() | |
| self.root_dir = root_dir | |
| self.label_csv = label_csv | |
| self.batch_size = batch_size | |
| self.val_split = val_split | |
| self.apply_equalization = apply_equalization | |
| torch.manual_seed(self.common_seed) | |
| torch.cuda.manual_seed_all(self.common_seed) | |
| torch.backends.cudnn.deterministic = True | |
| self.train_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| # transforms.RandomHorizontalFlip(), | |
| # transforms.RandomRotation(20), | |
| transforms.ToTensor(), | |
| ]) | |
| self.val_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| self.full_dataset = None | |
| def setup(self, stage=None): | |
| data_frame = pd.read_csv(self.label_csv) | |
| data_frame = data_frame.sample( | |
| frac=1, random_state=self.common_seed).reset_index(drop=True) | |
| dataset_size = len(data_frame) | |
| val_size = int(dataset_size * self.val_split) | |
| train_size = dataset_size - val_size | |
| # Split the dataset using random_split | |
| full_dataset = XrayDataset( | |
| data_frame, | |
| self.root_dir, | |
| transform=None, # We'll apply the correct transform later | |
| apply_equalization=self.apply_equalization, | |
| ) | |
| self.train_dataset, self.val_dataset = random_split( | |
| full_dataset, | |
| [train_size, val_size], | |
| generator=torch.Generator().manual_seed(self.common_seed), | |
| ) | |
| def train_transforms(x): | |
| return self.train_transform(x) if self.train_transform else x | |
| def val_transforms(x): | |
| return self.val_transform(x) if self.val_transform else x | |
| self.train_dataset.dataset.transform = train_transforms | |
| self.val_dataset.dataset.transform = val_transforms | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| worker_init_fn=self.seed_worker, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| worker_init_fn=self.seed_worker, | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| worker_init_fn=self.seed_worker, | |
| ) | |