Spaces:
Running
Running
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from huggingface_hub import hf_hub_download | |
# custom installation from this PR: https://github.com/huggingface/transformers/pull/34583 | |
# !pip install git+https://github.com/geetu040/transformers.git@depth-pro-projects#egg=transformers | |
from transformers import DepthProConfig, DepthProImageProcessorFast, DepthProForDepthEstimation | |
# initialize model | |
config = DepthProConfig(use_fov_model=False) | |
model = DepthProForDepthEstimation(config) | |
features = config.fusion_hidden_size | |
semantic_classifier_dropout = 0.1 | |
num_labels = 1 | |
model.head.head = nn.Sequential( | |
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), | |
nn.BatchNorm2d(features), | |
nn.ReLU(), | |
nn.Dropout(semantic_classifier_dropout), | |
nn.Conv2d(features, features, kernel_size=1), | |
nn.ConvTranspose2d(features, num_labels, kernel_size=2, stride=2, padding=0, bias=True), | |
) | |
# load weights | |
weights_path = hf_hub_download(repo_id="geetu040/DepthPro_Segmentation_Human", filename="model_weights.pth") | |
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'), weights_only=True)) | |
# load to device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# load image processor | |
image_processor = DepthProImageProcessorFast() | |
def predict(image): | |
# inference | |
image = image.convert("RGB") | |
# prepare image for the model | |
inputs = image_processor(images=image, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# inference | |
with torch.no_grad(): | |
output = model(**inputs) | |
# convert tensors to PIL.Image | |
output = output[0] # get output logits | |
output = F.interpolate( | |
output.unsqueeze(0), | |
size=(image.height, image.width) | |
) # interpolate to match size | |
output = output.squeeze() # get first and only batch and channel | |
output = output.sigmoid() # apply sigmoid for binary segmentation | |
output = (output > 0.5).float() # threshold to create binary mask | |
output = output.cpu() # unload from cuda if used | |
output = output * 255 # convert [0, 1] to [0, 255] | |
output = output.numpy() # convert to numpy | |
output = output.astype('uint8') # convert to PIL.Image compatible format | |
output = Image.fromarray(output) # create PIL.Image object | |
return output | |