--- library_name: transformers tags: - vision - vit - mae - retinal-imaging - image-classification - pytorch - OCT license: cc-by-nc-4.0 base_model: - YukunZhou/RETFound_mae_natureOCT pipeline_tag: image-feature-extraction --- # RETFound ViT-L/16 (MAE → Transformers) — natureOCT **Author of this fork:** *Dávid Isztl* **Upstream project:** [RETFound_mae_natureOCT](https://huggingface.co/YukunZhou/RETFound_mae_natureOCT) by *Yukun Zhou* et al. **Paper:** *A foundation model for generalizable disease detection from retinal images*, Nature (2023) > This repository provides a **Transformers-compatible** export of the RETFound **MAE** encoder trained on a subset of **natureOCT** (OCT). > It includes `config.json`, `model.safetensors`, and an `AutoImageProcessor`, so you can load it directly with 🤗 `AutoModel` / `AutoModelForImageClassification`. --- ## Model Details ### Model Description This is a ViT-Large/16 encoder pretrained with the **Masked Autoencoder (MAE)** objective on Optical Coherence Tomography (OCT). This fork converts the original **PyTorch `.pth`** checkpoint into a standard 🤗 **Transformers** format and removes MAE-only components. - **Developed by (upstream):** Yukun Zhou et al. - **Shared by (this fork):** Dávid Isztl - **Model type:** Vision Transformer (encoder only) - **License:** CC BY-NC 4.0 (inherited from upstream) - **Finetuned from:** Upstream RETFound MAE checkpoint (ViT-L/16) **Architecture (ViT-L/16 @ 224):** - `hidden_size=1024`, `num_hidden_layers=24`, `num_attention_heads=16`, `patch_size=16`, `image_size=224` - `add_pooling_layer=False` (use CLS token or your own pooling) **Conversion notes:** - Dropped **MAE-only** tensors: `mask_token`, `decoder_*` - Remapped fused **qkv** weights (timm-style) → separate Q/K/V matrices (Transformers style) - Set `layer_norm_eps=1e-6` to match timm numerics - Positional embeddings sized for 224×224 (patch 16×16) ### Model Sources - **Repository (upstream):** https://github.com/rmaphoh/RETFound - **Paper:** https://www.nature.com/articles/s41586-023-06555-x --- ## Uses ### Direct Use - **Feature extraction** from retinal images for downstream tasks - Initial encoder for transfer learning on medical imaging research tasks (e.g., classification, retrieval) ### Downstream Use - **Fine-tuning** for image classification and related tasks using `AutoModelForImageClassification` - Using CLS token or pooled features in custom pipelines ### Out-of-Scope Use - Clinical decision-making without proper validation and regulatory approval - Commercial use beyond the **CC BY-NC 4.0** license terms --- ## Bias, Risks, and Limitations - Trained on specific retinal data (subset of natureOCT); distribution shifts (device, population, protocol) can degrade performance. - Not a medical device; requires independent validation before any real-world or clinical deployment. - Potential biases relate to dataset composition, imaging hardware, and labeling procedures. ### Recommendations - Perform task- and population-specific validation. - Monitor for domain shift; consider domain adaptation where appropriate. - Document preprocessing and augmentation pipelines for reproducibility. --- ## How to Get Started with the Model ### Feature extraction (encoder) ```python from transformers import AutoModel, AutoImageProcessor from PIL import Image import torch repo = "iszt/RETFound_mae_natureOCT" # this fork processor = AutoImageProcessor.from_pretrained(repo) model = AutoModel.from_pretrained(repo) # ViTModel with add_pooling_layer=False model.eval() img = Image.open("example_retina_cfp.jpg").convert("RGB") inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): out = model(**inputs) cls = out.last_hidden_state[:, 0] # [B, 1024] — CLS embedding after final norm tokens = out.last_hidden_state[:, 1:, :] # [B, N, 1024] — patch tokens ```` ### Classification fine-tune (use **AutoModelForImageClassification**) ```python from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification repo = "iszt/RETFound_mae_natureOCT" id2label = {0: "negative", 1: "positive"} # example label2id = {v: k for k, v in id2label.items()} processor = AutoImageProcessor.from_pretrained(repo) config = AutoConfig.from_pretrained(repo) config.num_labels = len(id2label) config.id2label = id2label config.label2id = label2id # Loads encoder weights from the repo and initializes a fresh classifier head model = AutoModelForImageClassification.from_pretrained( repo, config=config, ignore_mismatched_sizes=True, # replaces the classification head if shapes differ ) # now train `model` with your dataloader/Trainer ``` --- ## Training Details ### Training Data * Upstream pretraining: **OCT** from a portion of **natureOCT**. ### Training Procedure * **Objective:** Masked Autoencoder (MAE) pretraining. * **This fork:** no additional training; checkpoint conversion only. #### Preprocessing * `AutoImageProcessor` provided for 224×224 inputs. If your dataset uses different normalization or resolution, adjust accordingly (and, if needed, interpolate positional embeddings). #### Training Hyperparameters * Not specified by upstream for this exact subset; see the paper and repository for general MAE settings. #### Speeds, Sizes, Times * This fork only performs conversion; refer to upstream for compute details. --- ## Evaluation ### Testing Data, Factors & Metrics * No new evaluation performed in this fork. * For downstream tasks, report metrics relevant to the task (e.g., AUROC, accuracy, F1), and stratify by pertinent factors (device, demographics, pathology prevalence). ### Results * N/A for this fork; please cite/consult upstream results for baseline pretraining performance. #### Summary * Use this encoder as initialization; measure and report results on your target dataset. --- ## Environmental Impact This repository performs a format conversion only. Upstream pretraining compute and emissions are described in the paper and may be estimated via tools like the ML CO2 calculator. * **Hardware Type:** N/A (conversion only) * **Hours used:** N/A (conversion only) * **Cloud Provider / Region:** N/A * **Carbon Emitted:** N/A --- ## Technical Specifications ### Model Architecture and Objective * **Architecture:** Vision Transformer Large, patch size 16, image size 224. * **Objective:** MAE pretraining (encoder-only kept in this fork). * **Pooling:** No pooling layer (`add_pooling_layer=False`). ### Compute Infrastructure * This fork does not introduce new training; conversion was done locally. #### Hardware * N/A for conversion. #### Software * Conversion used PyTorch, timm, and 🤗 Transformers. --- ## Citation If you use this model, please cite the original RETFound paper: **BibTeX:** ```bibtex @article{zhou2023foundation, title={A foundation model for generalizable disease detection from retinal images}, author={Zhou, Yukun and Chia, Mark A and Wagner, Siegfried K and Ayhan, Murat S and Williamson, Dominic J and Struyven, Robbert R and Liu, Timing and Xu, Moucheng and Lozano, Mateo G and Woodward-Court, Peter and others}, journal={Nature}, volume={622}, number={7981}, pages={156--163}, year={2023}, publisher={Nature Publishing Group UK London} } ``` **APA:** Zhou, Y., Chia, M. A., Wagner, S. K., Ayhan, M. S., Williamson, D. J., Struyven, R. R., … et al. (2023). *A foundation model for generalizable disease detection from retinal images*. **Nature**, 622(7981), 156–163. --- ## Glossary * **CFP:** Color Fundus Photography * **MAE:** Masked Autoencoder * **CLS token:** Special token prepended to the patch sequence in ViT; often used as a global image representation. --- ## More Information * Upstream code and instructions: [https://github.com/rmaphoh/RETFound](https://github.com/rmaphoh/RETFound) * Nature paper: [https://www.nature.com/articles/s41586-023-06555-x](https://www.nature.com/articles/s41586-023-06555-x) --- ## Model Card Authors * **Dávid Isztl** (fork & conversion) --- ## Model Card Contact * For this fork/conversion: contact *Dávid Isztl* via Hugging Face. * For upstream model/training code: **[ykzhoua@gmail.com](mailto:ykzhoua@gmail.com)** or **[yukun.zhou.19@ucl.ac.uk](mailto:yukun.zhou.19@ucl.ac.uk)**.