Text-to-Image
olitevaeB_im_f8c12 / README.md
hjups22's picture
Fixed incorrect safetensors API call in README
8db0004 verified
|
raw
history blame
2.76 kB
metadata
pipeline_tag: text-to-image
license: agpl-3.0

Open-LiteVAE

[github]

This repository contains a LiteVAE model trained with the open-litevae codebase, based on the paper "LiteVAE: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models" [2024].

Note: This model is intended for demonstration purposes, we do not recommend using it in production.

license: AGPL-3.0


Configuration Details

Parameter Value
Downscale Factor 8x
Latent Z dim 12
Encoder Size (params) B (6.2M)
Decoder Size (params) M (54M)
Discriminator UNetGAN-L
Training Set ImageNet-1k
Training Resolution 128x128 --> 256x256
Training Steps 100k --> 50k

Metric Comparison

Model Z dim rFID LPIPS PSNR SSIM
SD1-VAE 4 0.75 0.138 25.70 0.72
SD3-VAE 16 0.22 0.069 29.59 0.86
olvf8c12 (this repo) 12 0.24 0.084 28.74 0.84

Usage

# install open-litevae https://github.com/RGenDiff/open-litevae
#

from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from omegaconf import OmegaConf
from safetensors.torch import load_file
from olvae.utils import instantiate_from_config

def load_model_from_config(config_path, ckpt_path, device=torch.device("cuda")):
    config = OmegaConf.load(config_path)
    sd = load_file(ckpt_path)
    model = instantiate_from_config(config.model)
    model.load_state_dict(sd, strict=False)
    model = model.to(device).eval()
    return model

# load the model
olitevae = load_model_from_config(config_path="configs/olitevaeB_im_f8c12.yaml", 
                                    ckpt_path="olitevaeB_im_f8c12.safetensors")

img_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# encode
image = img_transforms(Image.open(<your image>)).to(device)
latent = olitevae.encode(image.unsqueeze(0)).sample()
print(latent.shape)

# decode
y = olitevae.decode(latent)
save_image(y[0]*0.5 + 0.5, "decoded_image.png")

Please Cite the Original Paper

@inproceedings{
sadat2024litevae,
title={Lite{VAE}: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models},
author={Seyedmorteza Sadat and Jakob Buhmann and Derek Bradley and Otmar Hilliges and Romann M. Weber},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=mTAbl8kUzq}
}