GenStereo / genstereo /GenStereo.py
FQiao's picture
Upload 144 files
b4754db verified
from os.path import join
from typing import Union, Optional, List, Dict, Tuple, Any
from dataclasses import dataclass
import inspect
from omegaconf import OmegaConf, DictConfig
from jaxtyping import Float
from PIL import Image
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from einops import rearrange, repeat
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.image_processor import VaeImageProcessor
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from .models import (
PoseGuider,
UNet2DConditionModel,
UNet3DConditionModel,
ReferenceAttentionControl
)
from .ops import get_viewport_matrix, forward_warper, convert_left_to_right, convert_left_to_right_torch
class AdaptiveFusionLayer(nn.Module):
def __init__(self):
super().__init__()
self.fusion_layer = nn.Sequential(
nn.Conv2d(7, 1, kernel_size=3, stride=1, padding=1), # (B, 7, H, W) → (B, 1, H, W)
nn.Sigmoid() # Ensures weights are between 0 and 1
)
def forward(self, generated_image, warped_image, mask):
# Compute fusion weights
fusion_input = torch.cat([generated_image, warped_image, mask], dim=1) # (B, 7, H, W)
weights = self.fusion_layer(fusion_input) # (B, 1, H, W)
fused_output = mask * weights * warped_image + (1 - mask * weights) * generated_image
return fused_output
class GenStereo():
@dataclass
class Config():
pretrained_model_path: str = ''
checkpoint_name: str = ''
half_precision_weights: bool = False
height: int = 512
width: int = 512
num_inference_steps: int = 50
guidance_scale: float = 1.5
cfg: Config
class Embedder():
def __init__(self, **kwargs) -> None:
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self) -> None:
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs) -> Tensor:
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def __init__(
self,
cfg: Optional[Union[dict, DictConfig]] = None,
device: Optional[str] = 'cuda:0'
) -> None:
self.cfg = OmegaConf.structured(self.Config(**cfg))
self.model_path = join(
self.cfg.pretrained_model_path, self.cfg.checkpoint_name
)
self.device = device
self.configure()
self.transform_pixels = transforms.Compose([
transforms.ToTensor(), # Converts image to Tensor
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
])
def configure(self) -> None:
print(f"Loading GenStereo...")
# Configurations.
self.dtype = (
torch.float16 if self.cfg.half_precision_weights else torch.float32
)
self.viewport_mtx: Float[Tensor, 'B 4 4'] = get_viewport_matrix(
self.cfg.width, self.cfg.height,
batch_size=1, device=self.device
).to(self.dtype)
# Load models.
self.load_models()
# Timestep
self.scheduler.set_timesteps(
self.cfg.num_inference_steps, device=self.device)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
print(f"Loaded GenStereo.")
def load_models(self) -> None:
# VAE.
self.vae = AutoencoderKL.from_pretrained(
join(self.cfg.pretrained_model_path, 'sd-vae-ft-mse')
).to(self.device, dtype=self.dtype)
# Image processor.
self.vae_scale_factor = \
2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
)
self.clip_image_processor = CLIPImageProcessor()
# Image encoder.
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
join(self.cfg.pretrained_model_path, 'image_encoder')
).to(self.device, dtype=self.dtype)
# Reference Unet.
self.reference_unet = UNet2DConditionModel.from_config(
UNet2DConditionModel.load_config(
join(self.model_path, 'config.json')
)).to(self.device, dtype=self.dtype)
self.reference_unet.load_state_dict(torch.load(
join(self.model_path, 'reference_unet.pth'),
map_location= 'cpu'),
)
# Denoising Unet.
self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
join(self.model_path, 'config.json'),
join(self.model_path, 'denoising_unet.pth')
).to(self.device, dtype=self.dtype)
self.unet_in_channels = self.denoising_unet.config.in_channels
# Pose guider.
self.pose_guider = PoseGuider(
conditioning_embedding_channels=320,
# TODO
conditioning_channels=14,
).to(self.device, dtype=self.dtype)
self.pose_guider.load_state_dict(torch.load(
join(self.model_path, 'pose_guider.pth'),
map_location='cpu'),
)
# Noise scheduler
sched_kwargs = OmegaConf.to_container(OmegaConf.create({
'num_train_timesteps': 1000,
'beta_start': 0.00085,
'beta_end': 0.012,
'beta_schedule': 'scaled_linear',
'steps_offset': 1,
'clip_sample': False
}))
sched_kwargs.update(
rescale_betas_zero_snr=True,
timestep_spacing='trailing',
prediction_type='v_prediction',
)
self.scheduler = DDIMScheduler(**sched_kwargs)
self.vae.requires_grad_(False)
self.image_encoder.requires_grad_(False)
self.reference_unet.requires_grad_(False)
self.denoising_unet.requires_grad_(False)
self.pose_guider.requires_grad_(False)
# Coordinates embedding.
self.embedder = self.get_embedder(2)
def get_embedder(self, multires):
embed_kwargs = {
'include_input' : True,
'input_dims' : 2,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = self.Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed
def __call__(
self,
src_image: Image,
src_disparity: Float[Tensor, 'B C H W'],
ratio
) -> Dict[str, Tensor]:
""" Perform NVS.
"""
src_image_pil = src_image
src_image = self.transform_pixels(src_image).unsqueeze(0).to(self.device, dtype=self.dtype)
batch_size = src_image.shape[0]
# Rearrange and resize.
src_image = self.preprocess_image(src_image)
src_disparity = self.preprocess_image(src_disparity)
pipe_args = dict(
src_image=src_image,
src_image_pil=src_image_pil,
src_disparity=src_disparity,
ratio=ratio
)
# Prepare inputs.
conditions, renders = self.prepare_conditions_stereogen(**pipe_args)
# NVS.
latents_clean = self.perform_nvs(
**pipe_args,
**conditions,
**renders
)
# Decode to images.
synthesized = self.decode_latents(latents_clean)
inference_out = {
'synthesized': synthesized,
'warped': renders['warped'],
'mask': renders['mask'],
'correspondence': conditions['correspondence']
}
return inference_out
def preprocess_image(
self,
image: Float[Tensor, 'B C H W']
) -> Float[Tensor, 'B C H W']:
image = F.interpolate(
image, (self.cfg.height, self.cfg.width)
)
return image
def get_image_prompt(
self,
src_image_pil: Image
) -> Float[Tensor, '2 B L']:
clip_image = self.clip_image_processor(
images=src_image_pil, return_tensors="pt"
).pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype)
).image_embeds
image_prompt_embeds = clip_image_embeds.unsqueeze(1)
uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
image_prompt_embeds = torch.cat(
[uncond_image_prompt_embeds, image_prompt_embeds], dim=0
)
return image_prompt_embeds
def encode_images(
self,
rgb: Float[Tensor, 'B C H W']
) -> Float[Tensor, 'B C H W']:
latents = self.vae.encode(rgb).latent_dist.mean # rgb [-1, 1]
latents = latents * 0.18215
return latents
def decode_latents(
self,
latents: Float[Tensor, 'B C H W']
) -> Float[Tensor, 'B C H W']:
latents = 1 / 0.18215 * latents
rgb = []
for frame_idx in range(latents.shape[0]):
rgb.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
rgb = torch.cat(rgb)
rgb = (rgb / 2 + 0.5).clamp(0, 1)
return rgb.squeeze(2)
def get_reference_controls(
self,
batch_size: int
) -> Tuple[ReferenceAttentionControl, ReferenceAttentionControl]:
reader = ReferenceAttentionControl(
self.denoising_unet,
do_classifier_free_guidance=True,
mode='read',
batch_size=batch_size,
fusion_blocks='full',
feature_fusion_type='attention_full_sharing'
)
writer = ReferenceAttentionControl(
self.reference_unet,
do_classifier_free_guidance=True,
mode='write',
batch_size=batch_size,
fusion_blocks='full',
feature_fusion_type='attention_full_sharing'
)
return reader, writer
def prepare_extra_step_kwargs(
self,
generator,
eta
) -> Dict[str, Any]:
accepts_eta = 'eta' in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs['eta'] = eta
# check if the scheduler accepts generator
accepts_generator = 'generator' in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs['generator'] = generator
return extra_step_kwargs
def get_pose_features(
self,
src_embed: Float[Tensor, 'B C H W'],
trg_embed: Float[Tensor, 'B C H W'],
do_classifier_guidance: bool = True
) -> Tuple[Tensor, Tensor]:
pose_cond_tensor = src_embed.unsqueeze(2)
pose_cond_tensor = pose_cond_tensor.to(
device=self.device, dtype=self.pose_guider.dtype
)
pose_cond_tensor_2 = trg_embed.unsqueeze(2)
pose_cond_tensor_2 = pose_cond_tensor_2.to(
device=self.device, dtype=self.pose_guider.dtype
)
pose_fea = self.pose_guider(pose_cond_tensor)
pose_fea_2 = self.pose_guider(pose_cond_tensor_2)
if do_classifier_guidance:
pose_fea = torch.cat([pose_fea] * 2)
pose_fea_2 = torch.cat([pose_fea_2] * 2)
return pose_fea, pose_fea_2
@torch.no_grad()
def prepare_conditions_stereogen(
self,
src_image: Float[Tensor, 'B C H W'],
src_image_pil,
src_disparity: Float[Tensor, 'B C H W'],
ratio
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
# Prepare inputs.
B = src_image.shape[0]
H, W = src_image.shape[2:4]
# Coordinate grids.
grid: Float[Tensor, 'H W C'] = torch.stack(torch.meshgrid(
torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
).to(self.device, dtype=self.dtype)
# Coordinates embedding.
coords = torch.stack((grid[..., 0]/H, grid[..., 1]/W), dim=-1)
embed = repeat(self.embedder(coords), 'h w c -> b c h w', b=B)
warped_embed, mask, warped_image, disparity = convert_left_to_right_torch(embed.squeeze(0), src_disparity.squeeze(0), src_image.squeeze(0), ratio)
warped_embed, mask, warped_image = warped_embed.unsqueeze(0), mask.unsqueeze(0).unsqueeze(0), warped_image.unsqueeze(0)
# mask == 1 where there's no pixel
# Conditions 14
src_coords_embed = torch.cat(
[embed, torch.zeros_like(mask, device=mask.device), src_image], dim=1)
trg_coords_embed = torch.cat([warped_embed, mask, warped_image], dim=1)
# Conditions 11
# src_coords_embed = torch.cat(
# [embed, torch.zeros_like(mask, device=mask.device)], dim=1)
# trg_coords_embed = torch.cat([warped_embed, mask], dim=1)
# # Conditions 4
# src_coords_embed = torch.cat(
# [torch.zeros_like(mask, device=mask.device), src_image], dim=1)
# trg_coords_embed = torch.cat([mask, warped_image], dim=1)
# Outputs.
conditions = dict(
src_coords_embed=src_coords_embed,
trg_coords_embed=trg_coords_embed,
correspondence=disparity
)
renders = dict(
warped=warped_image,
mask=1-mask # mask == 1 where there's a pixel
)
return conditions, renders
def perform_nvs(
self,
src_image,
src_image_pil,
src_coords_embed,
trg_coords_embed,
correspondence,
warped,
mask,
eta: float=0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]]=None,
**kwargs,
) -> Float[Tensor, 'B C H W']:
batch_size = src_image.shape[0]
# For the cross attention.
reference_control_reader, reference_control_writer = \
self.get_reference_controls(batch_size)
# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(
generator, eta
)
with torch.no_grad():
# Create fake inputs. It'll be replaced by pure noise.
latents = torch.randn(
batch_size,
self.unet_in_channels,
self.cfg.height // self.vae_scale_factor,
self.cfg.width // self.vae_scale_factor
).to(self.device, dtype=src_image.dtype)
initial_t = torch.tensor(
[self.num_train_timesteps - 1] * batch_size
).to(self.device, dtype=torch.long)
# Add noise.
noise = torch.randn_like(latents)
latents_noisy_start = self.scheduler.add_noise(
latents, noise, initial_t
)
latents_noisy_start = latents_noisy_start.unsqueeze(2)
# Prepare clip image embeds.
image_prompt_embeds = self.get_image_prompt(src_image_pil) # torch.Size([2, 1, 768])
# Prepare ref image latents.
ref_image_latents = self.encode_images(src_image)
# Prepare pose condition image.
pose_fea, pose_fea_2 = self.get_pose_features(
src_coords_embed, trg_coords_embed
)
pose_fea = pose_fea[:, :, 0, ...]
# Forward reference images.
self.reference_unet(
ref_image_latents.repeat(2, 1, 1, 1), # torch.Size([bs*2, 4, 64, 64])
torch.zeros(batch_size * 2).to(ref_image_latents), # torch.Size([bs*2])
encoder_hidden_states=image_prompt_embeds, # torch.Size([bs*2, 1, 768])
pose_cond_fea=pose_fea, # torch.Size([bs*2, 320, 64, 64])
return_dict=False,
)
# Update the denosing net with reference features.
reference_control_reader.update(
reference_control_writer,
correspondence=correspondence # torch.Size([bs, 1, 512, 512])
)
timesteps = self.scheduler.timesteps
latents_noisy = latents_noisy_start
for t in timesteps:
# Prepare latents.
latent_model_input = torch.cat([latents_noisy] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# Denoise.
noise_pred = self.denoising_unet(
latent_model_input, # torch.Size([bs*2, 4, 1, 64, 64])
t, # torch.Size([])
encoder_hidden_states=image_prompt_embeds, # torch.Size([bs*2, 1, 768])
pose_cond_fea=pose_fea_2, # torch.Size([bs*2, 320, 1, 64, 64])
return_dict=False,
)[0] # torch.Size([bs*2, 4, 1, 64, 64])
# CFG.
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# t -> t-1
latents_noisy = self.scheduler.step(
noise_pred, t, latents_noisy, **extra_step_kwargs,
return_dict=False
)[0]
# Noise disappears eventually
latents_clean = latents_noisy
reference_control_reader.clear()
reference_control_writer.clear()
return latents_clean.squeeze(2)