Fixed incorrect safetensors API call in README
Browse files
README.md
CHANGED
@@ -49,12 +49,12 @@ import torch
|
|
49 |
import torchvision.transforms as transforms
|
50 |
from torchvision.utils import save_image
|
51 |
from omegaconf import OmegaConf
|
52 |
-
from safetensors.torch import
|
53 |
from olvae.utils import instantiate_from_config
|
54 |
|
55 |
def load_model_from_config(config_path, ckpt_path, device=torch.device("cuda")):
|
56 |
config = OmegaConf.load(config_path)
|
57 |
-
sd =
|
58 |
model = instantiate_from_config(config.model)
|
59 |
model.load_state_dict(sd, strict=False)
|
60 |
model = model.to(device).eval()
|
|
|
49 |
import torchvision.transforms as transforms
|
50 |
from torchvision.utils import save_image
|
51 |
from omegaconf import OmegaConf
|
52 |
+
from safetensors.torch import load_file
|
53 |
from olvae.utils import instantiate_from_config
|
54 |
|
55 |
def load_model_from_config(config_path, ckpt_path, device=torch.device("cuda")):
|
56 |
config = OmegaConf.load(config_path)
|
57 |
+
sd = load_file(ckpt_path)
|
58 |
model = instantiate_from_config(config.model)
|
59 |
model.load_state_dict(sd, strict=False)
|
60 |
model = model.to(device).eval()
|