Segmentation using a UNet2D

0. What is UNet2D?

  • UNet2D is a deep-learning architecture of the family of convolutional neural-networks and of the sub-family of auto-encoders.

  • It is trained through supervized learning, which means that for training, some pairs of input image + expected segmentation (== ground-truth) are required.

  • After training, the model is able to produce a probability map through its process of inference. This probability map has to be thresholded to transform it into a mask.

  • UNet2D generates a semantic segmentation instead of an instances segmentation. It means that each pixel will contain the answer to the question “is this pixel part of a microglia?” but the cells won’t be given individual IDs.

1. Get your data ready

  • You can retrain the model if you have some annotated data using the provided file: src/dl/unet2d_training.py

  • To train new UNet models, you need the file “src/dl/unet2d_training.py”. It contains the entire workflow to produce a bundled model ready for deployment.

  • Before starting, create a folder named “models” to store all the new model versions you create.

  • You also need a “working_dir” where the script will export its temporary data.

  • To train the UNet model, you need two distinct folders. You can name them as you like.
    • The first folder, referred to as “inputs”, will contain “.tif” images with values globally normalized in the range [0.0, 1.0].

    • The second folder, referred to as “masks”, will also contain “.tif” images, but these will be binary masks. They are thresholded to everything above 0 upon opening, so there is no restriction on whether they should be 0 and 1 or 0 and 255.

    • Images in both folders should be named the same way.

  • The models produced by this script include:
    • “version.txt”: The version index of this model, allowing detection if the model should be re-downloaded from the internet.

    • “training_history.png”: A set of 2 plots (with 4 plot slots).
      • The first plot contains the loss and the validation loss.

      • The second plot contains the training and validation precision, as well as the training and validation recall.

    • “last.keras”: The weights generated at the last training epoch.

    • “best.keras”: The weights that achieved the best validation loss. These weights are used in the segmentation step.

    • “data_usage.json”: Information on which files were used for training and which were used for validation.

    • “augmentations_preview.png”: A sample of data after passing through the data augmentation pipeline.

    • “architecture.png”: A graph representing the UNet2D architecture used by this model.

    • “predictions”: A folder containing the validation data and the model’s predictions.

2. Data augmentation

The images on which we have to work have a wide entropy, so we need a solid data augmentation workflow to avoid having to annotate an unreasonable number of images, and to ensure that the model generalizes well to different types of data.

The data augmentation pipeline includes the following transformations:

  • Random rotations: The images are randomly rotated in a range from -90 to 90 degrees.

  • Random flips: The images are randomly flipped horizontally and/or vertically.

  • Random gamma adjustment: A random gamma correction is applied to every patch. It allows to spread or dilate the histogram.

  • Random noise addition: Random Gaussian noise is added to the images.

  • Filament ruptures: On many images, the filaments pass on another Z plane before coming back. To simulate that, some filamentous areas are randomly discarded (blured).

These augmentations are applied on-the-fly at loading to ensure that each epoch sees a different set of augmented images, which helps in improving the robustness and generalization of the model.

3. Filaments extraction

  • Images containing filaments actually contain 96% of background, which represents a massive dis-balance between the background and foreground classes.

  • Using a usual loss function would end-up in the model predicting only solid black patches, as it would consider that it is 96% correct.

  • To address that problem, we had to re-implement the clDice loss as described in the paper: https://doi.org/10.48550/arXiv.2404.00130.

4. Setup

  • If you already have an Python environment in which “Microglia Analyzer” is installed, it aleady contains everything you need to train a model.

  • To launch the training, you just have to fill the setings in the first section, and run the script.

Name

Description

data_folder

Parent folder of the “inputs” and “masks” folders.

qc_folder

Parent folder of the “inputs” and “masks” folders used only for quality control (not for training). These are just images used to perform performance metrics at the end of training.

inputs_name

Name of the folder containing the input images (name of the folder in data_folder and qc_folder).

masks_name

Name of the folder containing the masks (name of the folder in data_folder and qc_folder).

models_path

Folder in which the models will be saved. They will be saved as “{model_name_prefix}-V{version_number}”.

working_directory

Folder in which the training, validation, and testing folders will be created. This folder and its content can be deleted once the training is done.

model_name_prefix

Prefix of the model name. Will be part of the folder name in models_path.

reset_local_data

If True, the locally copied training, validation, and testing folders will be re-imported.

remove_wrong_data

If True, the data that is not useful will be deleted from the data folder. It is a destructive operation, review the first run of the sanity check before activating this.

data_usage

Path to a JSON file containing how each input file should be used (for training, validation, or testing).

validation_percentage

Percentage of the data that will be used for validation. This data will be moved to the validation folder.

batch_size

Number of images per batch.

epochs

Number of epochs for the training.

unet_depth

Depth of the UNet model, i.e., the number of layers in the encoder part (equal to the number of layers in the decoder part).

num_filters_start

Number of filters in the first layer of the UNet.

dropout_rate

Dropout rate. Percentage of neurons that will be randomly disabled at each epoch. Better for generalization.

optimizer

Optimizer used for the training.

learning_rate

Learning rate at which the optimizer is initialized.

skeleton_coef

Coefficient of the skeleton loss.

bce_coef

Coefficient of the binary cross-entropy loss.

early_stop_patience

Number of epochs without improvement before stopping the training.

dilation_kernel

Kernel used for the dilation of the skeleton.

loss

Loss function used for the training.

use_data_augmentation

If True, data augmentation will be used.

use_mirroring

If True, random mirroring will be used.

use_gaussian_noise

If True, random Gaussian noise will be used.

noise_scale

Scale of the Gaussian noise (range of values).

use_random_rotations

If True, random rotations will be used.

angle_range

Range of the random rotations. The angle will be in [angle_range[0], angle_range[1]].

use_gamma_correction

If True, random gamma correction will be used.

gamma_range

Range of the gamma correction. The gamma will be in [1 - gamma_range, 1 + gamma_range] (1.0 == neutral).

use_holes

If True, holes will be created in the input images to teach the network to fill them.

export_aug_sample

If True, an augmented sample will be exported to the working directory as a preview.

5. Usage

  • This model consumes patches of 512×512 pixels in input with and overlap of 128 pixels.

  • The merging is performed with the alpha-blending technique described on the page where the patches creation is explained.

  • The output is labeled by connected components and filtered by number of pixels (processed from a minimal area in µm²) before being presented to the user.