🩺 UNet Model for COVID-19 CT Scan Segmentation
📌 Model Overview
This UNet-based segmentation model is designed for automated segmentation of COVID-19 infected lung regions in CT scans. It enhances the classic U-Net with attention mechanisms to improve focus on infected regions.
- Architecture: UNet + Attention Gates
- Dataset: COVID-19 CT scans from Coronacases.org, Radiopaedia.org, and Zenodo Repository
- Task: Image Segmentation (Lung Infection)
- Metrics: Dice Coefficient, IoU, Hausdorff Distance, ASSD
📊 Training Details
- Dataset Size: 20 CT scans (512 × 512 × 301 slices)
- Preprocessing:
- Normalization of pixel intensities
[0,1]
- HU Thresholding:
[-1000, 1500]
- Image resizing to
128 × 128 pixels
- Binarization of masks (0 = background, 1 = infected regions)
- Normalization of pixel intensities
- Augmentation:
- Rotations: ±5 degrees
- Elastic transformations, Gaussian blur
- Brightness/contrast variations
- Final dataset: 2,252 CT slices
- Training:
- Optimizer: Adam (
learning rate = 1e-4
) - Loss Function: Weighted BCE-Dice Loss + Surface Loss
- Batch Size: 16
- Epochs: 25
- Training Platform: NVIDIA Tesla T4 (Google Colab Pro)
- Optimizer: Adam (
🚀 Model Performance
Metric | Non-Augmented Model | Augmented Model |
---|---|---|
Dice Coefficient | 0.8502 | 0.8658 |
IoU (Mean) | 0.7445 | 0.8316 |
ASSD (Symmetric Distance) | 0.3907 | 0.3888 |
Hausdorff Distance | 8.4853 | 9.8995 |
ROC AUC Score | 0.91 | 1.00 |
📌 Key Findings:
✔ Augmentation improved segmentation accuracy significantly
✔ Attention U-Net outperformed other segmentation models
📥 How to Use the Model
1️⃣ Load the Model
TensorFlow/Keras
import os
from huggingface_hub import hf_hub_download
from tensorflow.keras.models import load_model
from keras.saving import register_keras_serializable
import tensorflow.keras.backend as K
# ✅ Set Keras backend (optional)
os.environ["KERAS_BACKEND"] = "jax"
# ✅ Register and define missing functions
@register_keras_serializable()
def dice_coef(y_true, y_pred, smooth=1e-6):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
@register_keras_serializable()
def gl_sl(*args, **kwargs):
pass # Placeholder function (update if needed)
# ✅ Download the model from HF中国镜像站
model_path = hf_hub_download(repo_id="amal90888/unet-segmentation-model", filename="unet_model.keras")
# ✅ Load the model with registered custom objects
unet = load_model(model_path, custom_objects={"dice_coef": dice_coef, "gl_sl": gl_sl}, compile=False)
# ✅ Recompile with fresh optimizer and correct loss function
from tensorflow.keras.optimizers import Adam
unet.compile(optimizer=Adam(learning_rate=1e-4), loss="binary_crossentropy", metrics=["accuracy", dice_coef])
print("✅ Model loaded and recompiled successfully!")
- Downloads last month
- 14
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The HF Inference API does not support image-segmentation models for keras library.