AetherV1 / app.py
Wenzheng Chang
update bloat16
8040e22
raw
history blame
50.5 kB
import gc
import os
import random
import re
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import gradio as gr
import imageio.v3 as iio
import numpy as np
import PIL
import rootutils
import torch
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXTransformer3DModel,
)
from transformers import AutoTokenizer, T5EncoderModel
import spaces
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402
AetherV1PipelineCogVideoX,
AetherV1PipelineOutput,
)
from aether.utils.postprocess_utils import ( # noqa: E402
align_camera_extrinsics,
apply_transformation,
colorize_depth,
compute_scale,
get_intrinsics,
interpolate_poses,
postprocess_pointmap,
project,
raymap_to_poses,
)
from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
def seed_all(seed: int = 0) -> None:
"""
Set random seeds of all components.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# # Global pipeline
cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
pipeline = AetherV1PipelineCogVideoX(
tokenizer=AutoTokenizer.from_pretrained(
cogvideox_pretrained_model_name_or_path,
subfolder="tokenizer",
),
text_encoder=T5EncoderModel.from_pretrained(
cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
),
vae=AutoencoderKLCogVideoX.from_pretrained(
cogvideox_pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch.bfloat16
),
scheduler=CogVideoXDPMScheduler.from_pretrained(
cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
),
transformer=CogVideoXTransformer3DModel.from_pretrained(
aether_pretrained_model_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16
),
)
pipeline.vae.enable_slicing()
pipeline.vae.enable_tiling()
# pipeline.to(device)
def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX:
"""Initialize the model pipeline."""
# cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
# aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
# pipeline = AetherV1PipelineCogVideoX(
# tokenizer=AutoTokenizer.from_pretrained(
# cogvideox_pretrained_model_name_or_path,
# subfolder="tokenizer",
# ),
# text_encoder=T5EncoderModel.from_pretrained(
# cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
# ),
# vae=AutoencoderKLCogVideoX.from_pretrained(
# cogvideox_pretrained_model_name_or_path, subfolder="vae"
# ),
# scheduler=CogVideoXDPMScheduler.from_pretrained(
# cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
# ),
# transformer=CogVideoXTransformer3DModel.from_pretrained(
# aether_pretrained_model_name_or_path, subfolder="transformer"
# ),
# )
# pipeline.vae.enable_slicing()
# pipeline.vae.enable_tiling()
pipeline.to(device)
return pipeline
def get_window_starts(
total_frames: int, sliding_window_size: int, temporal_stride: int
) -> List[int]:
"""Calculate window start indices."""
starts = list(
range(
0,
total_frames - sliding_window_size + 1,
temporal_stride,
)
)
if (
total_frames > sliding_window_size
and (total_frames - sliding_window_size) % temporal_stride != 0
):
starts.append(total_frames - sliding_window_size)
return starts
def blend_and_merge_window_results(
window_results: List[AetherV1PipelineOutput], window_indices: List[int], args: Dict
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Blend and merge window results."""
merged_rgb = None
merged_disparity = None
merged_poses = None
merged_focals = None
align_pointmaps = args.get("align_pointmaps", True)
smooth_camera = args.get("smooth_camera", True)
smooth_method = args.get("smooth_method", "kalman") if smooth_camera else "none"
if align_pointmaps:
merged_pointmaps = None
w1 = window_results[0].disparity
for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)):
t_end = t_start + window_result.rgb.shape[0]
if idx == 0:
merged_rgb = window_result.rgb
merged_disparity = window_result.disparity
pointmap_dict = postprocess_pointmap(
window_result.disparity,
window_result.raymap,
vae_downsample_scale=8,
ray_o_scale_inv=0.1,
smooth_camera=smooth_camera,
smooth_method=smooth_method if smooth_camera else "none",
)
merged_poses = pointmap_dict["camera_pose"]
merged_focals = (
pointmap_dict["intrinsics"][:, 0, 0]
+ pointmap_dict["intrinsics"][:, 1, 1]
) / 2
if align_pointmaps:
merged_pointmaps = pointmap_dict["pointmap"]
else:
overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start
window_disparity = window_result.disparity
# Align disparity
disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1
scale = compute_scale(
window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]),
merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]),
disp_mask.reshape(1, -1, w1.shape[-1]),
)
window_disparity = scale * window_disparity
# Blend disparity
result_disparity = np.ones((t_end, *w1.shape[1:]))
result_disparity[:t_start] = merged_disparity[:t_start]
result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:]
weight = np.linspace(1, 0, overlap_t)[:, None, None]
result_disparity[t_start : t_start + overlap_t] = merged_disparity[
t_start : t_start + overlap_t
] * weight + window_disparity[:overlap_t] * (1 - weight)
merged_disparity = result_disparity
# Blend RGB
result_rgb = np.ones((t_end, *w1.shape[1:], 3))
result_rgb[:t_start] = merged_rgb[:t_start]
result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:]
weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None]
result_rgb[t_start : t_start + overlap_t] = merged_rgb[
t_start : t_start + overlap_t
] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb)
merged_rgb = result_rgb
# Align poses
window_raymap = window_result.raymap
window_poses, window_Fov_x, window_Fov_y = raymap_to_poses(
window_raymap, ray_o_scale_inv=0.1
)
rel_r, rel_t, rel_s = align_camera_extrinsics(
torch.from_numpy(window_poses[:overlap_t]),
torch.from_numpy(merged_poses[-overlap_t:]),
)
aligned_window_poses = (
apply_transformation(
torch.from_numpy(window_poses),
rel_r,
rel_t,
rel_s,
return_extri=True,
)
.cpu()
.numpy()
)
result_poses = np.ones((t_end, 4, 4))
result_poses[:t_start] = merged_poses[:t_start]
result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:]
# Interpolate poses in overlap region
weights = np.linspace(1, 0, overlap_t)
for t in range(overlap_t):
weight = weights[t]
pose1 = merged_poses[t_start + t]
pose2 = aligned_window_poses[t]
result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight)
merged_poses = result_poses
# Align intrinsics
window_intrinsics, _ = get_intrinsics(
batch_size=window_poses.shape[0],
h=window_result.disparity.shape[1],
w=window_result.disparity.shape[2],
fovx=window_Fov_x,
fovy=window_Fov_y,
)
window_focals = (
window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1]
) / 2
scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean()
window_focals = scale * window_focals
result_focals = np.ones((t_end,))
result_focals[:t_start] = merged_focals[:t_start]
result_focals[t_start + overlap_t :] = window_focals[overlap_t:]
weight = np.linspace(1, 0, overlap_t)
result_focals[t_start : t_start + overlap_t] = merged_focals[
t_start : t_start + overlap_t
] * weight + window_focals[:overlap_t] * (1 - weight)
merged_focals = result_focals
if align_pointmaps:
# Align pointmaps
window_pointmaps = postprocess_pointmap(
result_disparity[t_start:],
window_raymap,
vae_downsample_scale=8,
camera_pose=aligned_window_poses,
focal=window_focals,
ray_o_scale_inv=0.1,
smooth_camera=smooth_camera,
smooth_method=smooth_method if smooth_camera else "none",
)
result_pointmaps = np.ones((t_end, *w1.shape[1:], 3))
result_pointmaps[:t_start] = merged_pointmaps[:t_start]
result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][
overlap_t:
]
weight = np.linspace(1, 0, overlap_t)[:, None, None, None]
result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[
t_start : t_start + overlap_t
] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight)
merged_pointmaps = result_pointmaps
# project to pointmaps
height = args.get("height", 480)
width = args.get("width", 720)
intrinsics = [
np.array([[f, 0, 0.5 * width], [0, f, 0.5 * height], [0, 0, 1]])
for f in merged_focals
]
if align_pointmaps:
pointmaps = merged_pointmaps
else:
pointmaps = np.stack(
[
project(
1 / np.clip(merged_disparity[i], 1e-8, 1e8),
intrinsics[i],
merged_poses[i],
)
for i in range(merged_poses.shape[0])
]
)
return merged_rgb, merged_disparity, merged_poses, pointmaps
def process_video_to_frames(video_path: str, fps_sample: int = 12) -> List[str]:
"""Process video into frames and save them locally."""
# Create a unique output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"temp_frames_{timestamp}"
os.makedirs(output_dir, exist_ok=True)
# Read video
video = iio.imread(video_path)
# Calculate frame interval based on original video fps
if isinstance(video, np.ndarray):
# For captured videos
total_frames = len(video)
frame_interval = max(
1, round(total_frames / (fps_sample * (total_frames / 30)))
)
else:
# Default if can't determine
frame_interval = 2
frame_paths = []
for i, frame in enumerate(video[::frame_interval]):
frame_path = os.path.join(output_dir, f"frame_{i:04d}.jpg")
if isinstance(frame, np.ndarray):
iio.imwrite(frame_path, frame)
frame_paths.append(frame_path)
return frame_paths, output_dir
def save_output_files(
rgb: np.ndarray,
disparity: np.ndarray,
poses: Optional[np.ndarray] = None,
raymap: Optional[np.ndarray] = None,
pointmap: Optional[np.ndarray] = None,
task: str = "reconstruction",
output_dir: str = "outputs",
**kwargs,
) -> Dict[str, str]:
"""
Save outputs and return paths to saved files.
"""
os.makedirs(output_dir, exist_ok=True)
if pointmap is None and raymap is not None:
# Generate pointmap from raymap and disparity
smooth_camera = kwargs.get("smooth_camera", True)
smooth_method = (
kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
)
pointmap_dict = postprocess_pointmap(
disparity,
raymap,
vae_downsample_scale=8,
ray_o_scale_inv=0.1,
smooth_camera=smooth_camera,
smooth_method=smooth_method,
)
pointmap = pointmap_dict["pointmap"]
if poses is None and raymap is not None:
poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
# Create a unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_filename = f"{task}_{timestamp}"
# Paths for saved files
paths = {}
# Save RGB video
rgb_path = os.path.join(output_dir, f"{base_filename}_rgb.mp4")
iio.imwrite(
rgb_path,
(np.clip(rgb, 0, 1) * 255).astype(np.uint8),
fps=kwargs.get("fps", 12),
)
paths["rgb"] = rgb_path
# Save depth/disparity video
depth_path = os.path.join(output_dir, f"{base_filename}_disparity.mp4")
iio.imwrite(
depth_path,
(colorize_depth(disparity) * 255).astype(np.uint8),
fps=kwargs.get("fps", 12),
)
paths["disparity"] = depth_path
# Save point cloud GLB files
if pointmap is not None and poses is not None:
pointcloud_save_frame_interval = kwargs.get(
"pointcloud_save_frame_interval", 10
)
max_depth = kwargs.get("max_depth", 100.0)
rtol = kwargs.get("rtol", 0.03)
glb_paths = []
# Determine which frames to save based on the interval
frames_to_save = list(
range(0, pointmap.shape[0], pointcloud_save_frame_interval)
)
# Always include the first and last frame
if 0 not in frames_to_save:
frames_to_save.insert(0, 0)
if pointmap.shape[0] - 1 not in frames_to_save:
frames_to_save.append(pointmap.shape[0] - 1)
# Sort the frames to ensure they're in order
frames_to_save = sorted(set(frames_to_save))
for frame_idx in frames_to_save:
if frame_idx >= pointmap.shape[0]:
continue
predictions = {
"world_points": pointmap[frame_idx : frame_idx + 1],
"images": rgb[frame_idx : frame_idx + 1],
"depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
"camera_poses": poses[frame_idx : frame_idx + 1],
}
glb_path = os.path.join(
output_dir, f"{base_filename}_pointcloud_frame_{frame_idx}.glb"
)
scene_3d = predictions_to_glb(
predictions,
filter_by_frames="all",
show_cam=True,
max_depth=max_depth,
rtol=rtol,
frame_rel_idx=float(frame_idx) / pointmap.shape[0],
)
scene_3d.export(glb_path)
glb_paths.append(glb_path)
paths["pointcloud_glbs"] = glb_paths
return paths
@spaces.GPU(duration=300)
def process_reconstruction(
video_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
sliding_window_stride,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
progress=gr.Progress(),
):
"""
Process reconstruction task.
"""
try:
gc.collect()
torch.cuda.empty_cache()
# 设置随机种子
seed_all(seed)
# 检查CUDA是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
pipeline = build_pipeline(device)
progress(0.1, "Loading video")
# Check if video_file is a string or a file object
if isinstance(video_file, str):
video_path = video_file
else:
video_path = video_file.name
video = iio.imread(video_path).astype(np.float32) / 255.0
# Setup arguments
args = {
"height": height,
"width": width,
"num_frames": num_frames,
"sliding_window_stride": sliding_window_stride,
"smooth_camera": smooth_camera,
"smooth_method": "kalman" if smooth_camera else "none",
"align_pointmaps": align_pointmaps,
"max_depth": max_depth,
"rtol": rtol,
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
}
# Process in sliding windows
window_results = []
window_indices = get_window_starts(
len(video), num_frames, sliding_window_stride
)
progress(0.2, f"Processing video in {len(window_indices)} windows")
for i, start_idx in enumerate(window_indices):
progress_val = 0.2 + (0.6 * (i / len(window_indices)))
progress(progress_val, f"Processing window {i+1}/{len(window_indices)}")
output = pipeline(
task="reconstruction",
image=None,
goal=None,
video=video[start_idx : start_idx + num_frames],
raymap=None,
height=height,
width=width,
num_frames=num_frames,
fps=fps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
use_dynamic_cfg=False,
generator=torch.Generator(device=device).manual_seed(seed),
)
window_results.append(output)
progress(0.8, "Merging results from all windows")
# Merge window results
(
merged_rgb,
merged_disparity,
merged_poses,
pointmaps,
) = blend_and_merge_window_results(window_results, window_indices, args)
progress(0.9, "Saving output files")
# Save output files
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
output_paths = save_output_files(
rgb=merged_rgb,
disparity=merged_disparity,
poses=merged_poses,
pointmap=pointmaps,
task="reconstruction",
output_dir=output_dir,
fps=12,
**args,
)
progress(1.0, "Done!")
# Return paths for displaying
return (
output_paths["rgb"],
output_paths["disparity"],
output_paths.get("pointcloud_glbs", []),
)
except Exception:
import traceback
traceback.print_exc()
return None, None, []
@spaces.GPU(duration=240)
def process_prediction(
image_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
progress=gr.Progress(),
):
"""
Process prediction task.
"""
try:
gc.collect()
torch.cuda.empty_cache()
# Set random seed
seed_all(seed)
# Build the pipeline
pipeline = build_pipeline(device)
progress(0.1, "Loading image")
# Check if image_file is a string or a file object
if isinstance(image_file, str):
image_path = image_file
else:
image_path = image_file.name
image = PIL.Image.open(image_path)
progress(0.2, "Running prediction")
# Run prediction
output = pipeline(
task="prediction",
image=image,
video=None,
goal=None,
raymap=np.load(f"assets/example_raymaps/raymap_{raymap_option}.npy"),
height=height,
width=width,
num_frames=num_frames,
fps=fps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
use_dynamic_cfg=use_dynamic_cfg,
generator=torch.Generator(device=device).manual_seed(seed),
return_dict=True,
)
# Show RGB output immediately
rgb_output = output.rgb
# Setup arguments for saving
args = {
"height": height,
"width": width,
"smooth_camera": smooth_camera,
"smooth_method": "kalman" if smooth_camera else "none",
"align_pointmaps": align_pointmaps,
"max_depth": max_depth,
"rtol": rtol,
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
}
if post_reconstruction:
progress(0.5, "Running post-reconstruction for better quality")
recon_output = pipeline(
task="reconstruction",
video=output.rgb,
height=height,
width=width,
num_frames=num_frames,
fps=fps,
num_inference_steps=4,
guidance_scale=1.0,
use_dynamic_cfg=False,
generator=torch.Generator(device=device).manual_seed(seed),
)
disparity = recon_output.disparity
raymap = recon_output.raymap
else:
disparity = output.disparity
raymap = output.raymap
progress(0.8, "Saving output files")
# Save output files
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
output_paths = save_output_files(
rgb=rgb_output,
disparity=disparity,
raymap=raymap,
task="prediction",
output_dir=output_dir,
fps=12,
**args,
)
progress(1.0, "Done!")
# Return paths for displaying
return (
output_paths["rgb"],
output_paths["disparity"],
output_paths.get("pointcloud_glbs", []),
)
except Exception:
import traceback
traceback.print_exc()
return None, None, []
@spaces.GPU(duration=240)
def process_planning(
image_file,
goal_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
use_dynamic_cfg,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
progress=gr.Progress(),
):
"""
Process planning task.
"""
try:
gc.collect()
torch.cuda.empty_cache()
# Set random seed
seed_all(seed)
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
# Build the pipeline
pipeline = build_pipeline(device)
progress(0.1, "Loading images")
# Check if image_file and goal_file are strings or file objects
if isinstance(image_file, str):
image_path = image_file
else:
image_path = image_file.name
if isinstance(goal_file, str):
goal_path = goal_file
else:
goal_path = goal_file.name
image = PIL.Image.open(image_path)
goal = PIL.Image.open(goal_path)
progress(0.2, "Running planning")
# Run planning
output = pipeline(
task="planning",
image=image,
video=None,
goal=goal,
raymap=None,
height=height,
width=width,
num_frames=num_frames,
fps=fps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
use_dynamic_cfg=use_dynamic_cfg,
generator=torch.Generator(device=device).manual_seed(seed),
return_dict=True,
)
# Show RGB output immediately
rgb_output = output.rgb
# Setup arguments for saving
args = {
"height": height,
"width": width,
"smooth_camera": smooth_camera,
"smooth_method": "kalman" if smooth_camera else "none",
"align_pointmaps": align_pointmaps,
"max_depth": max_depth,
"rtol": rtol,
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
}
if post_reconstruction:
progress(0.5, "Running post-reconstruction for better quality")
recon_output = pipeline(
task="reconstruction",
video=output.rgb,
height=height,
width=width,
num_frames=num_frames,
fps=12,
num_inference_steps=4,
guidance_scale=1.0,
use_dynamic_cfg=False,
generator=torch.Generator(device=device).manual_seed(seed),
)
disparity = recon_output.disparity
raymap = recon_output.raymap
else:
disparity = output.disparity
raymap = output.raymap
progress(0.8, "Saving output files")
# Save output files
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
output_paths = save_output_files(
rgb=rgb_output,
disparity=disparity,
raymap=raymap,
task="planning",
output_dir=output_dir,
fps=fps,
**args,
)
progress(1.0, "Done!")
# Return paths for displaying
return (
output_paths["rgb"],
output_paths["disparity"],
output_paths.get("pointcloud_glbs", []),
)
except Exception:
import traceback
traceback.print_exc()
return None, None, []
def update_task_ui(task):
"""Update UI elements based on selected task."""
if task == "reconstruction":
return (
gr.update(visible=True), # video_input
gr.update(visible=False), # image_input
gr.update(visible=False), # goal_input
gr.update(visible=False), # image_preview
gr.update(visible=False), # goal_preview
gr.update(value=4), # num_inference_steps
gr.update(visible=True), # sliding_window_stride
gr.update(visible=False), # use_dynamic_cfg
gr.update(visible=False), # raymap_option
gr.update(visible=False), # post_reconstruction
gr.update(value=1.0), # guidance_scale
)
elif task == "prediction":
return (
gr.update(visible=False), # video_input
gr.update(visible=True), # image_input
gr.update(visible=False), # goal_input
gr.update(visible=True), # image_preview
gr.update(visible=False), # goal_preview
gr.update(value=50), # num_inference_steps
gr.update(visible=False), # sliding_window_stride
gr.update(visible=True), # use_dynamic_cfg
gr.update(visible=True), # raymap_option
gr.update(visible=True), # post_reconstruction
gr.update(value=3.0), # guidance_scale
)
elif task == "planning":
return (
gr.update(visible=False), # video_input
gr.update(visible=True), # image_input
gr.update(visible=True), # goal_input
gr.update(visible=True), # image_preview
gr.update(visible=True), # goal_preview
gr.update(value=50), # num_inference_steps
gr.update(visible=False), # sliding_window_stride
gr.update(visible=True), # use_dynamic_cfg
gr.update(visible=False), # raymap_option
gr.update(visible=True), # post_reconstruction
gr.update(value=3.0), # guidance_scale
)
def update_image_preview(image_file):
"""Update the image preview."""
if image_file:
return image_file.name
return None
def update_goal_preview(goal_file):
"""Update the goal preview."""
if goal_file:
return goal_file.name
return None
def get_download_link(selected_frame, all_paths):
"""Update the download button with the selected file path."""
if not selected_frame or not all_paths:
return gr.update(visible=False, value=None)
frame_num = int(re.search(r"Frame (\d+)", selected_frame).group(1))
for path in all_paths:
if f"frame_{frame_num}" in path:
# Make sure the file exists before setting it
if os.path.exists(path):
return gr.update(visible=True, value=path, interactive=True)
return gr.update(visible=False, value=None)
# Theme setup
theme = gr.themes.Default(
primary_hue="blue",
secondary_hue="cyan",
)
with gr.Blocks(
theme=theme,
css="""
.output-column {
min-height: 400px;
}
.warning {
color: #ff9800;
font-weight: bold;
}
.highlight {
background-color: rgba(0, 123, 255, 0.1);
padding: 10px;
border-radius: 8px;
border-left: 5px solid #007bff;
margin: 10px 0;
}
.task-header {
margin-top: 10px;
margin-bottom: 15px;
font-size: 1.2em;
font-weight: bold;
color: #007bff;
}
.flex-display {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
.output-subtitle {
font-size: 1.1em;
margin-top: 5px;
margin-bottom: 5px;
color: #505050;
}
.input-section, .params-section, .advanced-section {
border: 1px solid #ddd;
padding: 15px;
border-radius: 8px;
margin-bottom: 15px;
}
.logo-container {
display: flex;
justify-content: center;
margin-bottom: 20px;
}
.logo-image {
max-width: 300px;
height: auto;
}
""",
) as demo:
with gr.Row(elem_classes=["logo-container"]):
gr.Image("assets/logo.png", show_label=False, elem_classes=["logo-image"])
gr.Markdown(
"""
# Aether: Geometric-Aware Unified World Modeling
Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with
generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities:
1. **4D dynamic reconstruction** - Reconstruct dynamic point clouds from videos by estimating depths and camera poses.
2. **Action-Conditioned Video Prediction** - Predict future frames based on initial observation images, with optional conditions of camera trajectory actions.
3. **Goal-Conditioned Visual Planning** - Generate planning paths from pairs of observation and goal images.
Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.
"""
)
with gr.Row():
with gr.Column(scale=1):
task = gr.Radio(
["reconstruction", "prediction", "planning"],
label="Select Task",
value="reconstruction",
info="Choose the task you want to perform",
)
with gr.Group(elem_classes=["input-section"]):
# Input section - changes based on task
gr.Markdown("## 📥 Input", elem_classes=["task-header"])
# Task-specific inputs
video_input = gr.Video(
label="Upload Input Video",
sources=["upload"],
visible=True,
interactive=True,
elem_id="video_input",
)
image_input = gr.File(
label="Upload Start Image",
file_count="single",
file_types=["image"],
visible=False,
interactive=True,
elem_id="image_input",
)
goal_input = gr.File(
label="Upload Goal Image",
file_count="single",
file_types=["image"],
visible=False,
interactive=True,
elem_id="goal_input",
)
with gr.Row(visible=False) as preview_row:
image_preview = gr.Image(
label="Start Image Preview",
elem_id="image_preview",
visible=False,
)
goal_preview = gr.Image(
label="Goal Image Preview",
elem_id="goal_preview",
visible=False,
)
with gr.Group(elem_classes=["params-section"]):
gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"])
with gr.Row():
with gr.Column(scale=1):
height = gr.Dropdown(
choices=[480],
value=480,
label="Height",
info="Height of the output video",
)
with gr.Column(scale=1):
width = gr.Dropdown(
choices=[720],
value=720,
label="Width",
info="Width of the output video",
)
with gr.Row():
with gr.Column(scale=1):
num_frames = gr.Dropdown(
choices=[17, 25, 33, 41],
value=41,
label="Number of Frames",
info="Number of frames to predict",
)
with gr.Column(scale=1):
fps = gr.Dropdown(
choices=[8, 10, 12, 15, 24],
value=12,
label="FPS",
info="Frames per second",
)
with gr.Row():
with gr.Column(scale=1):
num_inference_steps = gr.Slider(
minimum=1,
maximum=60,
value=4,
step=1,
label="Inference Steps",
info="Number of inference step",
)
sliding_window_stride = gr.Slider(
minimum=1,
maximum=40,
value=24,
step=1,
label="Sliding Window Stride",
info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task",
visible=True,
)
use_dynamic_cfg = gr.Checkbox(
label="Use Dynamic CFG",
value=True,
info="Use dynamic CFG",
visible=False,
)
raymap_option = gr.Radio(
choices=["backward", "forward_right", "left_forward", "right"],
label="Camera Movement Direction",
value="forward_right",
info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.",
visible=False,
)
post_reconstruction = gr.Checkbox(
label="Post-Reconstruction",
value=True,
info="Run reconstruction after prediction for better quality",
visible=False,
)
with gr.Accordion(
"Advanced Options", open=False, visible=True
) as advanced_options:
with gr.Group(elem_classes=["advanced-section"]):
with gr.Row():
with gr.Column(scale=1):
guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
value=1.0,
step=0.1,
label="Guidance Scale",
info="Guidance scale (only for prediction / planning)",
)
with gr.Row():
with gr.Column(scale=1):
seed = gr.Number(
value=42,
label="Random Seed",
info="Set a seed for reproducible results",
precision=0,
minimum=0,
maximum=2147483647,
)
with gr.Row():
with gr.Column(scale=1):
smooth_camera = gr.Checkbox(
label="Smooth Camera",
value=True,
info="Apply smoothing to camera trajectory",
)
with gr.Column(scale=1):
align_pointmaps = gr.Checkbox(
label="Align Point Maps",
value=False,
info="Align point maps across frames",
)
with gr.Row():
with gr.Column(scale=1):
max_depth = gr.Slider(
minimum=10,
maximum=200,
value=60,
step=10,
label="Max Depth",
info="Maximum depth for point cloud (higher = more distant points)",
)
with gr.Column(scale=1):
rtol = gr.Slider(
minimum=0.01,
maximum=2.0,
value=0.03,
step=0.01,
label="Relative Tolerance",
info="Used for depth edge detection. Lower = remove more edges",
)
pointcloud_save_frame_interval = gr.Slider(
minimum=1,
maximum=20,
value=10,
step=1,
label="Point Cloud Frame Interval",
info="Save point cloud every N frames (higher = fewer files but less complete representation)",
)
run_button = gr.Button("Run Aether", variant="primary")
with gr.Column(scale=1, elem_classes=["output-column"]):
with gr.Group():
gr.Markdown("## 📤 Output", elem_classes=["task-header"])
gr.Markdown("### RGB Video", elem_classes=["output-subtitle"])
rgb_output = gr.Video(
label="RGB Output", interactive=False, elem_id="rgb_output"
)
gr.Markdown("### Depth Video", elem_classes=["output-subtitle"])
depth_output = gr.Video(
label="Depth Output", interactive=False, elem_id="depth_output"
)
gr.Markdown("### Point Clouds", elem_classes=["output-subtitle"])
with gr.Row(elem_classes=["flex-display"]):
pointcloud_frames = gr.Dropdown(
label="Select Frame",
choices=[],
value=None,
interactive=True,
elem_id="pointcloud_frames",
)
pointcloud_download = gr.DownloadButton(
label="Download Point Cloud",
visible=False,
elem_id="pointcloud_download",
)
model_output = gr.Model3D(
label="Point Cloud Viewer", interactive=True, elem_id="model_output"
)
with gr.Tab("About Results"):
gr.Markdown(
"""
### Understanding the Outputs
- **RGB Video**: Shows the predicted or reconstructed RGB frames
- **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue)
- **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids
<p class="warning">Note: 3D point clouds take a long time to visualize, and we show the keyframes only.
You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`.</p>
"""
)
# Event handlers
task.change(
fn=update_task_ui,
inputs=[task],
outputs=[
video_input,
image_input,
goal_input,
image_preview,
goal_preview,
num_inference_steps,
sliding_window_stride,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
guidance_scale,
],
)
image_input.change(
fn=update_image_preview, inputs=[image_input], outputs=[image_preview]
).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
goal_input.change(
fn=update_goal_preview, inputs=[goal_input], outputs=[goal_preview]
).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
def update_pointcloud_frames(pointcloud_paths):
"""Update the pointcloud frames dropdown with available frames."""
if not pointcloud_paths:
return gr.update(choices=[], value=None), None, gr.update(visible=False)
# Extract frame numbers from filenames
frame_info = []
for path in pointcloud_paths:
filename = os.path.basename(path)
match = re.search(r"frame_(\d+)", filename)
if match:
frame_num = int(match.group(1))
frame_info.append((f"Frame {frame_num}", path))
# Sort by frame number
frame_info.sort(key=lambda x: int(re.search(r"Frame (\d+)", x[0]).group(1)))
choices = [label for label, _ in frame_info]
paths = [path for _, path in frame_info]
if not choices:
return gr.update(choices=[], value=None), None, gr.update(visible=False)
# Make download button visible when we have point cloud files
return (
gr.update(choices=choices, value=choices[0]),
paths[0],
gr.update(visible=True),
)
def select_pointcloud_frame(frame_label, all_paths):
"""Select a specific pointcloud frame."""
if not frame_label or not all_paths:
return None
frame_num = int(re.search(r"Frame (\d+)", frame_label).group(1))
for path in all_paths:
if f"frame_{frame_num}" in path:
return path
return None
# Then in the run button click handler:
def process_task(task_type, *args):
"""Process selected task with appropriate function."""
if task_type == "reconstruction":
rgb_path, depth_path, pointcloud_paths = process_reconstruction(*args)
# Update the pointcloud frames dropdown
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
pointcloud_paths
)
return (
rgb_path,
depth_path,
initial_path,
frame_dropdown,
pointcloud_paths,
download_visible,
)
elif task_type == "prediction":
rgb_path, depth_path, pointcloud_paths = process_prediction(*args)
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
pointcloud_paths
)
return (
rgb_path,
depth_path,
initial_path,
frame_dropdown,
pointcloud_paths,
download_visible,
)
elif task_type == "planning":
rgb_path, depth_path, pointcloud_paths = process_planning(*args)
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
pointcloud_paths
)
return (
rgb_path,
depth_path,
initial_path,
frame_dropdown,
pointcloud_paths,
download_visible,
)
return (
None,
None,
None,
gr.update(choices=[], value=None),
[],
gr.update(visible=False),
)
# Store all pointcloud paths for later use
all_pointcloud_paths = gr.State([])
run_button.click(
fn=lambda task_type,
video_file,
image_file,
goal_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
sliding_window_stride,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed: process_task(
task_type,
*(
[
video_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
sliding_window_stride,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
]
if task_type == "reconstruction"
else [
image_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
]
if task_type == "prediction"
else [
image_file,
goal_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
use_dynamic_cfg,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
]
),
),
inputs=[
task,
video_input,
image_input,
goal_input,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
sliding_window_stride,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
],
outputs=[
rgb_output,
depth_output,
model_output,
pointcloud_frames,
all_pointcloud_paths,
pointcloud_download,
],
)
pointcloud_frames.change(
fn=select_pointcloud_frame,
inputs=[pointcloud_frames, all_pointcloud_paths],
outputs=[model_output],
).then(
fn=get_download_link,
inputs=[pointcloud_frames, all_pointcloud_paths],
outputs=[pointcloud_download],
)
# Example Accordion
with gr.Accordion("Examples"):
gr.Markdown(
"""
### Examples will be added soon
Check back for example inputs for each task type.
"""
)
# Load the model at startup
demo.load(lambda: build_pipeline(torch.device("cpu")), inputs=None, outputs=None)
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
demo.queue(max_size=20).launch(show_error=True, share=True)