iszt's picture
Update README.md
c7645ac verified
metadata
library_name: transformers
tags:
  - vision
  - vit
  - mae
  - retinal-imaging
  - image-classification
  - pytorch
  - OCT
license: cc-by-nc-4.0
base_model:
  - YukunZhou/RETFound_mae_natureOCT
pipeline_tag: image-feature-extraction

RETFound ViT-L/16 (MAE → Transformers) — natureOCT

Author of this fork: Dávid Isztl
Upstream project: RETFound_mae_natureOCT by Yukun Zhou et al.
Paper: A foundation model for generalizable disease detection from retinal images, Nature (2023)

This repository provides a Transformers-compatible export of the RETFound MAE encoder trained on a subset of natureOCT (OCT).
It includes config.json, model.safetensors, and an AutoImageProcessor, so you can load it directly with 🤗 AutoModel / AutoModelForImageClassification.


Model Details

Model Description

This is a ViT-Large/16 encoder pretrained with the Masked Autoencoder (MAE) objective on Optical Coherence Tomography (OCT).
This fork converts the original PyTorch .pth checkpoint into a standard 🤗 Transformers format and removes MAE-only components.

  • Developed by (upstream): Yukun Zhou et al.
  • Shared by (this fork): Dávid Isztl
  • Model type: Vision Transformer (encoder only)
  • License: CC BY-NC 4.0 (inherited from upstream)
  • Finetuned from: Upstream RETFound MAE checkpoint (ViT-L/16)

Architecture (ViT-L/16 @ 224):

  • hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, patch_size=16, image_size=224
  • add_pooling_layer=False (use CLS token or your own pooling)

Conversion notes:

  • Dropped MAE-only tensors: mask_token, decoder_*
  • Remapped fused qkv weights (timm-style) → separate Q/K/V matrices (Transformers style)
  • Set layer_norm_eps=1e-6 to match timm numerics
  • Positional embeddings sized for 224×224 (patch 16×16)

Model Sources


Uses

Direct Use

  • Feature extraction from retinal images for downstream tasks
  • Initial encoder for transfer learning on medical imaging research tasks (e.g., classification, retrieval)

Downstream Use

  • Fine-tuning for image classification and related tasks using AutoModelForImageClassification
  • Using CLS token or pooled features in custom pipelines

Out-of-Scope Use

  • Clinical decision-making without proper validation and regulatory approval
  • Commercial use beyond the CC BY-NC 4.0 license terms

Bias, Risks, and Limitations

  • Trained on specific retinal data (subset of natureOCT); distribution shifts (device, population, protocol) can degrade performance.
  • Not a medical device; requires independent validation before any real-world or clinical deployment.
  • Potential biases relate to dataset composition, imaging hardware, and labeling procedures.

Recommendations

  • Perform task- and population-specific validation.
  • Monitor for domain shift; consider domain adaptation where appropriate.
  • Document preprocessing and augmentation pipelines for reproducibility.

How to Get Started with the Model

Feature extraction (encoder)

from transformers import AutoModel, AutoImageProcessor
from PIL import Image
import torch

repo = "iszt/RETFound_mae_natureOCT"  # this fork

processor = AutoImageProcessor.from_pretrained(repo)
model = AutoModel.from_pretrained(repo)  # ViTModel with add_pooling_layer=False
model.eval()

img = Image.open("example_retina_cfp.jpg").convert("RGB")
inputs = processor(images=img, return_tensors="pt")

with torch.no_grad():
    out = model(**inputs)
    cls = out.last_hidden_state[:, 0]        # [B, 1024] — CLS embedding after final norm
    tokens = out.last_hidden_state[:, 1:, :] # [B, N, 1024] — patch tokens

Classification fine-tune (use AutoModelForImageClassification)

from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification

repo = "iszt/RETFound_mae_natureOCT"
id2label = {0: "negative", 1: "positive"}  # example
label2id = {v: k for k, v in id2label.items()}

processor = AutoImageProcessor.from_pretrained(repo)

config = AutoConfig.from_pretrained(repo)
config.num_labels = len(id2label)
config.id2label = id2label
config.label2id = label2id

# Loads encoder weights from the repo and initializes a fresh classifier head
model = AutoModelForImageClassification.from_pretrained(
    repo,
    config=config,
    ignore_mismatched_sizes=True,  # replaces the classification head if shapes differ
)

# now train `model` with your dataloader/Trainer

Training Details

Training Data

  • Upstream pretraining: OCT from a portion of natureOCT.

Training Procedure

  • Objective: Masked Autoencoder (MAE) pretraining.
  • This fork: no additional training; checkpoint conversion only.

Preprocessing

  • AutoImageProcessor provided for 224×224 inputs. If your dataset uses different normalization or resolution, adjust accordingly (and, if needed, interpolate positional embeddings).

Training Hyperparameters

  • Not specified by upstream for this exact subset; see the paper and repository for general MAE settings.

Speeds, Sizes, Times

  • This fork only performs conversion; refer to upstream for compute details.

Evaluation

Testing Data, Factors & Metrics

  • No new evaluation performed in this fork.
  • For downstream tasks, report metrics relevant to the task (e.g., AUROC, accuracy, F1), and stratify by pertinent factors (device, demographics, pathology prevalence).

Results

  • N/A for this fork; please cite/consult upstream results for baseline pretraining performance.

Summary

  • Use this encoder as initialization; measure and report results on your target dataset.

Environmental Impact

This repository performs a format conversion only. Upstream pretraining compute and emissions are described in the paper and may be estimated via tools like the ML CO2 calculator.

  • Hardware Type: N/A (conversion only)
  • Hours used: N/A (conversion only)
  • Cloud Provider / Region: N/A
  • Carbon Emitted: N/A

Technical Specifications

Model Architecture and Objective

  • Architecture: Vision Transformer Large, patch size 16, image size 224.
  • Objective: MAE pretraining (encoder-only kept in this fork).
  • Pooling: No pooling layer (add_pooling_layer=False).

Compute Infrastructure

  • This fork does not introduce new training; conversion was done locally.

Hardware

  • N/A for conversion.

Software

  • Conversion used PyTorch, timm, and 🤗 Transformers.

Citation

If you use this model, please cite the original RETFound paper:

BibTeX:

@article{zhou2023foundation,
  title={A foundation model for generalizable disease detection from retinal images},
  author={Zhou, Yukun and Chia, Mark A and Wagner, Siegfried K and Ayhan, Murat S and Williamson, Dominic J and Struyven, Robbert R and Liu, Timing and Xu, Moucheng and Lozano, Mateo G and Woodward-Court, Peter and others},
  journal={Nature},
  volume={622},
  number={7981},
  pages={156--163},
  year={2023},
  publisher={Nature Publishing Group UK London}
}

APA: Zhou, Y., Chia, M. A., Wagner, S. K., Ayhan, M. S., Williamson, D. J., Struyven, R. R., … et al. (2023). A foundation model for generalizable disease detection from retinal images. Nature, 622(7981), 156–163.


Glossary

  • CFP: Color Fundus Photography
  • MAE: Masked Autoencoder
  • CLS token: Special token prepended to the patch sequence in ViT; often used as a global image representation.

More Information


Model Card Authors

  • Dávid Isztl (fork & conversion)

Model Card Contact