import gc import os import random import re from datetime import datetime from typing import Dict, List, Optional, Tuple import gradio as gr import imageio.v3 as iio import numpy as np import PIL import rootutils import torch from diffusers import ( AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXTransformer3DModel, ) from transformers import AutoTokenizer, T5EncoderModel import spaces rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402 AetherV1PipelineCogVideoX, AetherV1PipelineOutput, ) from aether.utils.postprocess_utils import ( # noqa: E402 align_camera_extrinsics, apply_transformation, colorize_depth, compute_scale, get_intrinsics, interpolate_poses, postprocess_pointmap, project, raymap_to_poses, ) from aether.utils.visualize_utils import predictions_to_glb # noqa: E402 def seed_all(seed: int = 0) -> None: """ Set random seeds of all components. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # # Global pipeline cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V" aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1" pipeline = AetherV1PipelineCogVideoX( tokenizer=AutoTokenizer.from_pretrained( cogvideox_pretrained_model_name_or_path, subfolder="tokenizer", ), text_encoder=T5EncoderModel.from_pretrained( cogvideox_pretrained_model_name_or_path, subfolder="text_encoder" ), vae=AutoencoderKLCogVideoX.from_pretrained( cogvideox_pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch.bfloat16 ), scheduler=CogVideoXDPMScheduler.from_pretrained( cogvideox_pretrained_model_name_or_path, subfolder="scheduler" ), transformer=CogVideoXTransformer3DModel.from_pretrained( aether_pretrained_model_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16 ), ) pipeline.vae.enable_slicing() pipeline.vae.enable_tiling() def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX: """Initialize the model pipeline.""" pipeline.to(device) return pipeline def get_window_starts( total_frames: int, sliding_window_size: int, temporal_stride: int ) -> List[int]: """Calculate window start indices.""" starts = list( range( 0, total_frames - sliding_window_size + 1, temporal_stride, ) ) if ( total_frames > sliding_window_size and (total_frames - sliding_window_size) % temporal_stride != 0 ): starts.append(total_frames - sliding_window_size) return starts def blend_and_merge_window_results( window_results: List[AetherV1PipelineOutput], window_indices: List[int], args: Dict ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Blend and merge window results.""" merged_rgb = None merged_disparity = None merged_poses = None merged_focals = None align_pointmaps = args.get("align_pointmaps", True) smooth_camera = args.get("smooth_camera", True) smooth_method = args.get("smooth_method", "kalman") if smooth_camera else "none" if align_pointmaps: merged_pointmaps = None w1 = window_results[0].disparity for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)): t_end = t_start + window_result.rgb.shape[0] if idx == 0: merged_rgb = window_result.rgb merged_disparity = window_result.disparity pointmap_dict = postprocess_pointmap( window_result.disparity, window_result.raymap, vae_downsample_scale=8, ray_o_scale_inv=0.1, smooth_camera=smooth_camera, smooth_method=smooth_method if smooth_camera else "none", ) merged_poses = pointmap_dict["camera_pose"] merged_focals = ( pointmap_dict["intrinsics"][:, 0, 0] + pointmap_dict["intrinsics"][:, 1, 1] ) / 2 if align_pointmaps: merged_pointmaps = pointmap_dict["pointmap"] else: overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start window_disparity = window_result.disparity # Align disparity disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1 scale = compute_scale( window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]), merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]), disp_mask.reshape(1, -1, w1.shape[-1]), ) window_disparity = scale * window_disparity # Blend disparity result_disparity = np.ones((t_end, *w1.shape[1:])) result_disparity[:t_start] = merged_disparity[:t_start] result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:] weight = np.linspace(1, 0, overlap_t)[:, None, None] result_disparity[t_start : t_start + overlap_t] = merged_disparity[ t_start : t_start + overlap_t ] * weight + window_disparity[:overlap_t] * (1 - weight) merged_disparity = result_disparity # Blend RGB result_rgb = np.ones((t_end, *w1.shape[1:], 3)) result_rgb[:t_start] = merged_rgb[:t_start] result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:] weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None] result_rgb[t_start : t_start + overlap_t] = merged_rgb[ t_start : t_start + overlap_t ] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb) merged_rgb = result_rgb # Align poses window_raymap = window_result.raymap window_poses, window_Fov_x, window_Fov_y = raymap_to_poses( window_raymap, ray_o_scale_inv=0.1 ) rel_r, rel_t, rel_s = align_camera_extrinsics( torch.from_numpy(window_poses[:overlap_t]), torch.from_numpy(merged_poses[-overlap_t:]), ) aligned_window_poses = ( apply_transformation( torch.from_numpy(window_poses), rel_r, rel_t, rel_s, return_extri=True, ) .cpu() .numpy() ) result_poses = np.ones((t_end, 4, 4)) result_poses[:t_start] = merged_poses[:t_start] result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:] # Interpolate poses in overlap region weights = np.linspace(1, 0, overlap_t) for t in range(overlap_t): weight = weights[t] pose1 = merged_poses[t_start + t] pose2 = aligned_window_poses[t] result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight) merged_poses = result_poses # Align intrinsics window_intrinsics, _ = get_intrinsics( batch_size=window_poses.shape[0], h=window_result.disparity.shape[1], w=window_result.disparity.shape[2], fovx=window_Fov_x, fovy=window_Fov_y, ) window_focals = ( window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1] ) / 2 scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean() window_focals = scale * window_focals result_focals = np.ones((t_end,)) result_focals[:t_start] = merged_focals[:t_start] result_focals[t_start + overlap_t :] = window_focals[overlap_t:] weight = np.linspace(1, 0, overlap_t) result_focals[t_start : t_start + overlap_t] = merged_focals[ t_start : t_start + overlap_t ] * weight + window_focals[:overlap_t] * (1 - weight) merged_focals = result_focals if align_pointmaps: # Align pointmaps window_pointmaps = postprocess_pointmap( result_disparity[t_start:], window_raymap, vae_downsample_scale=8, camera_pose=aligned_window_poses, focal=window_focals, ray_o_scale_inv=0.1, smooth_camera=smooth_camera, smooth_method=smooth_method if smooth_camera else "none", ) result_pointmaps = np.ones((t_end, *w1.shape[1:], 3)) result_pointmaps[:t_start] = merged_pointmaps[:t_start] result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][ overlap_t: ] weight = np.linspace(1, 0, overlap_t)[:, None, None, None] result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[ t_start : t_start + overlap_t ] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight) merged_pointmaps = result_pointmaps # project to pointmaps height = args.get("height", 480) width = args.get("width", 720) intrinsics = [ np.array([[f, 0, 0.5 * width], [0, f, 0.5 * height], [0, 0, 1]]) for f in merged_focals ] if align_pointmaps: pointmaps = merged_pointmaps else: pointmaps = np.stack( [ project( 1 / np.clip(merged_disparity[i], 1e-8, 1e8), intrinsics[i], merged_poses[i], ) for i in range(merged_poses.shape[0]) ] ) return merged_rgb, merged_disparity, merged_poses, pointmaps def process_video_to_frames(video_path: str, fps_sample: int = 12) -> List[str]: """Process video into frames and save them locally.""" # Create a unique output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = f"temp_frames_{timestamp}" os.makedirs(output_dir, exist_ok=True) # Read video video = iio.imread(video_path) # Calculate frame interval based on original video fps if isinstance(video, np.ndarray): # For captured videos total_frames = len(video) frame_interval = max( 1, round(total_frames / (fps_sample * (total_frames / 30))) ) else: # Default if can't determine frame_interval = 2 frame_paths = [] for i, frame in enumerate(video[::frame_interval]): frame_path = os.path.join(output_dir, f"frame_{i:04d}.jpg") if isinstance(frame, np.ndarray): iio.imwrite(frame_path, frame) frame_paths.append(frame_path) return frame_paths, output_dir def save_output_files( rgb: np.ndarray, disparity: np.ndarray, poses: Optional[np.ndarray] = None, raymap: Optional[np.ndarray] = None, pointmap: Optional[np.ndarray] = None, task: str = "reconstruction", output_dir: str = "outputs", **kwargs, ) -> Dict[str, str]: """ Save outputs and return paths to saved files. """ os.makedirs(output_dir, exist_ok=True) if pointmap is None and raymap is not None: # # Generate pointmap from raymap and disparity # smooth_camera = kwargs.get("smooth_camera", True) # smooth_method = ( # kwargs.get("smooth_method", "kalman") if smooth_camera else "none" # ) # pointmap_dict = postprocess_pointmap( # disparity, # raymap, # vae_downsample_scale=8, # ray_o_scale_inv=0.1, # smooth_camera=smooth_camera, # smooth_method=smooth_method, # ) # pointmap = pointmap_dict["pointmap"] window_result = AetherV1PipelineOutput( rgb=rgb, disparity=disparity, raymap=raymap ) window_results = [window_result] window_indices = [0] _, _, poses_from_blend, pointmap = blend_and_merge_window_results(window_results, window_indices, kwargs) # Use poses from blend_and_merge_window_results if poses is None if poses is None: poses = poses_from_blend if poses is None and raymap is not None: poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1) # Create a unique filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") base_filename = f"{task}_{timestamp}" # Paths for saved files paths = {} # Save RGB video rgb_path = os.path.join(output_dir, f"{base_filename}_rgb.mp4") iio.imwrite( rgb_path, (np.clip(rgb, 0, 1) * 255).astype(np.uint8), fps=kwargs.get("fps", 12), ) paths["rgb"] = rgb_path # Save depth/disparity video depth_path = os.path.join(output_dir, f"{base_filename}_disparity.mp4") iio.imwrite( depth_path, (colorize_depth(disparity) * 255).astype(np.uint8), fps=kwargs.get("fps", 12), ) paths["disparity"] = depth_path # Save point cloud GLB files if pointmap is not None and poses is not None: pointcloud_save_frame_interval = kwargs.get( "pointcloud_save_frame_interval", 10 ) max_depth = kwargs.get("max_depth", 100.0) rtol = kwargs.get("rtol", 0.03) glb_paths = [] # Determine which frames to save based on the interval frames_to_save = list( range(0, pointmap.shape[0], pointcloud_save_frame_interval) ) # Always include the first and last frame if 0 not in frames_to_save: frames_to_save.insert(0, 0) if pointmap.shape[0] - 1 not in frames_to_save: frames_to_save.append(pointmap.shape[0] - 1) # Sort the frames to ensure they're in order frames_to_save = sorted(set(frames_to_save)) for frame_idx in frames_to_save: if frame_idx >= pointmap.shape[0]: continue # fix the problem of point cloud being upside down and left-right reversed: flip Y axis and X axis flipped_pointmap = pointmap[frame_idx:frame_idx+1].copy() flipped_pointmap[..., 1] = -flipped_pointmap[..., 1] # flip Y axis (up and down) flipped_pointmap[..., 0] = -flipped_pointmap[..., 0] # flip X axis (left and right) # flip camera poses flipped_poses = poses[frame_idx:frame_idx+1].copy() # flip Y axis and X axis of camera orientation flipped_poses[..., 1, :3] = -flipped_poses[..., 1, :3] # flip Y axis of camera orientation flipped_poses[..., 0, :3] = -flipped_poses[..., 0, :3] # flip X axis of camera orientation flipped_poses[..., :3, 1] = -flipped_poses[..., :3, 1] # flip Y axis of camera orientation flipped_poses[..., :3, 0] = -flipped_poses[..., :3, 0] # flip X axis of camera orientation # flip Y axis and X axis of camera position flipped_poses[..., 1, 3] = -flipped_poses[..., 1, 3] # flip Y axis position flipped_poses[..., 0, 3] = -flipped_poses[..., 0, 3] # flip X axis position # use flipped point cloud and camera poses predictions = { "world_points": flipped_pointmap, "images": rgb[frame_idx : frame_idx + 1], "depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8), "camera_poses": flipped_poses, } glb_path = os.path.join( output_dir, f"{base_filename}_pointcloud_frame_{frame_idx}.glb" ) scene_3d = predictions_to_glb( predictions, filter_by_frames="all", show_cam=True, max_depth=max_depth, rtol=rtol, frame_rel_idx=float(frame_idx) / pointmap.shape[0], ) scene_3d.export(glb_path) glb_paths.append(glb_path) paths["pointcloud_glbs"] = glb_paths return paths @spaces.GPU(duration=300) def process_reconstruction( video_file, height, width, num_frames, num_inference_steps, guidance_scale, sliding_window_stride, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, progress=gr.Progress(), ): """ Process reconstruction task. """ try: gc.collect() torch.cuda.empty_cache() seed_all(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): raise ValueError("CUDA is not available. Check your environment.") pipeline = build_pipeline(device) progress(0.1, "Loading video") # Check if video_file is a string or a file object if isinstance(video_file, str): video_path = video_file else: video_path = video_file.name video = iio.imread(video_path).astype(np.float32) / 255.0 # Setup arguments args = { "height": height, "width": width, "num_frames": num_frames, "sliding_window_stride": sliding_window_stride, "smooth_camera": smooth_camera, "smooth_method": "kalman" if smooth_camera else "none", "align_pointmaps": align_pointmaps, "max_depth": max_depth, "rtol": rtol, "pointcloud_save_frame_interval": pointcloud_save_frame_interval, } # Process in sliding windows window_results = [] window_indices = get_window_starts( len(video), num_frames, sliding_window_stride ) progress(0.2, f"Processing video in {len(window_indices)} windows") for i, start_idx in enumerate(window_indices): progress_val = 0.2 + (0.6 * (i / len(window_indices))) progress(progress_val, f"Processing window {i+1}/{len(window_indices)}") output = pipeline( task="reconstruction", image=None, goal=None, video=video[start_idx : start_idx + num_frames], raymap=None, height=height, width=width, num_frames=num_frames, fps=fps, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, use_dynamic_cfg=False, generator=torch.Generator(device=device).manual_seed(seed), ) window_results.append(output) progress(0.8, "Merging results from all windows") # Merge window results ( merged_rgb, merged_disparity, merged_poses, pointmaps, ) = blend_and_merge_window_results(window_results, window_indices, args) progress(0.9, "Saving output files") # Save output files output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) output_paths = save_output_files( rgb=merged_rgb, disparity=merged_disparity, poses=merged_poses, pointmap=pointmaps, task="reconstruction", output_dir=output_dir, fps=12, **args, ) progress(1.0, "Done!") # Return paths for displaying return ( output_paths["rgb"], output_paths["disparity"], output_paths.get("pointcloud_glbs", []), ) except Exception: import traceback traceback.print_exc() return None, None, [] @spaces.GPU(duration=300) def process_prediction( image_file, height, width, num_frames, num_inference_steps, guidance_scale, use_dynamic_cfg, raymap_option, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, progress=gr.Progress(), ): """ Process prediction task. """ try: gc.collect() torch.cuda.empty_cache() # Set random seed seed_all(seed) # Check if CUDA is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): raise ValueError("CUDA is not available. Check your environment.") # Build the pipeline pipeline = build_pipeline(device) progress(0.1, "Loading image") # Check if image_file is a string or a file object if isinstance(image_file, str): image_path = image_file else: image_path = image_file.name image = PIL.Image.open(image_path) progress(0.2, "Running prediction") # Run prediction output = pipeline( task="prediction", image=image, video=None, goal=None, raymap=np.load(f"assets/example_raymaps/raymap_{raymap_option}.npy"), height=height, width=width, num_frames=num_frames, fps=fps, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, use_dynamic_cfg=use_dynamic_cfg, generator=torch.Generator(device=device).manual_seed(seed), return_dict=True, ) # Show RGB output immediately rgb_output = output.rgb # Setup arguments for saving args = { "height": height, "width": width, "smooth_camera": smooth_camera, "smooth_method": "kalman" if smooth_camera else "none", "align_pointmaps": align_pointmaps, "max_depth": max_depth, "rtol": rtol, "pointcloud_save_frame_interval": pointcloud_save_frame_interval, } if post_reconstruction: progress(0.5, "Running post-reconstruction for better quality") recon_output = pipeline( task="reconstruction", video=output.rgb, height=height, width=width, num_frames=num_frames, fps=fps, num_inference_steps=4, guidance_scale=1.0, use_dynamic_cfg=False, generator=torch.Generator(device=device).manual_seed(seed), ) disparity = recon_output.disparity raymap = recon_output.raymap else: disparity = output.disparity raymap = output.raymap progress(0.8, "Saving output files") # Save output files output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) output_paths = save_output_files( rgb=rgb_output, disparity=disparity, raymap=raymap, task="prediction", output_dir=output_dir, fps=12, **args, ) progress(1.0, "Done!") # Return paths for displaying return ( output_paths["rgb"], output_paths["disparity"], output_paths.get("pointcloud_glbs", []), ) except Exception: import traceback traceback.print_exc() return None, None, [] @spaces.GPU(duration=300) def process_planning( image_file, goal_file, height, width, num_frames, num_inference_steps, guidance_scale, use_dynamic_cfg, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, progress=gr.Progress(), ): """ Process planning task. """ try: gc.collect() torch.cuda.empty_cache() # Set random seed seed_all(seed) # Check if CUDA is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): raise ValueError("CUDA is not available. Check your environment.") # Build the pipeline pipeline = build_pipeline(device) progress(0.1, "Loading images") # Check if image_file and goal_file are strings or file objects if isinstance(image_file, str): image_path = image_file else: image_path = image_file.name if isinstance(goal_file, str): goal_path = goal_file else: goal_path = goal_file.name image = PIL.Image.open(image_path) goal = PIL.Image.open(goal_path) progress(0.2, "Running planning") # Run planning output = pipeline( task="planning", image=image, video=None, goal=goal, raymap=None, height=height, width=width, num_frames=num_frames, fps=fps, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, use_dynamic_cfg=use_dynamic_cfg, generator=torch.Generator(device=device).manual_seed(seed), return_dict=True, ) # Show RGB output immediately rgb_output = output.rgb # Setup arguments for saving args = { "height": height, "width": width, "smooth_camera": smooth_camera, "smooth_method": "kalman" if smooth_camera else "none", "align_pointmaps": align_pointmaps, "max_depth": max_depth, "rtol": rtol, "pointcloud_save_frame_interval": pointcloud_save_frame_interval, } if post_reconstruction: progress(0.5, "Running post-reconstruction for better quality") recon_output = pipeline( task="reconstruction", video=output.rgb, height=height, width=width, num_frames=num_frames, fps=12, num_inference_steps=4, guidance_scale=1.0, use_dynamic_cfg=False, generator=torch.Generator(device=device).manual_seed(seed), ) disparity = recon_output.disparity raymap = recon_output.raymap else: disparity = output.disparity raymap = output.raymap progress(0.8, "Saving output files") # Save output files output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) output_paths = save_output_files( rgb=rgb_output, disparity=disparity, raymap=raymap, task="planning", output_dir=output_dir, fps=fps, **args, ) progress(1.0, "Done!") # Return paths for displaying return ( output_paths["rgb"], output_paths["disparity"], output_paths.get("pointcloud_glbs", []), ) except Exception: import traceback traceback.print_exc() return None, None, [] def update_task_ui(task): """Update UI elements based on selected task.""" if task == "reconstruction": return ( gr.update(visible=True), # reconstruction_group gr.update(visible=False), # prediction_group gr.update(visible=False), # planning_group gr.update(visible=False), # preview_row gr.update(value=4), # num_inference_steps gr.update(visible=True), # sliding_window_stride gr.update(visible=False), # use_dynamic_cfg gr.update(visible=False), # raymap_option gr.update(visible=False), # post_reconstruction gr.update(value=1.0), # guidance_scale ) elif task == "prediction": return ( gr.update(visible=False), # reconstruction_group gr.update(visible=True), # prediction_group gr.update(visible=False), # planning_group gr.update(visible=True), # preview_row gr.update(value=50), # num_inference_steps gr.update(visible=False), # sliding_window_stride gr.update(visible=True), # use_dynamic_cfg gr.update(visible=True), # raymap_option gr.update(visible=True), # post_reconstruction gr.update(value=3.0), # guidance_scale ) elif task == "planning": return ( gr.update(visible=False), # reconstruction_group gr.update(visible=False), # prediction_group gr.update(visible=True), # planning_group gr.update(visible=True), # preview_row gr.update(value=50), # num_inference_steps gr.update(visible=False), # sliding_window_stride gr.update(visible=True), # use_dynamic_cfg gr.update(visible=False), # raymap_option gr.update(visible=True), # post_reconstruction gr.update(value=3.0), # guidance_scale ) def update_image_preview(image_file): """Update the image preview.""" if image_file is None: return None if isinstance(image_file, str): return image_file return image_file.name if hasattr(image_file, 'name') else None def update_goal_preview(goal_file): """Update the goal preview.""" if goal_file is None: return None if isinstance(goal_file, str): return goal_file return goal_file.name if hasattr(goal_file, 'name') else None def get_download_link(selected_frame, all_paths): """Update the download button with the selected file path.""" if not selected_frame or not all_paths: return gr.update(visible=False, value=None) frame_num = int(re.search(r"Frame (\d+)", selected_frame).group(1)) for path in all_paths: if f"frame_{frame_num}" in path: # Make sure the file exists before setting it if os.path.exists(path): return gr.update(visible=True, value=path, interactive=True) return gr.update(visible=False, value=None) # Theme setup theme = gr.themes.Default( primary_hue="blue", secondary_hue="cyan", ) with gr.Blocks( theme=theme, css=""" .output-column { min-height: 400px; } .warning { color: #856404 !important; font-weight: bold !important; padding: 10px !important; background-color: #fff3cd !important; border-left: 4px solid #ffc107 !important; border-radius: 4px !important; margin: 10px 0 !important; } .dark .warning { background-color: rgba(255, 193, 7, 0.1) !important; color: #fbd38d !important; } .highlight { background-color: rgba(0, 123, 255, 0.1); padding: 10px; border-radius: 8px; border-left: 5px solid #007bff; margin: 10px 0; } .task-header { margin-top: 15px; margin-bottom: 20px; font-size: 1.4em; font-weight: bold; color: #007bff; } .flex-display { display: flex; flex-wrap: wrap; gap: 10px; } .output-subtitle { font-size: 1.1em; margin-top: 5px; margin-bottom: 5px; color: #505050; } .input-section, .params-section, .advanced-section { border: 1px solid #ddd; padding: 20px; border-radius: 8px; margin-bottom: 20px; } .logo-container { display: flex; justify-content: center; margin-bottom: 20px; } .logo-image { max-width: 300px; height: auto; } /* Optimize layout and spacing */ .container { margin: 0 auto; padding: 0 15px; max-width: 1800px; } .header { text-align: center; margin-bottom: 20px; padding: 15px; background: linear-gradient(to right, #f8f9fa, #e9ecef); border-radius: 10px; } .dark .header { background: linear-gradient(to right, #2d3748, #1a202c); } .main-title { font-size: 2.2em; font-weight: bold; margin: 0 auto; color: #2c3e50; max-width: 800px; } .dark .main-title { color: #e2e8f0; } .links-bar { display: flex; justify-content: center; gap: 15px; margin: 12px 0; } .link-button { display: inline-flex; align-items: center; padding: 6px 12px; background-color: #007bff; color: white !important; text-decoration: none; border-radius: 5px; transition: background-color 0.3s; font-size: 0.95em; } .link-button:hover { background-color: #0056b3; text-decoration: none; } .features-limitations-container { display: flex; gap: 15px; margin: 20px 0; } .capabilities-box, .limitations-box { flex: 1; padding: 18px; border-radius: 8px; margin-bottom: 15px; } .capabilities-box { background: #f0f9ff; border-left: 5px solid #3498db; } .dark .capabilities-box { background: #172a3a; border-left: 5px solid #3498db; } .limitations-box { background: #f8f9fa; border-left: 5px solid #ffc107; } .dark .limitations-box { background: #2d2a20; border-left: 5px solid #ffc107; } .capabilities-text, .limitations-text { color: #495057; line-height: 1.6; } .dark .capabilities-text, .dark .limitations-text { color: #cbd5e0; } .capabilities-text h3 { color: #2980b9; margin-top: 0; margin-bottom: 15px; } .dark .capabilities-text h3 { color: #63b3ed; } .limitations-text h3 { color: #d39e00; margin-top: 0; margin-bottom: 15px; } .dark .limitations-text h3 { color: #fbd38d; } .capabilities-text blockquote, .limitations-text blockquote { margin: 20px 0 0 0; padding: 10px 20px; font-style: italic; } .capabilities-text blockquote { border-left: 3px solid #3498db; background: rgba(52, 152, 219, 0.1); } .dark .capabilities-text blockquote { background: rgba(52, 152, 219, 0.2); } .limitations-text blockquote { border-left: 3px solid #ffc107; background: rgba(255, 193, 7, 0.1); } .dark .limitations-text blockquote { background: rgba(255, 193, 7, 0.2); } /* Optimize layout and spacing */ .main-interface { display: flex; gap: 30px; margin-top: 20px; } .input-column, .output-column { flex: 1; min-width: 0; display: flex; flex-direction: column; } .output-panel { border: 1px solid #ddd; border-radius: 8px; padding: 20px; height: 100%; display: flex; flex-direction: column; overflow-y: auto; } .dark .output-panel { border-color: #4a5568; } .run-button-container { display: flex; justify-content: center; margin: 15px 0; } .run-button { padding: 10px 30px; font-size: 1.1em; font-weight: bold; background: linear-gradient(to right, #3498db, #2980b9); border: none; border-radius: 5px; color: white; cursor: pointer; transition: all 0.3s; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .run-button:hover { background: linear-gradient(to right, #2980b9, #1a5276); box-shadow: 0 6px 8px rgba(0, 0, 0, 0.15); transform: translateY(-2px); } .task-selector { background-color: #f8f9fa; padding: 12px; border-radius: 8px; margin-bottom: 15px; border: 1px solid #e9ecef; } .dark .task-selector { background-color: #2d3748; border-color: #4a5568; } /* Compact parameter settings */ .compact-params .row { margin-bottom: 8px; } .compact-params label { margin-bottom: 4px; } /* More obvious advanced options */ .advanced-options-header { background-color: #e9ecef; padding: 10px 15px; border-radius: 6px; margin-top: 10px; font-weight: bold; color: #495057; border-left: 4px solid #6c757d; cursor: pointer; transition: all 0.2s; } .advanced-options-header:hover { background-color: #dee2e6; } .dark .advanced-options-header { background-color: #2d3748; color: #e2e8f0; border-left: 4px solid #a0aec0; } .dark .advanced-options-header:hover { background-color: #4a5568; } /* Vertical arrangement of output section */ .output-section { margin-bottom: 30px; border: 1px solid #e9ecef; border-radius: 8px; padding: 20px; } .output-section-title { font-weight: bold; color: #495057; margin-bottom: 15px; font-size: 1.2em; } .dark .output-section-title { color: #e2e8f0; } .pointcloud-controls { display: flex; gap: 10px; margin-bottom: 10px; align-items: center; } .note-box { background-color: #fff8e1 !important; border-left: 4px solid #ffc107 !important; padding: 12px !important; margin: 15px 0 !important; border-radius: 4px !important; color: #333 !important; } .dark .note-box { background-color: rgba(255, 193, 7, 0.1) !important; color: #e0e0e0 !important; } .note-box p, .note-box strong { color: inherit !important; } /* Ensure warning class styles are correctly applied */ .warning { color: #856404 !important; font-weight: bold !important; padding: 10px !important; background-color: #fff3cd !important; border-left: 4px solid #ffc107 !important; border-radius: 4px !important; margin: 10px 0 !important; } .dark .warning { background-color: rgba(255, 193, 7, 0.1) !important; color: #fbd38d !important; } .warning-box { background-color: #fff3cd; border-left: 4px solid #ffc107; padding: 12px; margin: 15px 0; border-radius: 4px; color: #856404; } .dark .warning-box { background-color: rgba(255, 193, 7, 0.1); color: #fbd38d; } """, ) as demo: with gr.Column(elem_classes=["container"]): with gr.Row(elem_classes=["header"]): with gr.Column(): gr.Markdown( """ # Aether: Geometric-Aware Unified World Modeling """, elem_classes=["main-title"] ) gr.Markdown( """ """, ) with gr.Row(elem_classes=["features-limitations-container"]): with gr.Column(elem_classes=["capabilities-box"]): gr.Markdown( """ ### 🚀 Key Capabilities Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities: - 🌏 **4D Dynamic Reconstruction**: Reconstruct dynamic point clouds from videos by estimating depths and camera poses. - 🎬 **Action-Conditioned Prediction**: Predict future frames based on initial observations, with optional camera trajectory actions. - 🎯 **Goal-Conditioned Planning**: Generate planning paths from pairs of observation and goal images. > *Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.* """, elem_classes=["capabilities-text"] ) with gr.Column(elem_classes=["limitations-box"]): gr.Markdown( """ ### 📝 Current Limitations Aether represents an initial step in our journey, trained entirely on synthetic data. While it demonstrates promising capabilities, it is important to be aware of its current limitations: - 🔄 **Dynamic Scenarios**: Struggles with highly dynamic scenarios involving significant motion or dense crowds. - 📸 **Camera Stability**: Camera pose estimation can be less stable in certain conditions. - 📐 **Planning Range**: For visual planning tasks, we recommend keeping the observations and goals relatively close to ensure optimal performance. > *We are actively working on the next generation of Aether and are committed to addressing these limitations in future releases.* """, elem_classes=["limitations-text"] ) with gr.Row(elem_classes=["main-interface"]): with gr.Column(elem_classes=["input-column"]): gpu_time_warning = gr.Markdown( """
⚠️ Warning:
Due to HuggingFace Spaces ZERO GPU quota limitations, only short video reconstruction tasks (less than 100 frames) can be completed online. 💻 Recommendation:
We strongly encourage you to deploy Aether locally for: - Processing longer video reconstruction tasks - Better performance and full access to prediction and planning tasks Visit our GitHub repository for local deployment instructions.
""", ) with gr.Group(elem_classes=["task-selector"]): task = gr.Radio( ["reconstruction", "prediction", "planning"], label="Select Task", value="reconstruction", info="Choose the task you want to perform", ) with gr.Group(elem_classes=["input-section"]): gr.Markdown("## 📥 Input", elem_classes=["task-header"]) # Task-specific inputs with gr.Group(visible=True) as reconstruction_group: video_input = gr.Video( label="Upload Input Video", sources=["upload"], interactive=True, elem_id="video_input", ) reconstruction_examples = gr.Examples( examples=[ ["assets/example_videos/bridge.mp4"], ["assets/example_videos/moviegen.mp4"], ["assets/example_videos/nuscenes.mp4"], ["assets/example_videos/veo2.mp4"], ], inputs=[video_input], label="Reconstruction Examples", examples_per_page=4, ) with gr.Group(visible=False) as prediction_group: image_input = gr.Image( label="Upload Start Image", type="filepath", interactive=True, elem_id="image_input", ) prediction_examples = gr.Examples( examples=[ ["assets/example_obs/car.png"], ["assets/example_obs/cartoon.png"], ["assets/example_obs/garden.jpg"], ["assets/example_obs/room.jpg"], ], inputs=[image_input], label="Prediction Examples", examples_per_page=4, ) with gr.Group(visible=False) as planning_group: with gr.Row(): image_input_planning = gr.Image( label="Upload Start Image", type="filepath", interactive=True, elem_id="image_input_planning", ) goal_input = gr.Image( label="Upload Goal Image", type="filepath", interactive=True, elem_id="goal_input", ) planning_examples = gr.Examples( examples=[ ["assets/example_obs_goal/01_obs.png", "assets/example_obs_goal/01_goal.png"], ["assets/example_obs_goal/02_obs.png", "assets/example_obs_goal/02_goal.png"], ["assets/example_obs_goal/03_obs.png", "assets/example_obs_goal/03_goal.png"], ["assets/example_obs_goal/04_obs.png", "assets/example_obs_goal/04_goal.png"], ], inputs=[image_input_planning, goal_input], label="Planning Examples", examples_per_page=4, ) with gr.Row(visible=False) as preview_row: image_preview = gr.Image( label="Start Image Preview", elem_id="image_preview", visible=False, ) goal_preview = gr.Image( label="Goal Image Preview", elem_id="goal_preview", visible=False, ) with gr.Group(elem_classes=["params-section", "compact-params"]): gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"]) with gr.Row(): with gr.Column(scale=1): height = gr.Dropdown( choices=[480], value=480, label="Height", info="Height of the output video", ) with gr.Column(scale=1): width = gr.Dropdown( choices=[720], value=720, label="Width", info="Width of the output video", ) with gr.Row(): with gr.Column(scale=1): num_frames = gr.Dropdown( choices=[17, 25, 33, 41], value=41, label="Number of Frames", info="Number of frames to predict", ) with gr.Column(scale=1): fps = gr.Dropdown( choices=[8, 10, 12, 15, 24], value=24, label="FPS", info="Frames per second", ) with gr.Row(): num_inference_steps = gr.Slider( minimum=1, maximum=60, value=4, step=1, label="Inference Steps", info="Number of inference step", ) sliding_window_stride = gr.Slider( minimum=1, maximum=40, value=24, step=1, label="Sliding Window Stride", info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task", visible=True, ) use_dynamic_cfg = gr.Checkbox( label="Use Dynamic CFG", value=True, info="Use dynamic CFG", visible=False, ) raymap_option = gr.Radio( choices=["backward", "forward_right", "left_forward", "right"], label="Camera Movement Direction", value="forward_right", info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.", visible=False, ) post_reconstruction = gr.Checkbox( label="Post-Reconstruction", value=True, info="Run reconstruction after prediction for better quality", visible=False, ) with gr.Accordion( "Advanced Options", open=False, visible=True, elem_classes=["advanced-options-header"] ) as advanced_options: with gr.Group(elem_classes=["advanced-section"]): with gr.Row(): guidance_scale = gr.Slider( minimum=1.0, maximum=10.0, value=1.0, step=0.1, label="Guidance Scale", info="Guidance scale (only for prediction / planning)", ) with gr.Row(): seed = gr.Number( value=42, label="Random Seed", info="Set a seed for reproducible results", precision=0, minimum=0, maximum=2147483647, ) with gr.Row(): with gr.Column(scale=1): smooth_camera = gr.Checkbox( label="Smooth Camera", value=True, info="Apply smoothing to camera trajectory", ) with gr.Column(scale=1): align_pointmaps = gr.Checkbox( label="Align Point Maps", value=False, info="Align point maps across frames", ) with gr.Row(): with gr.Column(scale=1): max_depth = gr.Slider( minimum=10, maximum=200, value=60, step=10, label="Max Depth", info="Maximum depth for point cloud (higher = more distant points)", ) with gr.Column(scale=1): rtol = gr.Slider( minimum=0.01, maximum=2.0, value=0.2, step=0.01, label="Relative Tolerance", info="Used for depth edge detection. Lower = remove more edges", ) pointcloud_save_frame_interval = gr.Slider( minimum=1, maximum=20, value=10, step=1, label="Point Cloud Frame Interval", info="Save point cloud every N frames (higher = fewer files but less complete representation)", ) with gr.Group(elem_classes=["run-button-container"]): run_button = gr.Button("Run Aether", variant="primary", elem_classes=["run-button"]) with gr.Column(elem_classes=["output-column"]): with gr.Group(elem_classes=["output-panel"]): gr.Markdown("## 📤 Output", elem_classes=["task-header"]) with gr.Group(elem_classes=["output-section"]): gr.Markdown("### RGB Video", elem_classes=["output-section-title"]) rgb_output = gr.Video( label="RGB Output", interactive=False, elem_id="rgb_output" ) with gr.Group(elem_classes=["output-section"]): gr.Markdown("### Depth Video", elem_classes=["output-section-title"]) depth_output = gr.Video( label="Depth Output", interactive=False, elem_id="depth_output" ) with gr.Group(elem_classes=["output-section"]): gr.Markdown("### Point Clouds", elem_classes=["output-section-title"]) with gr.Row(elem_classes=["pointcloud-controls"]): pointcloud_frames = gr.Dropdown( label="Select Frame", choices=[], value=None, interactive=True, elem_id="pointcloud_frames", ) pointcloud_download = gr.DownloadButton( label="Download Point Cloud", visible=False, elem_id="pointcloud_download", ) model_output = gr.Model3D( label="Point Cloud Viewer", interactive=True, elem_id="model_output" ) gr.Markdown( """ > **Note:** 3D point clouds take a long time to visualize, and we show the keyframes only. > You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`. """ ) with gr.Group(elem_classes=["output-section"]): gr.Markdown("### About Results", elem_classes=["output-section-title"]) gr.Markdown( """ #### Understanding the Outputs - **RGB Video**: Shows the predicted or reconstructed RGB frames - **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue) - **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids """ ) # Event handlers task.change( fn=update_task_ui, inputs=[task], outputs=[ reconstruction_group, prediction_group, planning_group, preview_row, num_inference_steps, sliding_window_stride, use_dynamic_cfg, raymap_option, post_reconstruction, guidance_scale, ], ) image_input.change( fn=update_image_preview, inputs=[image_input], outputs=[image_preview] ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row]) goal_input.change( fn=update_goal_preview, inputs=[goal_input], outputs=[goal_preview] ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row]) def update_pointcloud_frames(pointcloud_paths): """Update the pointcloud frames dropdown with available frames.""" if not pointcloud_paths: return gr.update(choices=[], value=None), None, gr.update(visible=False) # Extract frame numbers from filenames frame_info = [] for path in pointcloud_paths: filename = os.path.basename(path) match = re.search(r"frame_(\d+)", filename) if match: frame_num = int(match.group(1)) frame_info.append((f"Frame {frame_num}", path)) # Sort by frame number frame_info.sort(key=lambda x: int(re.search(r"Frame (\d+)", x[0]).group(1))) choices = [label for label, _ in frame_info] paths = [path for _, path in frame_info] if not choices: return gr.update(choices=[], value=None), None, gr.update(visible=False) # Make download button visible when we have point cloud files return ( gr.update(choices=choices, value=choices[0]), paths[0], gr.update(visible=True), ) def select_pointcloud_frame(frame_label, all_paths): """Select a specific pointcloud frame.""" if not frame_label or not all_paths: return None frame_num = int(re.search(r"Frame (\d+)", frame_label).group(1)) for path in all_paths: if f"frame_{frame_num}" in path: return path return None # Then in the run button click handler: def process_task(task_type, *args): """Process selected task with appropriate function.""" if task_type == "reconstruction": rgb_path, depth_path, pointcloud_paths = process_reconstruction(*args) # Update the pointcloud frames dropdown frame_dropdown, initial_path, download_visible = update_pointcloud_frames( pointcloud_paths ) return ( rgb_path, depth_path, initial_path, frame_dropdown, pointcloud_paths, download_visible, ) elif task_type == "prediction": rgb_path, depth_path, pointcloud_paths = process_prediction(*args) frame_dropdown, initial_path, download_visible = update_pointcloud_frames( pointcloud_paths ) return ( rgb_path, depth_path, initial_path, frame_dropdown, pointcloud_paths, download_visible, ) elif task_type == "planning": rgb_path, depth_path, pointcloud_paths = process_planning(*args) frame_dropdown, initial_path, download_visible = update_pointcloud_frames( pointcloud_paths ) return ( rgb_path, depth_path, initial_path, frame_dropdown, pointcloud_paths, download_visible, ) return ( None, None, None, gr.update(choices=[], value=None), [], gr.update(visible=False), ) # Store all pointcloud paths for later use all_pointcloud_paths = gr.State([]) run_button.click( fn=lambda task_type, video_file, image_file, image_input_planning, goal_file, height, width, num_frames, num_inference_steps, guidance_scale, sliding_window_stride, use_dynamic_cfg, raymap_option, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed: process_task( task_type, *( [ video_file, height, width, num_frames, num_inference_steps, guidance_scale, sliding_window_stride, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, ] if task_type == "reconstruction" else [ image_file, height, width, num_frames, num_inference_steps, guidance_scale, use_dynamic_cfg, raymap_option, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, ] if task_type == "prediction" else [ image_input_planning, goal_file, height, width, num_frames, num_inference_steps, guidance_scale, use_dynamic_cfg, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, ] ), ), inputs=[ task, video_input, image_input, image_input_planning, goal_input, height, width, num_frames, num_inference_steps, guidance_scale, sliding_window_stride, use_dynamic_cfg, raymap_option, post_reconstruction, fps, smooth_camera, align_pointmaps, max_depth, rtol, pointcloud_save_frame_interval, seed, ], outputs=[ rgb_output, depth_output, model_output, pointcloud_frames, all_pointcloud_paths, pointcloud_download, ], ) pointcloud_frames.change( fn=select_pointcloud_frame, inputs=[pointcloud_frames, all_pointcloud_paths], outputs=[model_output], ).then( fn=get_download_link, inputs=[pointcloud_frames, all_pointcloud_paths], outputs=[pointcloud_download], ) # Load the model at startup demo.load(lambda: build_pipeline(torch.device("cpu")), inputs=None, outputs=None) if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "false" demo.queue(max_size=20).launch(show_error=True, share=True)