import gc import os import random import re from datetime import datetime from typing import Dict, List, Optional, Tuple import gradio as gr import imageio.v3 as iio import numpy as np import PIL import rootutils import torch from diffusers import ( AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXTransformer3DModel, ) from transformers import AutoTokenizer, T5EncoderModel import spaces rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402 AetherV1PipelineCogVideoX, AetherV1PipelineOutput, ) from aether.utils.postprocess_utils import ( # noqa: E402 align_camera_extrinsics, apply_transformation, colorize_depth, compute_scale, get_intrinsics, interpolate_poses, postprocess_pointmap, project, raymap_to_poses, ) from aether.utils.visualize_utils import predictions_to_glb # noqa: E402 def seed_all(seed: int = 0) -> None: """ Set random seeds of all components. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # # Global pipeline cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V" aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1" pipeline = AetherV1PipelineCogVideoX( tokenizer=AutoTokenizer.from_pretrained( cogvideox_pretrained_model_name_or_path, subfolder="tokenizer", ), text_encoder=T5EncoderModel.from_pretrained( cogvideox_pretrained_model_name_or_path, subfolder="text_encoder" ), vae=AutoencoderKLCogVideoX.from_pretrained( cogvideox_pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch.bfloat16 ), scheduler=CogVideoXDPMScheduler.from_pretrained( cogvideox_pretrained_model_name_or_path, subfolder="scheduler" ), transformer=CogVideoXTransformer3DModel.from_pretrained( aether_pretrained_model_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16 ), ) pipeline.vae.enable_slicing() pipeline.vae.enable_tiling() # pipeline.to(device) def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX: """Initialize the model pipeline.""" # cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V" # aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1" # pipeline = AetherV1PipelineCogVideoX( # tokenizer=AutoTokenizer.from_pretrained( # cogvideox_pretrained_model_name_or_path, # subfolder="tokenizer", # ), # text_encoder=T5EncoderModel.from_pretrained( # cogvideox_pretrained_model_name_or_path, subfolder="text_encoder" # ), # vae=AutoencoderKLCogVideoX.from_pretrained( # cogvideox_pretrained_model_name_or_path, subfolder="vae" # ), # scheduler=CogVideoXDPMScheduler.from_pretrained( # cogvideox_pretrained_model_name_or_path, subfolder="scheduler" # ), # transformer=CogVideoXTransformer3DModel.from_pretrained( # aether_pretrained_model_name_or_path, subfolder="transformer" # ), # ) # pipeline.vae.enable_slicing() # pipeline.vae.enable_tiling() pipeline.to(device) return pipeline def get_window_starts( total_frames: int, sliding_window_size: int, temporal_stride: int ) -> List[int]: """Calculate window start indices.""" starts = list( range( 0, total_frames - sliding_window_size + 1, temporal_stride, ) ) if ( total_frames > sliding_window_size and (total_frames - sliding_window_size) % temporal_stride != 0 ): starts.append(total_frames - sliding_window_size) return starts def blend_and_merge_window_results( window_results: List[AetherV1PipelineOutput], window_indices: List[int], args: Dict ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Blend and merge window results.""" merged_rgb = None merged_disparity = None merged_poses = None merged_focals = None align_pointmaps = args.get("align_pointmaps", True) smooth_camera = args.get("smooth_camera", True) smooth_method = args.get("smooth_method", "kalman") if smooth_camera else "none" if align_pointmaps: merged_pointmaps = None w1 = window_results[0].disparity for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)): t_end = t_start + window_result.rgb.shape[0] if idx == 0: merged_rgb = window_result.rgb merged_disparity = window_result.disparity pointmap_dict = postprocess_pointmap( window_result.disparity, window_result.raymap, vae_downsample_scale=8, ray_o_scale_inv=0.1, smooth_camera=smooth_camera, smooth_method=smooth_method if smooth_camera else "none", ) merged_poses = pointmap_dict["camera_pose"] merged_focals = ( pointmap_dict["intrinsics"][:, 0, 0] + pointmap_dict["intrinsics"][:, 1, 1] ) / 2 if align_pointmaps: merged_pointmaps = pointmap_dict["pointmap"] else: overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start window_disparity = window_result.disparity # Align disparity disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1 scale = compute_scale( window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]), merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]), disp_mask.reshape(1, -1, w1.shape[-1]), ) window_disparity = scale * window_disparity # Blend disparity result_disparity = np.ones((t_end, *w1.shape[1:])) result_disparity[:t_start] = merged_disparity[:t_start] result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:] weight = np.linspace(1, 0, overlap_t)[:, None, None] result_disparity[t_start : t_start + overlap_t] = merged_disparity[ t_start : t_start + overlap_t ] * weight + window_disparity[:overlap_t] * (1 - weight) merged_disparity = result_disparity # Blend RGB result_rgb = np.ones((t_end, *w1.shape[1:], 3)) result_rgb[:t_start] = merged_rgb[:t_start] result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:] weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None] result_rgb[t_start : t_start + overlap_t] = merged_rgb[ t_start : t_start + overlap_t ] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb) merged_rgb = result_rgb # Align poses window_raymap = window_result.raymap window_poses, window_Fov_x, window_Fov_y = raymap_to_poses( window_raymap, ray_o_scale_inv=0.1 ) rel_r, rel_t, rel_s = align_camera_extrinsics( torch.from_numpy(window_poses[:overlap_t]), torch.from_numpy(merged_poses[-overlap_t:]), ) aligned_window_poses = ( apply_transformation( torch.from_numpy(window_poses), rel_r, rel_t, rel_s, return_extri=True, ) .cpu() .numpy() ) result_poses = np.ones((t_end, 4, 4)) result_poses[:t_start] = merged_poses[:t_start] result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:] # Interpolate poses in overlap region weights = np.linspace(1, 0, overlap_t) for t in range(overlap_t): weight = weights[t] pose1 = merged_poses[t_start + t] pose2 = aligned_window_poses[t] result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight) merged_poses = result_poses # Align intrinsics window_intrinsics, _ = get_intrinsics( batch_size=window_poses.shape[0], h=window_result.disparity.shape[1], w=window_result.disparity.shape[2], fovx=window_Fov_x, fovy=window_Fov_y, ) window_focals = ( window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1] ) / 2 scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean() window_focals = scale * window_focals result_focals = np.ones((t_end,)) result_focals[:t_start] = merged_focals[:t_start] result_focals[t_start + overlap_t :] = window_focals[overlap_t:] weight = np.linspace(1, 0, overlap_t) result_focals[t_start : t_start + overlap_t] = merged_focals[ t_start : t_start + overlap_t ] * weight + window_focals[:overlap_t] * (1 - weight) merged_focals = result_focals if align_pointmaps: # Align pointmaps window_pointmaps = postprocess_pointmap( result_disparity[t_start:], window_raymap, vae_downsample_scale=8, camera_pose=aligned_window_poses, focal=window_focals, ray_o_scale_inv=0.1, smooth_camera=smooth_camera, smooth_method=smooth_method if smooth_camera else "none", ) result_pointmaps = np.ones((t_end, *w1.shape[1:], 3)) result_pointmaps[:t_start] = merged_pointmaps[:t_start] result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][ overlap_t: ] weight = np.linspace(1, 0, overlap_t)[:, None, None, None] result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[ t_start : t_start + overlap_t ] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight) merged_pointmaps = result_pointmaps # project to pointmaps height = args.get("height", 480) width = args.get("width", 720) intrinsics = [ np.array([[f, 0, 0.5 * width], [0, f, 0.5 * height], [0, 0, 1]]) for f in merged_focals ] if align_pointmaps: pointmaps = merged_pointmaps else: pointmaps = np.stack( [ project( 1 / np.clip(merged_disparity[i], 1e-8, 1e8), intrinsics[i], merged_poses[i], ) for i in range(merged_poses.shape[0]) ] ) return merged_rgb, merged_disparity, merged_poses, pointmaps def process_video_to_frames(video_path: str, fps_sample: int = 12) -> List[str]: """Process video into frames and save them locally.""" # Create a unique output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = f"temp_frames_{timestamp}" os.makedirs(output_dir, exist_ok=True) # Read video video = iio.imread(video_path) # Calculate frame interval based on original video fps if isinstance(video, np.ndarray): # For captured videos total_frames = len(video) frame_interval = max( 1, round(total_frames / (fps_sample * (total_frames / 30))) ) else: # Default if can't determine frame_interval = 2 frame_paths = [] for i, frame in enumerate(video[::frame_interval]): frame_path = os.path.join(output_dir, f"frame_{i:04d}.jpg") if isinstance(frame, np.ndarray): iio.imwrite(frame_path, frame) frame_paths.append(frame_path) return frame_paths, output_dir def save_output_files( rgb: np.ndarray, disparity: np.ndarray, poses: Optional[np.ndarray] = None, raymap: Optional[np.ndarray] = None, pointmap: Optional[np.ndarray] = None, task: str = "reconstruction", output_dir: str = "outputs", **kwargs, ) -> Dict[str, str]: """ Save outputs and return paths to saved files. """ os.makedirs(output_dir, exist_ok=True) if pointmap is None and raymap is not None: # Generate pointmap from raymap and disparity smooth_camera = kwargs.get("smooth_camera", True) smooth_method = ( kwargs.get("smooth_method", "kalman") if smooth_camera else "none" ) pointmap_dict = postprocess_pointmap( disparity, raymap, vae_downsample_scale=8, ray_o_scale_inv=0.1, smooth_camera=smooth_camera, smooth_method=smooth_method, ) pointmap = pointmap_dict["pointmap"] if poses is None and raymap is not None: poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1) # Create a unique filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") base_filename = f"{task}_{timestamp}" # Paths for saved files paths = {} # Save RGB video rgb_path = os.path.join(output_dir, f"{base_filename}_rgb.mp4") iio.imwrite( rgb_path, (np.clip(rgb, 0, 1) * 255).astype(np.uint8), fps=kwargs.get("fps", 12), ) paths["rgb"] = rgb_path # Save depth/disparity video depth_path = os.path.join(output_dir, f"{base_filename}_disparity.mp4") iio.imwrite( depth_path, (colorize_depth(disparity) * 255).astype(np.uint8), fps=kwargs.get("fps", 12), ) paths["disparity"] = depth_path # Save point cloud GLB files if pointmap is not None and poses is not None: pointcloud_save_frame_interval = kwargs.get( "pointcloud_save_frame_interval", 10 ) max_depth = kwargs.get("max_depth", 100.0) rtol = kwargs.get("rtol", 0.03) glb_paths = [] # Determine which frames to save based on the interval frames_to_save = list( range(0, pointmap.shape[0], pointcloud_save_frame_interval) ) # Always include the first and last frame if 0 not in frames_to_save: frames_to_save.insert(0, 0) if pointmap.shape[0] - 1 not in frames_to_save: frames_to_save.append(pointmap.shape[0] - 1) # Sort the frames to ensure they're in order frames_to_save = sorted(set(frames_to_save)) for frame_idx in frames_to_save: if frame_idx >= pointmap.shape[0]: continue predictions = { "world_points": pointmap[frame_idx : frame_idx + 1], "images": rgb[frame_idx : frame_idx + 1], "depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8), "camera_poses": poses[frame_idx : frame_idx + 1], } glb_path = os.path.join( output_dir, f"{base_filename}_pointcloud_frame_{frame_idx}.glb" ) scene_3d = predictions_to_glb( predictions, filter_by_frames="all", show_cam=True, max_depth=max_depth, rtol=rtol, frame_rel_idx=float(frame_idx) / pointmap.shape[0], ) scene_3d.export(glb_path) glb_paths.append(glb_path) paths["pointcloud_glbs"] = glb_paths return paths @spaces.GPU(duration=300) def process_reconstruction( video_file, height, width, num_frames, num_inference_steps, guidance_scale, sliding_window_stride, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, progress=gr.Progress(), ): """ Process reconstruction task. """ try: gc.collect() torch.cuda.empty_cache() # 设置随机种子 seed_all(seed) # 检查CUDA是否可用 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): raise ValueError("CUDA is not available. Check your environment.") pipeline = build_pipeline(device) progress(0.1, "Loading video") # Check if video_file is a string or a file object if isinstance(video_file, str): video_path = video_file else: video_path = video_file.name video = iio.imread(video_path).astype(np.float32) / 255.0 # Setup arguments args = { "height": height, "width": width, "num_frames": num_frames, "sliding_window_stride": sliding_window_stride, "smooth_camera": smooth_camera, "smooth_method": "kalman" if smooth_camera else "none", "align_pointmaps": align_pointmaps, "max_depth": max_depth, "rtol": rtol, "pointcloud_save_frame_interval": pointcloud_save_frame_interval, } # Process in sliding windows window_results = [] window_indices = get_window_starts( len(video), num_frames, sliding_window_stride ) progress(0.2, f"Processing video in {len(window_indices)} windows") for i, start_idx in enumerate(window_indices): progress_val = 0.2 + (0.6 * (i / len(window_indices))) progress(progress_val, f"Processing window {i+1}/{len(window_indices)}") output = pipeline( task="reconstruction", image=None, goal=None, video=video[start_idx : start_idx + num_frames], raymap=None, height=height, width=width, num_frames=num_frames, fps=fps, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, use_dynamic_cfg=False, generator=torch.Generator(device=device).manual_seed(seed), ) window_results.append(output) progress(0.8, "Merging results from all windows") # Merge window results ( merged_rgb, merged_disparity, merged_poses, pointmaps, ) = blend_and_merge_window_results(window_results, window_indices, args) progress(0.9, "Saving output files") # Save output files output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) output_paths = save_output_files( rgb=merged_rgb, disparity=merged_disparity, poses=merged_poses, pointmap=pointmaps, task="reconstruction", output_dir=output_dir, fps=12, **args, ) progress(1.0, "Done!") # Return paths for displaying return ( output_paths["rgb"], output_paths["disparity"], output_paths.get("pointcloud_glbs", []), ) except Exception: import traceback traceback.print_exc() return None, None, [] @spaces.GPU(duration=240) def process_prediction( image_file, height, width, num_frames, num_inference_steps, guidance_scale, use_dynamic_cfg, raymap_option, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, progress=gr.Progress(), ): """ Process prediction task. """ try: gc.collect() torch.cuda.empty_cache() # Set random seed seed_all(seed) # Build the pipeline pipeline = build_pipeline(device) progress(0.1, "Loading image") # Check if image_file is a string or a file object if isinstance(image_file, str): image_path = image_file else: image_path = image_file.name image = PIL.Image.open(image_path) progress(0.2, "Running prediction") # Run prediction output = pipeline( task="prediction", image=image, video=None, goal=None, raymap=np.load(f"assets/example_raymaps/raymap_{raymap_option}.npy"), height=height, width=width, num_frames=num_frames, fps=fps, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, use_dynamic_cfg=use_dynamic_cfg, generator=torch.Generator(device=device).manual_seed(seed), return_dict=True, ) # Show RGB output immediately rgb_output = output.rgb # Setup arguments for saving args = { "height": height, "width": width, "smooth_camera": smooth_camera, "smooth_method": "kalman" if smooth_camera else "none", "align_pointmaps": align_pointmaps, "max_depth": max_depth, "rtol": rtol, "pointcloud_save_frame_interval": pointcloud_save_frame_interval, } if post_reconstruction: progress(0.5, "Running post-reconstruction for better quality") recon_output = pipeline( task="reconstruction", video=output.rgb, height=height, width=width, num_frames=num_frames, fps=fps, num_inference_steps=4, guidance_scale=1.0, use_dynamic_cfg=False, generator=torch.Generator(device=device).manual_seed(seed), ) disparity = recon_output.disparity raymap = recon_output.raymap else: disparity = output.disparity raymap = output.raymap progress(0.8, "Saving output files") # Save output files output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) output_paths = save_output_files( rgb=rgb_output, disparity=disparity, raymap=raymap, task="prediction", output_dir=output_dir, fps=12, **args, ) progress(1.0, "Done!") # Return paths for displaying return ( output_paths["rgb"], output_paths["disparity"], output_paths.get("pointcloud_glbs", []), ) except Exception: import traceback traceback.print_exc() return None, None, [] @spaces.GPU(duration=240) def process_planning( image_file, goal_file, height, width, num_frames, num_inference_steps, guidance_scale, use_dynamic_cfg, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, progress=gr.Progress(), ): """ Process planning task. """ try: gc.collect() torch.cuda.empty_cache() # Set random seed seed_all(seed) # Check if CUDA is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): raise ValueError("CUDA is not available. Check your environment.") # Build the pipeline pipeline = build_pipeline(device) progress(0.1, "Loading images") # Check if image_file and goal_file are strings or file objects if isinstance(image_file, str): image_path = image_file else: image_path = image_file.name if isinstance(goal_file, str): goal_path = goal_file else: goal_path = goal_file.name image = PIL.Image.open(image_path) goal = PIL.Image.open(goal_path) progress(0.2, "Running planning") # Run planning output = pipeline( task="planning", image=image, video=None, goal=goal, raymap=None, height=height, width=width, num_frames=num_frames, fps=fps, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, use_dynamic_cfg=use_dynamic_cfg, generator=torch.Generator(device=device).manual_seed(seed), return_dict=True, ) # Show RGB output immediately rgb_output = output.rgb # Setup arguments for saving args = { "height": height, "width": width, "smooth_camera": smooth_camera, "smooth_method": "kalman" if smooth_camera else "none", "align_pointmaps": align_pointmaps, "max_depth": max_depth, "rtol": rtol, "pointcloud_save_frame_interval": pointcloud_save_frame_interval, } if post_reconstruction: progress(0.5, "Running post-reconstruction for better quality") recon_output = pipeline( task="reconstruction", video=output.rgb, height=height, width=width, num_frames=num_frames, fps=12, num_inference_steps=4, guidance_scale=1.0, use_dynamic_cfg=False, generator=torch.Generator(device=device).manual_seed(seed), ) disparity = recon_output.disparity raymap = recon_output.raymap else: disparity = output.disparity raymap = output.raymap progress(0.8, "Saving output files") # Save output files output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) output_paths = save_output_files( rgb=rgb_output, disparity=disparity, raymap=raymap, task="planning", output_dir=output_dir, fps=fps, **args, ) progress(1.0, "Done!") # Return paths for displaying return ( output_paths["rgb"], output_paths["disparity"], output_paths.get("pointcloud_glbs", []), ) except Exception: import traceback traceback.print_exc() return None, None, [] def update_task_ui(task): """Update UI elements based on selected task.""" if task == "reconstruction": return ( gr.update(visible=True), # video_input gr.update(visible=False), # image_input gr.update(visible=False), # goal_input gr.update(visible=False), # image_preview gr.update(visible=False), # goal_preview gr.update(value=4), # num_inference_steps gr.update(visible=True), # sliding_window_stride gr.update(visible=False), # use_dynamic_cfg gr.update(visible=False), # raymap_option gr.update(visible=False), # post_reconstruction gr.update(value=1.0), # guidance_scale ) elif task == "prediction": return ( gr.update(visible=False), # video_input gr.update(visible=True), # image_input gr.update(visible=False), # goal_input gr.update(visible=True), # image_preview gr.update(visible=False), # goal_preview gr.update(value=50), # num_inference_steps gr.update(visible=False), # sliding_window_stride gr.update(visible=True), # use_dynamic_cfg gr.update(visible=True), # raymap_option gr.update(visible=True), # post_reconstruction gr.update(value=3.0), # guidance_scale ) elif task == "planning": return ( gr.update(visible=False), # video_input gr.update(visible=True), # image_input gr.update(visible=True), # goal_input gr.update(visible=True), # image_preview gr.update(visible=True), # goal_preview gr.update(value=50), # num_inference_steps gr.update(visible=False), # sliding_window_stride gr.update(visible=True), # use_dynamic_cfg gr.update(visible=False), # raymap_option gr.update(visible=True), # post_reconstruction gr.update(value=3.0), # guidance_scale ) def update_image_preview(image_file): """Update the image preview.""" if image_file: return image_file.name return None def update_goal_preview(goal_file): """Update the goal preview.""" if goal_file: return goal_file.name return None def get_download_link(selected_frame, all_paths): """Update the download button with the selected file path.""" if not selected_frame or not all_paths: return gr.update(visible=False, value=None) frame_num = int(re.search(r"Frame (\d+)", selected_frame).group(1)) for path in all_paths: if f"frame_{frame_num}" in path: # Make sure the file exists before setting it if os.path.exists(path): return gr.update(visible=True, value=path, interactive=True) return gr.update(visible=False, value=None) # Theme setup theme = gr.themes.Default( primary_hue="blue", secondary_hue="cyan", ) with gr.Blocks( theme=theme, css=""" .output-column { min-height: 400px; } .warning { color: #ff9800; font-weight: bold; } .highlight { background-color: rgba(0, 123, 255, 0.1); padding: 10px; border-radius: 8px; border-left: 5px solid #007bff; margin: 10px 0; } .task-header { margin-top: 10px; margin-bottom: 15px; font-size: 1.2em; font-weight: bold; color: #007bff; } .flex-display { display: flex; flex-wrap: wrap; gap: 10px; } .output-subtitle { font-size: 1.1em; margin-top: 5px; margin-bottom: 5px; color: #505050; } .input-section, .params-section, .advanced-section { border: 1px solid #ddd; padding: 15px; border-radius: 8px; margin-bottom: 15px; } .logo-container { display: flex; justify-content: center; margin-bottom: 20px; } .logo-image { max-width: 300px; height: auto; } """, ) as demo: with gr.Row(elem_classes=["logo-container"]): gr.Image("assets/logo.png", show_label=False, elem_classes=["logo-image"]) gr.Markdown( """ # Aether: Geometric-Aware Unified World Modeling Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities: 1. **4D dynamic reconstruction** - Reconstruct dynamic point clouds from videos by estimating depths and camera poses. 2. **Action-Conditioned Video Prediction** - Predict future frames based on initial observation images, with optional conditions of camera trajectory actions. 3. **Goal-Conditioned Visual Planning** - Generate planning paths from pairs of observation and goal images. Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios. """ ) with gr.Row(): with gr.Column(scale=1): task = gr.Radio( ["reconstruction", "prediction", "planning"], label="Select Task", value="reconstruction", info="Choose the task you want to perform", ) with gr.Group(elem_classes=["input-section"]): # Input section - changes based on task gr.Markdown("## 📥 Input", elem_classes=["task-header"]) # Task-specific inputs video_input = gr.Video( label="Upload Input Video", sources=["upload"], visible=True, interactive=True, elem_id="video_input", ) image_input = gr.File( label="Upload Start Image", file_count="single", file_types=["image"], visible=False, interactive=True, elem_id="image_input", ) goal_input = gr.File( label="Upload Goal Image", file_count="single", file_types=["image"], visible=False, interactive=True, elem_id="goal_input", ) with gr.Row(visible=False) as preview_row: image_preview = gr.Image( label="Start Image Preview", elem_id="image_preview", visible=False, ) goal_preview = gr.Image( label="Goal Image Preview", elem_id="goal_preview", visible=False, ) with gr.Group(elem_classes=["params-section"]): gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"]) with gr.Row(): with gr.Column(scale=1): height = gr.Dropdown( choices=[480], value=480, label="Height", info="Height of the output video", ) with gr.Column(scale=1): width = gr.Dropdown( choices=[720], value=720, label="Width", info="Width of the output video", ) with gr.Row(): with gr.Column(scale=1): num_frames = gr.Dropdown( choices=[17, 25, 33, 41], value=41, label="Number of Frames", info="Number of frames to predict", ) with gr.Column(scale=1): fps = gr.Dropdown( choices=[8, 10, 12, 15, 24], value=12, label="FPS", info="Frames per second", ) with gr.Row(): with gr.Column(scale=1): num_inference_steps = gr.Slider( minimum=1, maximum=60, value=4, step=1, label="Inference Steps", info="Number of inference step", ) sliding_window_stride = gr.Slider( minimum=1, maximum=40, value=24, step=1, label="Sliding Window Stride", info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task", visible=True, ) use_dynamic_cfg = gr.Checkbox( label="Use Dynamic CFG", value=True, info="Use dynamic CFG", visible=False, ) raymap_option = gr.Radio( choices=["backward", "forward_right", "left_forward", "right"], label="Camera Movement Direction", value="forward_right", info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.", visible=False, ) post_reconstruction = gr.Checkbox( label="Post-Reconstruction", value=True, info="Run reconstruction after prediction for better quality", visible=False, ) with gr.Accordion( "Advanced Options", open=False, visible=True ) as advanced_options: with gr.Group(elem_classes=["advanced-section"]): with gr.Row(): with gr.Column(scale=1): guidance_scale = gr.Slider( minimum=1.0, maximum=10.0, value=1.0, step=0.1, label="Guidance Scale", info="Guidance scale (only for prediction / planning)", ) with gr.Row(): with gr.Column(scale=1): seed = gr.Number( value=42, label="Random Seed", info="Set a seed for reproducible results", precision=0, minimum=0, maximum=2147483647, ) with gr.Row(): with gr.Column(scale=1): smooth_camera = gr.Checkbox( label="Smooth Camera", value=True, info="Apply smoothing to camera trajectory", ) with gr.Column(scale=1): align_pointmaps = gr.Checkbox( label="Align Point Maps", value=False, info="Align point maps across frames", ) with gr.Row(): with gr.Column(scale=1): max_depth = gr.Slider( minimum=10, maximum=200, value=60, step=10, label="Max Depth", info="Maximum depth for point cloud (higher = more distant points)", ) with gr.Column(scale=1): rtol = gr.Slider( minimum=0.01, maximum=2.0, value=0.03, step=0.01, label="Relative Tolerance", info="Used for depth edge detection. Lower = remove more edges", ) pointcloud_save_frame_interval = gr.Slider( minimum=1, maximum=20, value=10, step=1, label="Point Cloud Frame Interval", info="Save point cloud every N frames (higher = fewer files but less complete representation)", ) run_button = gr.Button("Run Aether", variant="primary") with gr.Column(scale=1, elem_classes=["output-column"]): with gr.Group(): gr.Markdown("## 📤 Output", elem_classes=["task-header"]) gr.Markdown("### RGB Video", elem_classes=["output-subtitle"]) rgb_output = gr.Video( label="RGB Output", interactive=False, elem_id="rgb_output" ) gr.Markdown("### Depth Video", elem_classes=["output-subtitle"]) depth_output = gr.Video( label="Depth Output", interactive=False, elem_id="depth_output" ) gr.Markdown("### Point Clouds", elem_classes=["output-subtitle"]) with gr.Row(elem_classes=["flex-display"]): pointcloud_frames = gr.Dropdown( label="Select Frame", choices=[], value=None, interactive=True, elem_id="pointcloud_frames", ) pointcloud_download = gr.DownloadButton( label="Download Point Cloud", visible=False, elem_id="pointcloud_download", ) model_output = gr.Model3D( label="Point Cloud Viewer", interactive=True, elem_id="model_output" ) with gr.Tab("About Results"): gr.Markdown( """ ### Understanding the Outputs - **RGB Video**: Shows the predicted or reconstructed RGB frames - **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue) - **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids

Note: 3D point clouds take a long time to visualize, and we show the keyframes only. You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`.

""" ) # Event handlers task.change( fn=update_task_ui, inputs=[task], outputs=[ video_input, image_input, goal_input, image_preview, goal_preview, num_inference_steps, sliding_window_stride, use_dynamic_cfg, raymap_option, post_reconstruction, guidance_scale, ], ) image_input.change( fn=update_image_preview, inputs=[image_input], outputs=[image_preview] ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row]) goal_input.change( fn=update_goal_preview, inputs=[goal_input], outputs=[goal_preview] ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row]) def update_pointcloud_frames(pointcloud_paths): """Update the pointcloud frames dropdown with available frames.""" if not pointcloud_paths: return gr.update(choices=[], value=None), None, gr.update(visible=False) # Extract frame numbers from filenames frame_info = [] for path in pointcloud_paths: filename = os.path.basename(path) match = re.search(r"frame_(\d+)", filename) if match: frame_num = int(match.group(1)) frame_info.append((f"Frame {frame_num}", path)) # Sort by frame number frame_info.sort(key=lambda x: int(re.search(r"Frame (\d+)", x[0]).group(1))) choices = [label for label, _ in frame_info] paths = [path for _, path in frame_info] if not choices: return gr.update(choices=[], value=None), None, gr.update(visible=False) # Make download button visible when we have point cloud files return ( gr.update(choices=choices, value=choices[0]), paths[0], gr.update(visible=True), ) def select_pointcloud_frame(frame_label, all_paths): """Select a specific pointcloud frame.""" if not frame_label or not all_paths: return None frame_num = int(re.search(r"Frame (\d+)", frame_label).group(1)) for path in all_paths: if f"frame_{frame_num}" in path: return path return None # Then in the run button click handler: def process_task(task_type, *args): """Process selected task with appropriate function.""" if task_type == "reconstruction": rgb_path, depth_path, pointcloud_paths = process_reconstruction(*args) # Update the pointcloud frames dropdown frame_dropdown, initial_path, download_visible = update_pointcloud_frames( pointcloud_paths ) return ( rgb_path, depth_path, initial_path, frame_dropdown, pointcloud_paths, download_visible, ) elif task_type == "prediction": rgb_path, depth_path, pointcloud_paths = process_prediction(*args) frame_dropdown, initial_path, download_visible = update_pointcloud_frames( pointcloud_paths ) return ( rgb_path, depth_path, initial_path, frame_dropdown, pointcloud_paths, download_visible, ) elif task_type == "planning": rgb_path, depth_path, pointcloud_paths = process_planning(*args) frame_dropdown, initial_path, download_visible = update_pointcloud_frames( pointcloud_paths ) return ( rgb_path, depth_path, initial_path, frame_dropdown, pointcloud_paths, download_visible, ) return ( None, None, None, gr.update(choices=[], value=None), [], gr.update(visible=False), ) # Store all pointcloud paths for later use all_pointcloud_paths = gr.State([]) run_button.click( fn=lambda task_type, video_file, image_file, goal_file, height, width, num_frames, num_inference_steps, guidance_scale, sliding_window_stride, use_dynamic_cfg, raymap_option, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed: process_task( task_type, *( [ video_file, height, width, num_frames, num_inference_steps, guidance_scale, sliding_window_stride, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, ] if task_type == "reconstruction" else [ image_file, height, width, num_frames, num_inference_steps, guidance_scale, use_dynamic_cfg, raymap_option, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, ] if task_type == "prediction" else [ image_file, goal_file, height, width, num_frames, num_inference_steps, guidance_scale, use_dynamic_cfg, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, ] ), ), inputs=[ task, video_input, image_input, goal_input, height, width, num_frames, num_inference_steps, guidance_scale, sliding_window_stride, use_dynamic_cfg, raymap_option, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, ], outputs=[ rgb_output, depth_output, model_output, pointcloud_frames, all_pointcloud_paths, pointcloud_download, ], ) pointcloud_frames.change( fn=select_pointcloud_frame, inputs=[pointcloud_frames, all_pointcloud_paths], outputs=[model_output], ).then( fn=get_download_link, inputs=[pointcloud_frames, all_pointcloud_paths], outputs=[pointcloud_download], ) # Example Accordion with gr.Accordion("Examples"): gr.Markdown( """ ### Examples will be added soon Check back for example inputs for each task type. """ ) # Load the model at startup demo.load(lambda: build_pipeline(torch.device("cpu")), inputs=None, outputs=None) if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "false" demo.queue(max_size=20).launch(show_error=True, share=True)