Updated README
Browse files
README.md
CHANGED
@@ -1,3 +1,91 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Open-LiteVAE
|
2 |
+
[[github]](https://github.com/RGenDiff/open-litevae)
|
3 |
+
|
4 |
+
This repository contains a LiteVAE model trained with the open-litevae codebase, based on the paper "[LiteVAE: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models](https://openreview.net/forum?id=mTAbl8kUzq)" [2024].
|
5 |
+
|
6 |
+
**Note:** This model is intended for demonstration purposes, we do not recommend using it in production.
|
7 |
+
|
8 |
+
**license:** AGPL-3.0
|
9 |
+
|
10 |
+
---
|
11 |
+
|
12 |
+
## Configuration Details
|
13 |
+
|
14 |
+
|
15 |
+
| Parameter | Value |
|
16 |
+
|------------------|----|
|
17 |
+
| Downscale Factor | 8x |
|
18 |
+
| Latent Z dim | 12 |
|
19 |
+
| Encoder Size (params) | B (6.2M) |
|
20 |
+
| Decoder Size (params) | M (54M) |
|
21 |
+
| Discriminator | UNetGAN-L |
|
22 |
+
| Training Set | ImageNet-1k |
|
23 |
+
| Training Resolution | 128x128 --> 256x256|
|
24 |
+
| Training Steps | 100k --> 50k |
|
25 |
+
|
26 |
+
|
27 |
+
## Metric Comparison
|
28 |
+
|
29 |
+
|
30 |
+
| Model | Z dim | rFID | LPIPS | PSNR | SSIM |
|
31 |
+
|-------|-------|------|-------|------|------|
|
32 |
+
| SD1-VAE | 4 | 0.75 | 0.138 | 25.70 | 0.72 |
|
33 |
+
| SD3-VAE | 16 | 0.22 | 0.069 | 29.59 | 0.86 |
|
34 |
+
| olvf8c12 (this repo) | 12 | 0.24 | 0.084 | 28.74 | 0.84 |
|
35 |
+
|
36 |
+
## Usage
|
37 |
+
|
38 |
+
|
39 |
+
```python
|
40 |
+
# install open-litevae https://github.com/RGenDiff/open-litevae
|
41 |
+
#
|
42 |
+
|
43 |
+
from PIL import Image
|
44 |
+
import torch
|
45 |
+
import torchvision.transforms as transforms
|
46 |
+
from torchvision.utils import save_image
|
47 |
+
from omegaconf import OmegaConf
|
48 |
+
from safetensors.torch import load_model
|
49 |
+
from olvae.utils import instantiate_from_config
|
50 |
+
|
51 |
+
def load_model_from_config(config_path, ckpt_path, device=torch.device("cuda")):
|
52 |
+
config = OmegaConf.load(config_path)
|
53 |
+
sd = load_model(ckpt_path)
|
54 |
+
model = instantiate_from_config(config.model)
|
55 |
+
model.load_state_dict(sd, strict=False)
|
56 |
+
model = model.to(device).eval()
|
57 |
+
return model
|
58 |
+
|
59 |
+
# load the model
|
60 |
+
olitevae = load_model_from_config(config_path="configs/olitevaeB_im_f8c12.yaml",
|
61 |
+
ckpt_path="olitevaeB_im_f8c12.safetensors")
|
62 |
+
|
63 |
+
img_transforms = transforms.Compose([
|
64 |
+
transforms.ToTensor(),
|
65 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
66 |
+
])
|
67 |
+
|
68 |
+
# encode
|
69 |
+
image = img_transforms(Image.open(<your image>)).to(device)
|
70 |
+
latent = olitevae.encode(image.unsqueeze(0)).sample()
|
71 |
+
print(latent.shape)
|
72 |
+
|
73 |
+
# decode
|
74 |
+
y = olitevae.decode(latent)
|
75 |
+
save_image(y[0]*0.5 + 0.5, "decoded_image.png")
|
76 |
+
|
77 |
+
```
|
78 |
+
|
79 |
+
|
80 |
+
## Please Cite the Original Paper
|
81 |
+
|
82 |
+
```
|
83 |
+
@inproceedings{
|
84 |
+
sadat2024litevae,
|
85 |
+
title={Lite{VAE}: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models},
|
86 |
+
author={Seyedmorteza Sadat and Jakob Buhmann and Derek Bradley and Otmar Hilliges and Romann M. Weber},
|
87 |
+
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
|
88 |
+
year={2024},
|
89 |
+
url={https://openreview.net/forum?id=mTAbl8kUzq}
|
90 |
+
}
|
91 |
+
```
|