Skip to content

Files

Latest commit

36e33b8 · Feb 6, 2025

History

History

diabetic_retinopathy_detection

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
Feb 26, 2024
Feb 26, 2024
Feb 26, 2024
May 16, 2024
Oct 18, 2022
May 16, 2024
May 16, 2024
Feb 26, 2024
Feb 26, 2024
Feb 26, 2024
Feb 26, 2024
Feb 6, 2025
May 16, 2024
May 16, 2024
May 16, 2024
May 16, 2024
Jan 24, 2022
Feb 26, 2024
Feb 26, 2024
Feb 26, 2024
Feb 26, 2024
Feb 26, 2024
Feb 26, 2024
May 16, 2024
Feb 26, 2024

RETINA Benchmark

Overview

Hi, good to see you here! 👋

Thanks for checking out the code for the RETINA Benchmark, part of the Uncertainty Baselines project.

See our 2021 NeurIPS Datasets and Benchmarks paper introducing this benchmark in detail here.

This codebase will allow you to reproduce experiments from the paper (see citation here) as well as use the benchmarking utilities for predictive performance, robustness, and uncertainty quantification (evaluation and plotting) for your own Bayesian deep learning methods.

We would greatly appreciate a citation if you use this code in your own work.

Prediction Task Overview

In this benchmark, models try to predict the presence or absence of diabetic retinopathy (a binary classification task) using data from the Kaggle Diabetic Retinopathy Detection challenge and the APTOS 2019 Blindness Detection. Please see these pages for details on data collection, etc.

Models are trained with images of blood vessels in the eye, as seen in the TensorFlow Datasets description.

Abstract

Bayesian deep learning seeks to equip deep neural networks with the ability to precisely quantify their predictive uncertainty, and has promised to make deep learning more reliable for safety-critical real-world applications. Yet, existing Bayesian deep learning methods fall short of this promise; new methods continue to be evaluated on unrealistic test beds that do not reflect the complexities of downstream real-world tasks that would benefit most from reliable uncertainty quantification. We propose a set of real-world tasks that accurately reflect such complexities and are designed to assess the reliability of predictive models in safety-critical scenarios. Specifically, we curate two publicly available datasets of high-resolution human retina images exhibiting varying degrees of diabetic retinopathy, a medical condition that can lead to blindness, and use them to design a suite of automated diagnosis tasks that require reliable predictive uncertainty quantification. We use these tasks to benchmark well-established and state-of-the-art Bayesian deep learning methods on task-specific evaluation metrics. We provide an easy-to-use codebase for fast and easy benchmarking following reproducibility and software design principles. We provide implementations of all methods included in the benchmark as well as results computed over 100 TPU days, 20 GPU days, 400 hyperparameter configurations, and evaluation on at least 6 random seeds each.

Environment Installation

Set up and activate the Python environment by executing

conda create -n ub python=3.8
conda activate ub
python3 -m pip install -e .[models,jax,tensorflow,torch,retinopathy]  # In uncertainty-baselines root directory
pip install "git+https://github.com/google-research/robustness_metrics.git#egg=robustness_metrics"
pip install 'git+https://github.com/google/edward2.git'

Data Installation

Because the data is distributed through Kaggle, we need to take a manual route to downloading.

  1. Download from Kaggle: https://www.kaggle.com/c/diabetic-retinopathy-detection

  2. Extract everything to $DATA_DIR/downloads/manual; your directory should look like

sample/ sampleSubmission.csv test/ train/ trainLabels.csv

  1. Confirm successful download of files

You should have 35,126 training images and 53,576 test images, which should be located in manual/train and manual/test.

You may check this with the command

ls -1 | wc -l

  1. Manual loading -- this is not contained in standard execution of diabetic-retinopathy model execution (yet)

I suggest doing this loading in a screen session, in case it fails -- it takes a while.

I suggest doing this in an ipython shell!

$ ipython

Train loading:

First, we initialize a DiabeticRetinopathyDetectionDataset object.

import uncertainty_baselines as ub

data_dir = $DATA_DIR

dataset_train_builder = ub.datasets.get(
    "ub_diabetic_retinopathy_detection",
    split='train',
    data_dir=data_dir, download_data=True)

We then need to shuffle and package our data into TF objects:

dataset_train_builder._dataset_builder.download_and_prepare(download_dir=f'{data_dir}/downloads/')

Rinse and repeat for test data:

dataset_test_builder = ub.datasets.get(
    "ub_diabetic_retinopathy_detection",
    split='test',
    data_dir=data_dir, download_data=True)
dataset_test_builder._dataset_builder.download_and_prepare(download_dir=f'{data_dir}/downloads/')

Install / Download for Severity and Country Shifts

Severity Shift depends on precisely the same data as the original Diabetic Retinopathy dataset, so we do not need to go back to step 1.

We can package the Severity Shift splits into TF objects by substituting "ub_diabetic_retinopathy_detection" with "diabetic_retinopathy_severity_shift_mild", and using the following arguments for split:

train
in_domain_validation
ood_validation
in_domain_test
ood_test

On the other hand, to download the (much smaller) APTOS dataset, we do need to repeat steps from step 1, downloading from https://www.kaggle.com/c/aptos2019-blindness-detection. Note that APTOS only includes "validation" and "test" splits.

Additional Splits for Exploration

There are several additional splits available for experimenting with other partitions of the severity levels into binary classification, and with other preprocessing configurations. See the following files for details on available splits:

Tuning Scripts

All hyperparameter tuning and fine-tuning (i.e., retraining with 6 different training seed) scripts are located in baselines/diabetic_retinopathy_detection/experiments/tuning and baselines/diabetic_retinopathy_detection/experiments/top_config respectively.

Train a Model

Tuning scripts accept hyperparameters as simple Python arguments. We also implement logging using TensorBoard and Weights and Biases across all uncertainty quantification methods for the convenience of the user.

Execute a tuning script as follows (all tuning scripts are located in baselines/diabetic_retinopathy_detection, and have by default had their arguments fixed to the configuration achieving the highest AUC on the in-domain validation set for the Country Shift task).

python baselines/diabetic_retinopathy_detection/deterministic.py --data_dir='gs://ub-data/retinopathy' --use_gpu=True --output_dir='gs://ub-data/retinopathy-out/deterministic'

Select Top Performing Models

Model selection utilities are provided in baselines/diabetic_retinopathy_detection/model_selection.

First, follow the steps detailed in parse_tensorboards.py to convert TensorFlow event files to a public TensorBoard, and then parse this into a DataFrame containing results (per epoch metric logs, and hyperparameter details). The script expects that the TensorFlow event files are each in a folder corresponding to their identity, such as

dr_tuning/
   |--> 1/
        |--> tuning-run-seed-1.out.tfevents...
   |--> 2/
        |--> tuning-run-seed-2.out.tfevents...
   |--> 3/
        |--> tuning-run-seed-3.out.tfevents...
  ...

Following the steps in parse_tensorboards.py produces a file results.tsv. We can parse this file to obtain a ranking of the models based on our two tuning criteria: in-domain validation AUC, and area under the balanced accuracy referral curve (see paper), by executing python analyze_tensorboards.py in the directory containing the results.tsv file. This ranking allows the user to select top performing checkpoints.

Accessing Model Checkpoints

For each method, task (Country or Severity Shifts), and tuning method (see model selection details above) we release the six best-performing checkpoints here.

For more details on the models, see the accompanying Model Card along with method implementation and modification details provided in Section 5 of the benchmark whitepaper located here.

Evaluate a Model

Evaluation Sweep Scripts

Scripts for the evaluation sweeps used for the paper are located in baselines/diabetic_retinopathy_detection/experiments/eval.

.py sweep files are used with XManager, a framework for launching experiments on Google Cloud Platform.

.yaml sweep files are tuning scripts used with Weights & Biases.

Selective Prediction and Referral Curves

In Selective Prediction, a model's predictive uncertainty is used to choose a subset of the test set for which predictions will be evaluated. In particular, the uncertainty per test input forms a ranking. The X% of test inputs with the highest uncertainty are referred to a specialist, and the model performance is evaluated on the (100 - X)% remaining inputs. Standard evaluation therefore uses a referral fraction = 0, i.e., the full test set is retained.

We may wish to use a predictive model of diabetic retinopathy to ease the burden on clinical practitioners. Under Selective Prediction, the model refers the examples on which it is least confident to specialists. We can tune the referral fraction parameter based on practitioner availability, and a model with well-calibrated uncertainty will have high performance on metrics such as AUC/accuracy on the retained (non-referred) evaluation data, because its uncertainty and predictive performance are (negatively) correlated.

Using Evaluation Utilities

Once you have trained a few models and have placed the top performing checkpoints in a checkpoint_bucket, run an evaluation over the methods, and store both scalar results on predictive performance and uncertainty quantification metrics (e.g., accuracy, AUC, expected calibration error) as well as results needed for selective prediction and receiver operating characteristic plots, as follows.

Single model evaluation (e.g., X different training seeds for a deterministic model).

python baselines/diabetic_retinopathy_detection/eval_model.py --checkpoint_bucket='bucket-name' --output_bucket='results-bucket-name' --dr_decision_threshold='moderate' --model_type='deterministic' --single_model_multi_train_seeds=True

Ensemble evaluation, where each ensemble is formed by sampling without replacement from all available checkpoints in the directory, with sample size k_ensemble_members = 3 and number of sampling repetitions ensemble_sampling_repetitions = 6 (as in paper):

python baselines/diabetic_retinopathy_detection/eval_model.py --checkpoint_bucket='bucket-name' --output_bucket='results-bucket-name' --dr_decision_threshold='moderate' --model_type='deterministic' --k_ensemble_members=3 --ensemble_sampling_repetitions=6

Details on eval_model.py flags

single_model_multi_train_seeds will evaluate a single model with multiple train random seeds instead of multiple evaluation seeds. The particular evaluation seed you choose should be specified with parameter seed. If you enable this option, all checkpoints in the checkpoint directory will be loaded, and you will iterate through them, averaging performance metrics across the single models. On the other hand, if you wish to evaluate ensembles (i.e., average predictions, produce ensemble-based uncertainty estimates), our procedure is to construct ensembles by randomly sampling (without replacement) model checkpoints from a checkpoint directory. k_ensemble_members specifies the size of each ensemble, i.e., the number of models to sample without replacement in each repetition. ensemble_sampling_repetitions specifies the number of iterations in which we sample and evaluate an ensemble composed of a subset of all checkpoints.

Plot ROC and Selective Prediction Curves

Now we can generate the same ROC and selective prediction plots as appear in the paper (e.g., if you have run the above training and evaluation for many different Bayesian deep learning methods).

Note the flag distribution_shift to specify for which distribution shift you aim to generate outputs. See plot_results.py for info on expected directory structure.

python baselines/diabetic_retinopathy_detection/plot_results.py --results_dir='gs://results-bucket-name' --output_dir='gs://plot-outputs' --distribution_shift=aptos

Previous Tuning Details

The below tuning was done for the initial Uncertainty Baselines release. See baselines/diabetic_retinopathy_detection/experiments/initial_tuning for the corresponding tuning scripts and the trained model checkpoints here.

Model Checkpoints

For each method we release the best-performing checkpoints. These checkpoints were trained on the combined training and validation set, using hyperparameters selected from the best validation performance. Each checkpoint was selected to be from the step during training with the best test AUC (averaged across the 10 random seeds). This was epoch 63 for the deterministic model, epoch 72 for the MC-Dropout method, epoch 31 for the Variational Inference method, and epoch 61 for the Radial BNNs method. For more details on the models, see the accompanying Model Card, which covers all the models below, as the dataset is exactly the same across them all, and the only model differences are minor calibration improvements. The checkpoints can be browsed here.

Tuning

For this baseline, two rounds of quasirandom search were conducted on the hyperparameters listed below, where the first round was a heuristically-picked larger search space and the second round was a hand-tuned smaller range around the better performing values. Each round was for 50 trials, and the final hyperparemeters were selected using the final validation AUC from the second tuning round. These best hyperparameters were used to retrain combined train and validation sets over 10 seeds. We note that the learning rate schedules could likely be tuned for improved performance, but leave this to future work. All our intermediate and final tuning results are available below hosted on tensorboard.dev.

Below are links to tensorboard.dev TensorBoards for each baseline method that contain the metric values of the various tuning runs as well as the hyperparameter points sampled in the HPARAMS tab at the top of the page.

Deterministic

[First Tuning Round] [Final Tuning Round] [Best Hyperparamters 10 seeds]


Monte-Carlo Dropout

[First Tuning Round] [Final Tuning Round] [Best Hyperparamters 10 seeds]


Radial Bayesian Neural Networks

[First Tuning Round] [Final Tuning Round] [Best Hyperparamters 10 seeds]


Variational Inference

[First Tuning Round] [Final Tuning Round] [Best Hyperparamters 10 seeds]


Search spaces

Search space for the initial and final rounds of tuning on the deterministic method. We used a stepwise decay for the initial round but switched to a linear decay for the final round to alleviate overfitting, where we tuned the linear decay factor on the grid [1e-3, 1e-2, 0.1].

Learning Rate 1 - momentum L2
Initial [1e-3,0.1] [1e-2,0.1] [1e-5,1e-3]
Final [0.03, 0.5] [5e-3, 0.05] [1e-6, 2e-4]

Search space for the initial and final rounds of tuning on the Monte Carlo Dropout method.

Learning Rate 1 - momentum L2 dropout
Initial [1e-3,0.1] [1e-2,0.1] [1e-5,1e-3] [0.01, 0.25]
Final [1e-2,0.5] [1e-2, 0.04] [1e-5, 1e-3] [0.01, 0.2]

Search space for the initial and final rounds of tuning on the Radial BNN method.

Learning Rate 1 - momentum L2 stddev_mean_init stddev_stddev_init
Initial [1e-3,0.1] [1e-2,0.1] [1e-5,1e-3] [1e-5,1e-1] [1e-2,1]
Final [0.15,1] [1e-2, 0.05] [1e-4, 1e-3] [1e-5, 2e-2] [1e-2, 0.2]

Search space for the initial and final rounds of tuning on the Variational Inference method.

Learning Rate 1 - momentum L2 stddev_mean_init stddev_stddev_init
Initial [1e-3,0.1] [1e-2,0.1] [1e-5,1e-3] [1e-5,1e-1] [1e-2,1]
Final [0.02,5] [0.02, 0.1] [1e-5, 2e-4] [1e-5, 2e-3] [1e-2, 1]

Cite

Please cite our paper if you use this code in your own work:

@inproceedings{
    band2021benchmarking,
    title={Benchmarking Bayesian Deep Learning on Diabetic Retinopathy Detection Tasks},
    author={Neil Band and Tim G. J. Rudner and Qixuan Feng and Angelos Filos and Zachary Nado and Michael W Dusenberry and Ghassen Jerfel and Dustin Tran and Yarin Gal},
    booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
    year={2021},
    url={https://openreview.net/forum?id=jyd4Lyjr2iB}
}

Acknowledgements

The Diabetic Retinopathy Detection baseline was contributed through collaboration with the Oxford Applied and Theoretical Machine Learning (OATML) group, with sponsorship from: