🩺 UNet Model for COVID-19 CT Scan Segmentation

📌 Model Overview

This UNet-based segmentation model is designed for automated segmentation of COVID-19 infected lung regions in CT scans. It enhances the classic U-Net with attention mechanisms to improve focus on infected regions.

  • Architecture: UNet + Attention Gates
  • Dataset: COVID-19 CT scans from Coronacases.org, Radiopaedia.org, and Zenodo Repository
  • Task: Image Segmentation (Lung Infection)
  • Metrics: Dice Coefficient, IoU, Hausdorff Distance, ASSD

📊 Training Details

  • Dataset Size: 20 CT scans (512 × 512 × 301 slices)
  • Preprocessing:
    • Normalization of pixel intensities [0,1]
    • HU Thresholding: [-1000, 1500]
    • Image resizing to 128 × 128 pixels
    • Binarization of masks (0 = background, 1 = infected regions)
  • Augmentation:
    • Rotations: ±5 degrees
    • Elastic transformations, Gaussian blur
    • Brightness/contrast variations
    • Final dataset: 2,252 CT slices
  • Training:
    • Optimizer: Adam (learning rate = 1e-4)
    • Loss Function: Weighted BCE-Dice Loss + Surface Loss
    • Batch Size: 16
    • Epochs: 25
    • Training Platform: NVIDIA Tesla T4 (Google Colab Pro)

🚀 Model Performance

Metric Non-Augmented Model Augmented Model
Dice Coefficient 0.8502 0.8658
IoU (Mean) 0.7445 0.8316
ASSD (Symmetric Distance) 0.3907 0.3888
Hausdorff Distance 8.4853 9.8995
ROC AUC Score 0.91 1.00

📌 Key Findings:
Augmentation improved segmentation accuracy significantly
Attention U-Net outperformed other segmentation models


📥 How to Use the Model

1️⃣ Load the Model

TensorFlow/Keras

import os
from huggingface_hub import hf_hub_download
from tensorflow.keras.models import load_model
from keras.saving import register_keras_serializable
import tensorflow.keras.backend as K

# ✅ Set Keras backend (optional)
os.environ["KERAS_BACKEND"] = "jax"

# ✅ Register and define missing functions
@register_keras_serializable()
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

@register_keras_serializable()
def gl_sl(*args, **kwargs):
    pass  # Placeholder function (update if needed)

# ✅ Download the model from HF中国镜像站
model_path = hf_hub_download(repo_id="amal90888/unet-segmentation-model", filename="unet_model.keras")

# ✅ Load the model with registered custom objects
unet = load_model(model_path, custom_objects={"dice_coef": dice_coef, "gl_sl": gl_sl}, compile=False)

# ✅ Recompile with fresh optimizer and correct loss function
from tensorflow.keras.optimizers import Adam
unet.compile(optimizer=Adam(learning_rate=1e-4), loss="binary_crossentropy", metrics=["accuracy", dice_coef])

print("✅ Model loaded and recompiled successfully!")
Downloads last month
14
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The HF Inference API does not support image-segmentation models for keras library.