|
from os.path import join |
|
from typing import Union, Optional, List, Dict, Tuple, Any |
|
from dataclasses import dataclass |
|
import inspect |
|
|
|
from omegaconf import OmegaConf, DictConfig |
|
from jaxtyping import Float |
|
from PIL import Image |
|
|
|
import torch |
|
from torch import Tensor |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
from einops import rearrange, repeat |
|
|
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
from diffusers.image_processor import VaeImageProcessor |
|
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
|
|
|
from .models import ( |
|
PoseGuider, |
|
UNet2DConditionModel, |
|
UNet3DConditionModel, |
|
ReferenceAttentionControl |
|
) |
|
from .ops import get_viewport_matrix, forward_warper, convert_left_to_right, convert_left_to_right_torch |
|
|
|
class AdaptiveFusionLayer(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.fusion_layer = nn.Sequential( |
|
nn.Conv2d(7, 1, kernel_size=3, stride=1, padding=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, generated_image, warped_image, mask): |
|
|
|
fusion_input = torch.cat([generated_image, warped_image, mask], dim=1) |
|
weights = self.fusion_layer(fusion_input) |
|
fused_output = mask * weights * warped_image + (1 - mask * weights) * generated_image |
|
return fused_output |
|
|
|
class GenStereo(): |
|
@dataclass |
|
class Config(): |
|
pretrained_model_path: str = '' |
|
checkpoint_name: str = '' |
|
half_precision_weights: bool = False |
|
height: int = 512 |
|
width: int = 512 |
|
num_inference_steps: int = 50 |
|
guidance_scale: float = 1.5 |
|
cfg: Config |
|
|
|
class Embedder(): |
|
def __init__(self, **kwargs) -> None: |
|
self.kwargs = kwargs |
|
self.create_embedding_fn() |
|
|
|
def create_embedding_fn(self) -> None: |
|
embed_fns = [] |
|
d = self.kwargs['input_dims'] |
|
out_dim = 0 |
|
if self.kwargs['include_input']: |
|
embed_fns.append(lambda x : x) |
|
out_dim += d |
|
|
|
max_freq = self.kwargs['max_freq_log2'] |
|
N_freqs = self.kwargs['num_freqs'] |
|
|
|
if self.kwargs['log_sampling']: |
|
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) |
|
else: |
|
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) |
|
|
|
for freq in freq_bands: |
|
for p_fn in self.kwargs['periodic_fns']: |
|
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) |
|
out_dim += d |
|
|
|
self.embed_fns = embed_fns |
|
self.out_dim = out_dim |
|
|
|
def embed(self, inputs) -> Tensor: |
|
return torch.cat([fn(inputs) for fn in self.embed_fns], -1) |
|
|
|
def __init__( |
|
self, |
|
cfg: Optional[Union[dict, DictConfig]] = None, |
|
device: Optional[str] = 'cuda:0' |
|
) -> None: |
|
self.cfg = OmegaConf.structured(self.Config(**cfg)) |
|
self.model_path = join( |
|
self.cfg.pretrained_model_path, self.cfg.checkpoint_name |
|
) |
|
self.device = device |
|
self.configure() |
|
self.transform_pixels = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]) |
|
]) |
|
|
|
def configure(self) -> None: |
|
print(f"Loading GenStereo...") |
|
|
|
|
|
self.dtype = ( |
|
torch.float16 if self.cfg.half_precision_weights else torch.float32 |
|
) |
|
self.viewport_mtx: Float[Tensor, 'B 4 4'] = get_viewport_matrix( |
|
self.cfg.width, self.cfg.height, |
|
batch_size=1, device=self.device |
|
).to(self.dtype) |
|
|
|
|
|
self.load_models() |
|
|
|
|
|
self.scheduler.set_timesteps( |
|
self.cfg.num_inference_steps, device=self.device) |
|
self.num_train_timesteps = self.scheduler.config.num_train_timesteps |
|
|
|
print(f"Loaded GenStereo.") |
|
|
|
def load_models(self) -> None: |
|
|
|
self.vae = AutoencoderKL.from_pretrained( |
|
join(self.cfg.pretrained_model_path, 'sd-vae-ft-mse') |
|
).to(self.device, dtype=self.dtype) |
|
|
|
|
|
self.vae_scale_factor = \ |
|
2 ** (len(self.vae.config.block_out_channels) - 1) |
|
self.vae_image_processor = VaeImageProcessor( |
|
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True |
|
) |
|
self.clip_image_processor = CLIPImageProcessor() |
|
|
|
|
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
|
join(self.cfg.pretrained_model_path, 'image_encoder') |
|
).to(self.device, dtype=self.dtype) |
|
|
|
|
|
self.reference_unet = UNet2DConditionModel.from_config( |
|
UNet2DConditionModel.load_config( |
|
join(self.model_path, 'config.json') |
|
)).to(self.device, dtype=self.dtype) |
|
self.reference_unet.load_state_dict(torch.load( |
|
join(self.model_path, 'reference_unet.pth'), |
|
map_location= 'cpu'), |
|
) |
|
|
|
|
|
self.denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
|
join(self.model_path, 'config.json'), |
|
join(self.model_path, 'denoising_unet.pth') |
|
).to(self.device, dtype=self.dtype) |
|
self.unet_in_channels = self.denoising_unet.config.in_channels |
|
|
|
|
|
self.pose_guider = PoseGuider( |
|
conditioning_embedding_channels=320, |
|
|
|
conditioning_channels=14, |
|
).to(self.device, dtype=self.dtype) |
|
self.pose_guider.load_state_dict(torch.load( |
|
join(self.model_path, 'pose_guider.pth'), |
|
map_location='cpu'), |
|
) |
|
|
|
|
|
sched_kwargs = OmegaConf.to_container(OmegaConf.create({ |
|
'num_train_timesteps': 1000, |
|
'beta_start': 0.00085, |
|
'beta_end': 0.012, |
|
'beta_schedule': 'scaled_linear', |
|
'steps_offset': 1, |
|
'clip_sample': False |
|
})) |
|
sched_kwargs.update( |
|
rescale_betas_zero_snr=True, |
|
timestep_spacing='trailing', |
|
prediction_type='v_prediction', |
|
) |
|
self.scheduler = DDIMScheduler(**sched_kwargs) |
|
|
|
self.vae.requires_grad_(False) |
|
self.image_encoder.requires_grad_(False) |
|
self.reference_unet.requires_grad_(False) |
|
self.denoising_unet.requires_grad_(False) |
|
self.pose_guider.requires_grad_(False) |
|
|
|
|
|
self.embedder = self.get_embedder(2) |
|
|
|
def get_embedder(self, multires): |
|
embed_kwargs = { |
|
'include_input' : True, |
|
'input_dims' : 2, |
|
'max_freq_log2' : multires-1, |
|
'num_freqs' : multires, |
|
'log_sampling' : True, |
|
'periodic_fns' : [torch.sin, torch.cos], |
|
} |
|
|
|
embedder_obj = self.Embedder(**embed_kwargs) |
|
embed = lambda x, eo=embedder_obj : eo.embed(x) |
|
return embed |
|
|
|
def __call__( |
|
self, |
|
src_image: Image, |
|
src_disparity: Float[Tensor, 'B C H W'], |
|
ratio |
|
) -> Dict[str, Tensor]: |
|
""" Perform NVS. |
|
""" |
|
|
|
src_image_pil = src_image |
|
src_image = self.transform_pixels(src_image).unsqueeze(0).to(self.device, dtype=self.dtype) |
|
batch_size = src_image.shape[0] |
|
|
|
src_image = self.preprocess_image(src_image) |
|
src_disparity = self.preprocess_image(src_disparity) |
|
|
|
pipe_args = dict( |
|
src_image=src_image, |
|
src_image_pil=src_image_pil, |
|
src_disparity=src_disparity, |
|
ratio=ratio |
|
) |
|
|
|
|
|
conditions, renders = self.prepare_conditions_stereogen(**pipe_args) |
|
|
|
|
|
latents_clean = self.perform_nvs( |
|
**pipe_args, |
|
**conditions, |
|
**renders |
|
) |
|
|
|
|
|
synthesized = self.decode_latents(latents_clean) |
|
|
|
inference_out = { |
|
'synthesized': synthesized, |
|
'warped': renders['warped'], |
|
'mask': renders['mask'], |
|
'correspondence': conditions['correspondence'] |
|
} |
|
|
|
return inference_out |
|
|
|
def preprocess_image( |
|
self, |
|
image: Float[Tensor, 'B C H W'] |
|
) -> Float[Tensor, 'B C H W']: |
|
image = F.interpolate( |
|
image, (self.cfg.height, self.cfg.width) |
|
) |
|
return image |
|
|
|
def get_image_prompt( |
|
self, |
|
src_image_pil: Image |
|
) -> Float[Tensor, '2 B L']: |
|
clip_image = self.clip_image_processor( |
|
images=src_image_pil, return_tensors="pt" |
|
).pixel_values |
|
|
|
clip_image_embeds = self.image_encoder( |
|
clip_image.to(self.device, dtype=self.image_encoder.dtype) |
|
).image_embeds |
|
|
|
image_prompt_embeds = clip_image_embeds.unsqueeze(1) |
|
uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds) |
|
|
|
image_prompt_embeds = torch.cat( |
|
[uncond_image_prompt_embeds, image_prompt_embeds], dim=0 |
|
) |
|
|
|
return image_prompt_embeds |
|
|
|
def encode_images( |
|
self, |
|
rgb: Float[Tensor, 'B C H W'] |
|
) -> Float[Tensor, 'B C H W']: |
|
latents = self.vae.encode(rgb).latent_dist.mean |
|
latents = latents * 0.18215 |
|
return latents |
|
|
|
def decode_latents( |
|
self, |
|
latents: Float[Tensor, 'B C H W'] |
|
) -> Float[Tensor, 'B C H W']: |
|
latents = 1 / 0.18215 * latents |
|
rgb = [] |
|
for frame_idx in range(latents.shape[0]): |
|
rgb.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample) |
|
rgb = torch.cat(rgb) |
|
rgb = (rgb / 2 + 0.5).clamp(0, 1) |
|
return rgb.squeeze(2) |
|
|
|
def get_reference_controls( |
|
self, |
|
batch_size: int |
|
) -> Tuple[ReferenceAttentionControl, ReferenceAttentionControl]: |
|
reader = ReferenceAttentionControl( |
|
self.denoising_unet, |
|
do_classifier_free_guidance=True, |
|
mode='read', |
|
batch_size=batch_size, |
|
fusion_blocks='full', |
|
feature_fusion_type='attention_full_sharing' |
|
) |
|
writer = ReferenceAttentionControl( |
|
self.reference_unet, |
|
do_classifier_free_guidance=True, |
|
mode='write', |
|
batch_size=batch_size, |
|
fusion_blocks='full', |
|
feature_fusion_type='attention_full_sharing' |
|
) |
|
|
|
return reader, writer |
|
|
|
def prepare_extra_step_kwargs( |
|
self, |
|
generator, |
|
eta |
|
) -> Dict[str, Any]: |
|
accepts_eta = 'eta' in set( |
|
inspect.signature(self.scheduler.step).parameters.keys() |
|
) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs['eta'] = eta |
|
|
|
|
|
accepts_generator = 'generator' in set( |
|
inspect.signature(self.scheduler.step).parameters.keys() |
|
) |
|
if accepts_generator: |
|
extra_step_kwargs['generator'] = generator |
|
return extra_step_kwargs |
|
|
|
def get_pose_features( |
|
self, |
|
src_embed: Float[Tensor, 'B C H W'], |
|
trg_embed: Float[Tensor, 'B C H W'], |
|
do_classifier_guidance: bool = True |
|
) -> Tuple[Tensor, Tensor]: |
|
pose_cond_tensor = src_embed.unsqueeze(2) |
|
pose_cond_tensor = pose_cond_tensor.to( |
|
device=self.device, dtype=self.pose_guider.dtype |
|
) |
|
pose_cond_tensor_2 = trg_embed.unsqueeze(2) |
|
pose_cond_tensor_2 = pose_cond_tensor_2.to( |
|
device=self.device, dtype=self.pose_guider.dtype |
|
) |
|
pose_fea = self.pose_guider(pose_cond_tensor) |
|
pose_fea_2 = self.pose_guider(pose_cond_tensor_2) |
|
|
|
if do_classifier_guidance: |
|
pose_fea = torch.cat([pose_fea] * 2) |
|
pose_fea_2 = torch.cat([pose_fea_2] * 2) |
|
|
|
return pose_fea, pose_fea_2 |
|
|
|
@torch.no_grad() |
|
def prepare_conditions_stereogen( |
|
self, |
|
src_image: Float[Tensor, 'B C H W'], |
|
src_image_pil, |
|
src_disparity: Float[Tensor, 'B C H W'], |
|
ratio |
|
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: |
|
|
|
B = src_image.shape[0] |
|
H, W = src_image.shape[2:4] |
|
|
|
|
|
grid: Float[Tensor, 'H W C'] = torch.stack(torch.meshgrid( |
|
torch.arange(W), torch.arange(H), indexing='xy'), dim=-1 |
|
).to(self.device, dtype=self.dtype) |
|
|
|
|
|
coords = torch.stack((grid[..., 0]/H, grid[..., 1]/W), dim=-1) |
|
embed = repeat(self.embedder(coords), 'h w c -> b c h w', b=B) |
|
|
|
warped_embed, mask, warped_image, disparity = convert_left_to_right_torch(embed.squeeze(0), src_disparity.squeeze(0), src_image.squeeze(0), ratio) |
|
warped_embed, mask, warped_image = warped_embed.unsqueeze(0), mask.unsqueeze(0).unsqueeze(0), warped_image.unsqueeze(0) |
|
|
|
|
|
|
|
src_coords_embed = torch.cat( |
|
[embed, torch.zeros_like(mask, device=mask.device), src_image], dim=1) |
|
trg_coords_embed = torch.cat([warped_embed, mask, warped_image], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conditions = dict( |
|
src_coords_embed=src_coords_embed, |
|
trg_coords_embed=trg_coords_embed, |
|
correspondence=disparity |
|
) |
|
renders = dict( |
|
warped=warped_image, |
|
mask=1-mask |
|
) |
|
|
|
return conditions, renders |
|
|
|
def perform_nvs( |
|
self, |
|
src_image, |
|
src_image_pil, |
|
src_coords_embed, |
|
trg_coords_embed, |
|
correspondence, |
|
warped, |
|
mask, |
|
eta: float=0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]]=None, |
|
**kwargs, |
|
) -> Float[Tensor, 'B C H W']: |
|
batch_size = src_image.shape[0] |
|
|
|
|
|
reference_control_reader, reference_control_writer = \ |
|
self.get_reference_controls(batch_size) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs( |
|
generator, eta |
|
) |
|
|
|
with torch.no_grad(): |
|
|
|
latents = torch.randn( |
|
batch_size, |
|
self.unet_in_channels, |
|
self.cfg.height // self.vae_scale_factor, |
|
self.cfg.width // self.vae_scale_factor |
|
).to(self.device, dtype=src_image.dtype) |
|
initial_t = torch.tensor( |
|
[self.num_train_timesteps - 1] * batch_size |
|
).to(self.device, dtype=torch.long) |
|
|
|
|
|
noise = torch.randn_like(latents) |
|
latents_noisy_start = self.scheduler.add_noise( |
|
latents, noise, initial_t |
|
) |
|
latents_noisy_start = latents_noisy_start.unsqueeze(2) |
|
|
|
image_prompt_embeds = self.get_image_prompt(src_image_pil) |
|
|
|
|
|
ref_image_latents = self.encode_images(src_image) |
|
|
|
|
|
pose_fea, pose_fea_2 = self.get_pose_features( |
|
src_coords_embed, trg_coords_embed |
|
) |
|
|
|
pose_fea = pose_fea[:, :, 0, ...] |
|
|
|
self.reference_unet( |
|
ref_image_latents.repeat(2, 1, 1, 1), |
|
torch.zeros(batch_size * 2).to(ref_image_latents), |
|
encoder_hidden_states=image_prompt_embeds, |
|
pose_cond_fea=pose_fea, |
|
return_dict=False, |
|
) |
|
|
|
reference_control_reader.update( |
|
reference_control_writer, |
|
correspondence=correspondence |
|
) |
|
|
|
timesteps = self.scheduler.timesteps |
|
latents_noisy = latents_noisy_start |
|
for t in timesteps: |
|
|
|
latent_model_input = torch.cat([latents_noisy] * 2) |
|
latent_model_input = self.scheduler.scale_model_input( |
|
latent_model_input, t |
|
) |
|
|
|
|
|
noise_pred = self.denoising_unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=image_prompt_embeds, |
|
pose_cond_fea=pose_fea_2, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
|
|
latents_noisy = self.scheduler.step( |
|
noise_pred, t, latents_noisy, **extra_step_kwargs, |
|
return_dict=False |
|
)[0] |
|
|
|
|
|
latents_clean = latents_noisy |
|
|
|
reference_control_reader.clear() |
|
reference_control_writer.clear() |
|
|
|
return latents_clean.squeeze(2) |