Wenzheng Chang commited on
Commit
19da45c
·
1 Parent(s): ddee6ec

aetherv1 init

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ env/
106
+ .venv
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ ### VisualStudioCode
131
+ .vscode/*
132
+ !.vscode/settings.json
133
+ !.vscode/tasks.json
134
+ !.vscode/launch.json
135
+ !.vscode/extensions.json
136
+ *.code-workspace
137
+ **/.vscode
138
+
139
+ # JetBrains
140
+ .idea/
141
+
142
+ # Data & Models
143
+ *.h5
144
+ *.tar
145
+ *.tar.gz
146
+
147
+ # Lightning-Hydra-Template
148
+ configs/local/default.yaml
149
+ # data/
150
+ /logs/
151
+ .env
152
+
153
+ # Aim logging
154
+ .aim
155
+
156
+ # local files
157
+ logs
.pre-commit-config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ rev: v0.2.1
4
+ hooks:
5
+ - id: ruff
6
+ exclude: ^test/ # Skip the entire "test" directory
7
+ args:
8
+ - --fix
9
+ - id: ruff-format
10
+ - repo: https://github.com/pre-commit/pre-commit-hooks
11
+ rev: v4.5.0
12
+ hooks:
13
+ - id: check-merge-conflict
14
+ - id: check-yaml
.project-root ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # this file is required for inferring the project root directory
2
+ # do not delete
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Aether Team
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,71 @@
1
  ---
2
  title: AetherV1
3
- emoji: 😻
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.23.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: 'Aether: Geometric-Aware Unified World Modeling'
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: AetherV1
3
+ emoji: 🌏
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.23.1
8
+ app_file: scripts/demo_gradio.py
9
  pinned: false
10
  license: mit
11
  short_description: 'Aether: Geometric-Aware Unified World Modeling'
12
  ---
13
 
14
+ <div align="center">
15
+
16
+ # Aether: Geometric-Aware Unified World Modeling
17
+
18
+ </div>
19
+
20
+ <div align="center">
21
+ <img width="400" alt="image" src="assets/logo.png">
22
+ <!-- <br> -->
23
+ </div>
24
+
25
+ <div align="center">
26
+ <a href='https://arxiv.org/abs/2503.18945'><img src='https://img.shields.io/badge/arXiv-2503.18945-red'></a> &nbsp;
27
+ <a href='https://aether-world.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp;
28
+ <a href=''><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo%20(Coming%20Soon)-blue'></a> &nbsp;
29
+ </div>
30
+
31
+ Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with generative modeling
32
+ for human-like spatial reasoning. Our framework unifies three core capabilities: (1) **4D dynamic reconstruction**,
33
+ (2) **action-conditioned video prediction**, and (3) **goal-conditioned visual planning**. Trained entirely on
34
+ synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.
35
+
36
+ <div align="center">
37
+ <img src="assets/teaser.png" alt="Teaser" width="800"/>
38
+ </div>
39
+
40
+
41
+ ## 📝 Citation
42
+ If you find this work useful in your research, please consider citing:
43
+
44
+ ```bibtex
45
+ @article{aether,
46
+ title = {Aether: Geometric-Aware Unified World Modeling},
47
+ author = {Aether Team and Haoyi Zhu and Yifan Wang and Jianjun Zhou and Wenzheng Chang and Yang Zhou and Zizun Li and Junyi Chen and Chunhua Shen and Jiangmiao Pang and Tong He},
48
+ journal = {arXiv preprint arXiv:2503.18945},
49
+ year = {2025}
50
+ }
51
+ ```
52
+
53
+ ## ⚖️ License
54
+ This repository is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
55
+
56
+ ## 🙏 Acknowledgements
57
+ Our work is primarily built upon
58
+ [Accelerate](https://github.com/huggingface/accelerate),
59
+ [Diffusers](https://github.com/huggingface/diffusers),
60
+ [CogVideoX](https://github.com/THUDM/CogVideo),
61
+ [Finetrainers](https://github.com/a-r-r-o-w/finetrainers),
62
+ [DepthAnyVideo](https://github.com/Nightmare-n/DepthAnyVideo),
63
+ [CUT3R](https://github.com/CUT3R/CUT3R),
64
+ [MonST3R](https://github.com/Junyi42/monst3r),
65
+ [VBench](https://github.com/Vchitect/VBench),
66
+ [GST](https://github.com/SOTAMak1r/GST),
67
+ [SPA](https://github.com/HaoyiZhu/SPA),
68
+ [DroidCalib](https://github.com/boschresearch/DroidCalib),
69
+ [Grounded-SAM-2](https://github.com/IDEA-Research/Grounded-SAM-2),
70
+ [ceres-solver](https://github.com/ceres-solver/ceres-solver), etc.
71
+ We extend our gratitude to all these authors for their generously open-sourced code and their significant contributions to the community.
aether/__init__.py ADDED
File without changes
aether/pipelines/__init__.py ADDED
File without changes
aether/pipelines/aetherv1_pipeline_cogvideox.py ADDED
@@ -0,0 +1,959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ from diffusers import (
10
+ AutoencoderKLCogVideoX,
11
+ CogVideoXDPMScheduler,
12
+ CogVideoXImageToVideoPipeline,
13
+ CogVideoXTransformer3DModel,
14
+ )
15
+ from diffusers.image_processor import PipelineImageInput
16
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
17
+ from diffusers.utils import BaseOutput
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from einops import rearrange
20
+ from transformers import AutoTokenizer, T5EncoderModel
21
+
22
+ from aether.utils.preprocess_utils import imcrop_center
23
+
24
+
25
+ def get_3d_rotary_pos_embed(
26
+ embed_dim,
27
+ crops_coords,
28
+ grid_size,
29
+ temporal_size,
30
+ theta: int = 10000,
31
+ use_real: bool = True,
32
+ grid_type: str = "linspace",
33
+ max_size: Optional[Tuple[int, int]] = None,
34
+ device: Optional[torch.device] = None,
35
+ fps_factor: Optional[float] = 1.0,
36
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
37
+ """
38
+ RoPE for video tokens with 3D structure.
39
+
40
+ Args:
41
+ embed_dim: (`int`):
42
+ The embedding dimension size, corresponding to hidden_size_head.
43
+ crops_coords (`Tuple[int]`):
44
+ The top-left and bottom-right coordinates of the crop.
45
+ grid_size (`Tuple[int]`):
46
+ The grid size of the spatial positional embedding (height, width).
47
+ temporal_size (`int`):
48
+ The size of the temporal dimension.
49
+ theta (`float`):
50
+ Scaling factor for frequency computation.
51
+ grid_type (`str`):
52
+ Whether to use "linspace" or "slice" to compute grids.
53
+ fps_factor (`float`):
54
+ The relative fps factor of the video, computed by base_fps / fps. Useful for variable fps training.
55
+
56
+ Returns:
57
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
58
+ """
59
+ if use_real is not True:
60
+ raise ValueError(
61
+ " `use_real = False` is not currently supported for get_3d_rotary_pos_embed"
62
+ )
63
+
64
+ if grid_type == "linspace":
65
+ start, stop = crops_coords
66
+ grid_size_h, grid_size_w = grid_size
67
+ grid_h = torch.linspace(
68
+ start[0],
69
+ stop[0] * (grid_size_h - 1) / grid_size_h,
70
+ grid_size_h,
71
+ device=device,
72
+ dtype=torch.float32,
73
+ )
74
+ grid_w = torch.linspace(
75
+ start[1],
76
+ stop[1] * (grid_size_w - 1) / grid_size_w,
77
+ grid_size_w,
78
+ device=device,
79
+ dtype=torch.float32,
80
+ )
81
+ grid_t = (
82
+ torch.linspace(
83
+ 0,
84
+ temporal_size * (temporal_size - 1) / temporal_size,
85
+ temporal_size,
86
+ device=device,
87
+ dtype=torch.float32,
88
+ )
89
+ * fps_factor
90
+ )
91
+ elif grid_type == "slice":
92
+ max_h, max_w = max_size
93
+ grid_size_h, grid_size_w = grid_size
94
+ grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
95
+ grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
96
+ grid_t = (
97
+ torch.arange(temporal_size, device=device, dtype=torch.float32) * fps_factor
98
+ )
99
+ else:
100
+ raise ValueError("Invalid value passed for `grid_type`.")
101
+
102
+ # Compute dimensions for each axis
103
+ dim_t = embed_dim // 4
104
+ dim_h = embed_dim // 8 * 3
105
+ dim_w = embed_dim // 8 * 3
106
+
107
+ # Temporal frequencies
108
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
109
+ # Spatial frequencies for height and width
110
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
111
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
112
+
113
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
114
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
115
+ freqs_t = freqs_t[:, None, None, :].expand(
116
+ -1, grid_size_h, grid_size_w, -1
117
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
118
+ freqs_h = freqs_h[None, :, None, :].expand(
119
+ temporal_size, -1, grid_size_w, -1
120
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
121
+ freqs_w = freqs_w[None, None, :, :].expand(
122
+ temporal_size, grid_size_h, -1, -1
123
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
124
+
125
+ freqs = torch.cat(
126
+ [freqs_t, freqs_h, freqs_w], dim=-1
127
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
128
+ freqs = freqs.view(
129
+ temporal_size * grid_size_h * grid_size_w, -1
130
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
131
+ return freqs
132
+
133
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
134
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
135
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
136
+
137
+ if grid_type == "slice":
138
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
139
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
140
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
141
+
142
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
143
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
144
+ return cos, sin
145
+
146
+
147
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
148
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
149
+ tw = tgt_width
150
+ th = tgt_height
151
+ h, w = src
152
+ r = h / w
153
+ if r > (th / tw):
154
+ resize_height = th
155
+ resize_width = int(round(th / h * w))
156
+ else:
157
+ resize_width = tw
158
+ resize_height = int(round(tw / w * h))
159
+
160
+ crop_top = int(round((th - resize_height) / 2.0))
161
+ crop_left = int(round((tw - resize_width) / 2.0))
162
+
163
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
164
+
165
+
166
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
167
+ def retrieve_timesteps(
168
+ scheduler,
169
+ num_inference_steps: Optional[int] = None,
170
+ device: Optional[Union[str, torch.device]] = None,
171
+ timesteps: Optional[List[int]] = None,
172
+ sigmas: Optional[List[float]] = None,
173
+ **kwargs,
174
+ ):
175
+ r"""
176
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
177
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
178
+
179
+ Args:
180
+ scheduler (`SchedulerMixin`):
181
+ The scheduler to get timesteps from.
182
+ num_inference_steps (`int`):
183
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
184
+ must be `None`.
185
+ device (`str` or `torch.device`, *optional*):
186
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
187
+ timesteps (`List[int]`, *optional*):
188
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
189
+ `num_inference_steps` and `sigmas` must be `None`.
190
+ sigmas (`List[float]`, *optional*):
191
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
192
+ `num_inference_steps` and `timesteps` must be `None`.
193
+
194
+ Returns:
195
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
196
+ second element is the number of inference steps.
197
+ """
198
+ if timesteps is not None and sigmas is not None:
199
+ raise ValueError(
200
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
201
+ )
202
+ if timesteps is not None:
203
+ accepts_timesteps = "timesteps" in set(
204
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
205
+ )
206
+ if not accepts_timesteps:
207
+ raise ValueError(
208
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
209
+ f" timestep schedules. Please check whether you are using the correct scheduler."
210
+ )
211
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
212
+ timesteps = scheduler.timesteps
213
+ num_inference_steps = len(timesteps)
214
+ elif sigmas is not None:
215
+ accept_sigmas = "sigmas" in set(
216
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
217
+ )
218
+ if not accept_sigmas:
219
+ raise ValueError(
220
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
221
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
222
+ )
223
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
224
+ timesteps = scheduler.timesteps
225
+ num_inference_steps = len(timesteps)
226
+ else:
227
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
228
+ timesteps = scheduler.timesteps
229
+ return timesteps, num_inference_steps
230
+
231
+
232
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
233
+ def retrieve_latents(
234
+ encoder_output: torch.Tensor,
235
+ generator: Optional[torch.Generator] = None,
236
+ sample_mode: str = "sample",
237
+ ):
238
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
239
+ return encoder_output.latent_dist.sample(generator)
240
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
241
+ return encoder_output.latent_dist.mode()
242
+ elif hasattr(encoder_output, "latents"):
243
+ return encoder_output.latents
244
+ else:
245
+ raise AttributeError("Could not access latents of provided encoder_output")
246
+
247
+
248
+ @dataclass
249
+ class AetherV1PipelineOutput(BaseOutput):
250
+ rgb: np.ndarray
251
+ disparity: np.ndarray
252
+ raymap: np.ndarray
253
+
254
+
255
+ class AetherV1PipelineCogVideoX(CogVideoXImageToVideoPipeline):
256
+ _supported_tasks = ["reconstruction", "prediction", "planning"]
257
+ _default_num_inference_steps = {
258
+ "reconstruction": 4,
259
+ "prediction": 50,
260
+ "planning": 50,
261
+ }
262
+ _default_guidance_scale = {
263
+ "reconstruction": 1.0,
264
+ "prediction": 3.0,
265
+ "planning": 3.0,
266
+ }
267
+ _default_use_dynamic_cfg = {
268
+ "reconstruction": False,
269
+ "prediction": True,
270
+ "planning": True,
271
+ }
272
+ _base_fps = 12
273
+
274
+ def __init__(
275
+ self,
276
+ tokenizer: AutoTokenizer,
277
+ text_encoder: T5EncoderModel,
278
+ vae: AutoencoderKLCogVideoX,
279
+ scheduler: CogVideoXDPMScheduler,
280
+ transformer: CogVideoXTransformer3DModel,
281
+ ):
282
+ super().__init__(
283
+ tokenizer=tokenizer,
284
+ text_encoder=text_encoder,
285
+ vae=vae,
286
+ scheduler=scheduler,
287
+ transformer=transformer,
288
+ )
289
+
290
+ self.empty_prompt_embeds, _ = self.encode_prompt(
291
+ prompt="",
292
+ negative_prompt=None,
293
+ do_classifier_free_guidance=False,
294
+ num_videos_per_prompt=1,
295
+ prompt_embeds=None,
296
+ )
297
+
298
+ def _prepare_rotary_positional_embeddings(
299
+ self,
300
+ height: int,
301
+ width: int,
302
+ num_frames: int,
303
+ device: torch.device,
304
+ fps: Optional[int] = None,
305
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ grid_height = height // (
307
+ self.vae_scale_factor_spatial * self.transformer.config.patch_size
308
+ )
309
+ grid_width = width // (
310
+ self.vae_scale_factor_spatial * self.transformer.config.patch_size
311
+ )
312
+
313
+ p = self.transformer.config.patch_size
314
+ p_t = self.transformer.config.patch_size_t
315
+
316
+ base_size_width = self.transformer.config.sample_width // p
317
+ base_size_height = self.transformer.config.sample_height // p
318
+
319
+ if p_t is None:
320
+ # CogVideoX 1.0
321
+ grid_crops_coords = get_resize_crop_region_for_grid(
322
+ (grid_height, grid_width), base_size_width, base_size_height
323
+ )
324
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
325
+ embed_dim=self.transformer.config.attention_head_dim,
326
+ crops_coords=grid_crops_coords,
327
+ grid_size=(grid_height, grid_width),
328
+ temporal_size=num_frames,
329
+ device=device,
330
+ fps_factor=self._base_fps / fps,
331
+ )
332
+ else:
333
+ # CogVideoX 1.5
334
+ base_num_frames = (num_frames + p_t - 1) // p_t
335
+
336
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
337
+ embed_dim=self.transformer.config.attention_head_dim,
338
+ crops_coords=None,
339
+ grid_size=(grid_height, grid_width),
340
+ temporal_size=base_num_frames,
341
+ grid_type="slice",
342
+ max_size=(base_size_height, base_size_width),
343
+ device=device,
344
+ fps_factor=self._base_fps / fps,
345
+ )
346
+
347
+ return freqs_cos, freqs_sin
348
+
349
+ def check_inputs(
350
+ self,
351
+ task,
352
+ image,
353
+ video,
354
+ goal,
355
+ raymap,
356
+ height,
357
+ width,
358
+ num_frames,
359
+ fps,
360
+ ):
361
+ if task not in self._supported_tasks:
362
+ raise ValueError(f"`task` has to be one of {self._supported_tasks}.")
363
+
364
+ if image is None and video is None:
365
+ raise ValueError("`image` or `video` has to be provided.")
366
+
367
+ if image is not None and video is not None:
368
+ raise ValueError("`image` and `video` cannot both be provided.")
369
+
370
+ if image is not None:
371
+ if task == "reconstruction":
372
+ raise ValueError("`image` is not supported for `reconstruction` task.")
373
+ if (
374
+ not isinstance(image, torch.Tensor)
375
+ and not isinstance(image, np.ndarray)
376
+ and not isinstance(image, PIL.Image.Image)
377
+ ):
378
+ raise ValueError(
379
+ "`image` has to be of type `torch.Tensor` or `np.ndarray` or `PIL.Image.Image` but is"
380
+ f" {type(image)}"
381
+ )
382
+
383
+ if goal is not None:
384
+ if task != "planning":
385
+ raise ValueError("`goal` is only supported for `planning` task.")
386
+
387
+ if (
388
+ not isinstance(goal, torch.Tensor)
389
+ and not isinstance(goal, np.ndarray)
390
+ and not isinstance(goal, PIL.Image.Image)
391
+ ):
392
+ raise ValueError(
393
+ "`goal` has to be of type `torch.Tensor` or `np.ndarray` or `PIL.Image.Image` but is"
394
+ f" {type(goal)}"
395
+ )
396
+
397
+ if video is not None:
398
+ if task != "reconstruction":
399
+ raise ValueError("`video` is only supported for `reconstruction` task.")
400
+
401
+ if (
402
+ not isinstance(video, torch.Tensor)
403
+ and not isinstance(video, np.ndarray)
404
+ and not (
405
+ isinstance(video, list)
406
+ and all(isinstance(v, PIL.Image.Image) for v in video)
407
+ )
408
+ ):
409
+ raise ValueError(
410
+ "`video` has to be of type `torch.Tensor` or `np.ndarray` or `List[PIL.Image.Image]` but is"
411
+ f" {type(video)}"
412
+ )
413
+
414
+ if height % 8 != 0 or width % 8 != 0:
415
+ raise ValueError(
416
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
417
+ )
418
+
419
+ if num_frames is None:
420
+ raise ValueError("`num_frames` is required.")
421
+
422
+ if num_frames not in [17, 25, 33, 41]:
423
+ raise ValueError("`num_frames` has to be one of [17, 25, 33, 41].")
424
+
425
+ if fps not in [8, 10, 12, 15, 24]:
426
+ raise ValueError("`fps` has to be one of [8, 10, 12, 15, 24].")
427
+
428
+ if (
429
+ raymap is not None
430
+ and not isinstance(raymap, torch.Tensor)
431
+ and not isinstance(raymap, np.ndarray)
432
+ ):
433
+ raise ValueError(
434
+ "`raymap` has to be of type `torch.Tensor` or `np.ndarray`."
435
+ )
436
+
437
+ if raymap is not None:
438
+ if raymap.shape[-4:] != (
439
+ num_frames,
440
+ 6,
441
+ height // self.vae_scale_factor_spatial,
442
+ width // self.vae_scale_factor_spatial,
443
+ ):
444
+ raise ValueError(
445
+ f"`raymap` shape is not correct. "
446
+ f"Expected {num_frames, 6, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial}, "
447
+ f"got {raymap.shape}."
448
+ )
449
+
450
+ def _preprocess_image(self, image, height, width):
451
+ if isinstance(image, torch.Tensor):
452
+ image = image.cpu().numpy()
453
+ if image.dtype == np.uint8:
454
+ image = image.astype(np.float32) / 255.0
455
+ if image.ndim == 3:
456
+ image = [image]
457
+ image = imcrop_center(image, height, width)
458
+ image = self.video_processor.preprocess(image, height, width)
459
+ return image
460
+
461
+ def preprocess_inputs(
462
+ self,
463
+ image,
464
+ goal,
465
+ video,
466
+ raymap,
467
+ height,
468
+ width,
469
+ num_frames,
470
+ ):
471
+ if image is not None:
472
+ if isinstance(image, PIL.Image.Image):
473
+ image = self.video_processor.preprocess(
474
+ image, height, width, resize_mode="crop"
475
+ ).to(self._execution_device)
476
+ else:
477
+ image = self._preprocess_image(image, height, width).to(
478
+ self._execution_device
479
+ )
480
+ if goal is not None:
481
+ if isinstance(goal, PIL.Image.Image):
482
+ goal = self.video_processor.preprocess(
483
+ goal, height, width, resize_mode="crop"
484
+ ).to(self._execution_device)
485
+ else:
486
+ goal = self._preprocess_image(goal, height, width).to(
487
+ self._execution_device
488
+ )
489
+ if video is not None:
490
+ if isinstance(video, list) and all(
491
+ isinstance(v, PIL.Image.Image) for v in video
492
+ ):
493
+ video = self.video_processor.preprocess(
494
+ video, height, width, resize_mode="crop"
495
+ ).to(self._execution_device)
496
+ else:
497
+ video = self._preprocess_image(video, height, width).to(
498
+ self._execution_device
499
+ )
500
+ # TODO: check raymap shape
501
+ if raymap is not None:
502
+ if isinstance(raymap, np.ndarray):
503
+ raymap = torch.from_numpy(raymap).to(self._execution_device)
504
+ if raymap.ndim == 4:
505
+ raymap = raymap.unsqueeze(0)
506
+
507
+ return image, goal, video, raymap
508
+
509
+ @torch.no_grad()
510
+ def prepare_latents(
511
+ self,
512
+ image: Optional[torch.Tensor] = None,
513
+ goal: Optional[torch.Tensor] = None,
514
+ video: Optional[torch.Tensor] = None,
515
+ raymap: Optional[torch.Tensor] = None,
516
+ batch_size: int = 1,
517
+ num_frames: int = 13,
518
+ height: int = 60,
519
+ width: int = 90,
520
+ dtype: Optional[torch.dtype] = None,
521
+ device: Optional[torch.device] = None,
522
+ generator: Optional[torch.Generator] = None,
523
+ ):
524
+ if isinstance(generator, list) and len(generator) != batch_size:
525
+ raise ValueError(
526
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
527
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
528
+ )
529
+
530
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
531
+ shape = (
532
+ batch_size,
533
+ num_frames,
534
+ 56,
535
+ height // self.vae_scale_factor_spatial,
536
+ width // self.vae_scale_factor_spatial,
537
+ )
538
+
539
+ # For CogVideoX1.5, the latent should add 1 for padding (Not use)
540
+ if self.transformer.config.patch_size_t is not None:
541
+ shape = (
542
+ shape[:1]
543
+ + (shape[1] + shape[1] % self.transformer.config.patch_size_t,)
544
+ + shape[2:]
545
+ )
546
+
547
+ if image is not None:
548
+ image = image.unsqueeze(2)
549
+ if isinstance(generator, list):
550
+ image_latents = [
551
+ retrieve_latents(
552
+ self.vae.encode(image[i].unsqueeze(0)), generator[i]
553
+ )
554
+ for i in range(batch_size)
555
+ ]
556
+ else:
557
+ image_latents = [
558
+ retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator)
559
+ for img in image
560
+ ]
561
+
562
+ image_latents = (
563
+ torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4)
564
+ ) # [B, F, C, H, W]
565
+
566
+ if not self.vae.config.invert_scale_latents:
567
+ image_latents = self.vae_scaling_factor_image * image_latents
568
+ else:
569
+ # This is awkward but required because the CogVideoX team forgot to multiply the
570
+ # scaling factor during training :)
571
+ image_latents = 1 / self.vae_scaling_factor_image * image_latents
572
+
573
+ if goal is not None:
574
+ goal = goal.unsqueeze(2)
575
+ if isinstance(generator, list):
576
+ goal_latents = [
577
+ retrieve_latents(
578
+ self.vae.encode(goal[i].unsqueeze(0)), generator[i]
579
+ )
580
+ for i in range(batch_size)
581
+ ]
582
+ else:
583
+ goal_latents = [
584
+ retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator)
585
+ for img in goal
586
+ ]
587
+
588
+ goal_latents = (
589
+ torch.cat(goal_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4)
590
+ ) # [B, F, C, H, W]
591
+
592
+ if not self.vae.config.invert_scale_latents:
593
+ goal_latents = self.vae_scaling_factor_image * goal_latents
594
+ else:
595
+ # This is awkward but required because the CogVideoX team forgot to multiply the
596
+ # scaling factor during training :)
597
+ goal_latents = 1 / self.vae_scaling_factor_image * goal_latents
598
+
599
+ if video is not None:
600
+ if video.ndim == 4:
601
+ video = video.unsqueeze(0)
602
+
603
+ video = video.permute(0, 2, 1, 3, 4)
604
+ if isinstance(generator, list):
605
+ video_latents = [
606
+ retrieve_latents(
607
+ self.vae.encode(video[i].unsqueeze(0)), generator[i]
608
+ )
609
+ for i in range(batch_size)
610
+ ]
611
+ else:
612
+ video_latents = [
613
+ retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator)
614
+ for img in video
615
+ ]
616
+
617
+ video_latents = (
618
+ torch.cat(video_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4)
619
+ ) # [B, F, C, H, W]
620
+
621
+ if not self.vae.config.invert_scale_latents:
622
+ video_latents = self.vae_scaling_factor_image * video_latents
623
+ else:
624
+ # This is awkward but required because the CogVideoX team forgot to multiply the
625
+ # scaling factor during training :)
626
+ video_latents = 1 / self.vae_scaling_factor_image * video_latents
627
+
628
+ if image is not None and goal is None:
629
+ padding_shape = (
630
+ batch_size,
631
+ num_frames - image_latents.shape[1],
632
+ *image_latents.shape[2:],
633
+ )
634
+ padding = torch.zeros(padding_shape, device=device, dtype=dtype)
635
+ condition_latents = torch.cat([image_latents, padding], dim=1)
636
+ elif goal is not None:
637
+ padding_shape = (
638
+ batch_size,
639
+ num_frames - goal_latents.shape[1] - image_latents.shape[1],
640
+ *image_latents.shape[2:],
641
+ )
642
+ padding = torch.zeros(padding_shape, device=device, dtype=dtype)
643
+ condition_latents = torch.cat([image_latents, padding, goal_latents], dim=1)
644
+ elif video is not None:
645
+ condition_latents = video_latents
646
+
647
+ if raymap is not None:
648
+ if raymap.shape[1] % self.vae_scale_factor_temporal != 0:
649
+ # repeat
650
+ raymap = torch.cat(
651
+ [
652
+ raymap[
653
+ :,
654
+ : self.vae_scale_factor_temporal
655
+ - raymap.shape[1] % self.vae_scale_factor_temporal,
656
+ ],
657
+ raymap,
658
+ ],
659
+ dim=1,
660
+ )
661
+ camera_conditions = rearrange(
662
+ raymap,
663
+ "b (n t) c h w -> b t (n c) h w",
664
+ n=self.vae_scale_factor_temporal,
665
+ )
666
+ else:
667
+ camera_conditions = torch.zeros(
668
+ batch_size,
669
+ num_frames,
670
+ 24,
671
+ height // self.vae_scale_factor_spatial,
672
+ width // self.vae_scale_factor_spatial,
673
+ device=device,
674
+ dtype=dtype,
675
+ )
676
+
677
+ condition_latents = torch.cat([condition_latents, camera_conditions], dim=2)
678
+ latents = randn_tensor(shape, device=device, generator=generator, dtype=dtype)
679
+
680
+ # scale the initial noise by the standard deviation required by the scheduler
681
+ latents = latents * self.scheduler.init_noise_sigma
682
+
683
+ return latents, condition_latents
684
+
685
+ @torch.no_grad()
686
+ def __call__(
687
+ self,
688
+ task: Optional[str] = None,
689
+ image: Optional[PipelineImageInput] = None,
690
+ video: Optional[PipelineImageInput] = None,
691
+ goal: Optional[PipelineImageInput] = None,
692
+ raymap: Optional[Union[torch.Tensor, np.ndarray]] = None,
693
+ height: Optional[int] = None,
694
+ width: Optional[int] = None,
695
+ num_frames: Optional[int] = None,
696
+ num_inference_steps: Optional[int] = None,
697
+ timesteps: Optional[List[int]] = None,
698
+ guidance_scale: Optional[float] = None,
699
+ use_dynamic_cfg: bool = False,
700
+ num_videos_per_prompt: int = 1,
701
+ eta: float = 0.0,
702
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
703
+ return_dict: bool = True,
704
+ attention_kwargs: Optional[Dict] = None,
705
+ fps: Optional[int] = None,
706
+ ) -> Union[AetherV1PipelineOutput, Tuple]:
707
+ if task is None:
708
+ if video is not None:
709
+ task = "reconstruction"
710
+ elif goal is not None:
711
+ task = "planning"
712
+ else:
713
+ task = "prediction"
714
+
715
+ height = (
716
+ height
717
+ or self.transformer.config.sample_height * self.vae_scale_factor_spatial
718
+ )
719
+ width = (
720
+ width
721
+ or self.transformer.config.sample_width * self.vae_scale_factor_spatial
722
+ )
723
+ num_frames = num_frames or self.transformer.config.sample_frames
724
+ fps = fps or self._base_fps
725
+
726
+ num_videos_per_prompt = 1
727
+
728
+ # 1. Check inputs. Raise error if not correct
729
+ self.check_inputs(
730
+ task=task,
731
+ image=image,
732
+ video=video,
733
+ goal=goal,
734
+ raymap=raymap,
735
+ height=height,
736
+ width=width,
737
+ num_frames=num_frames,
738
+ fps=fps,
739
+ )
740
+
741
+ # 2. Preprocess inputs
742
+ image, goal, video, raymap = self.preprocess_inputs(
743
+ image=image,
744
+ goal=goal,
745
+ video=video,
746
+ raymap=raymap,
747
+ height=height,
748
+ width=width,
749
+ num_frames=num_frames,
750
+ )
751
+ self._guidance_scale = guidance_scale
752
+ self._current_timestep = None
753
+ self._attention_kwargs = attention_kwargs
754
+ self._interrupt = False
755
+
756
+ batch_size = 1
757
+
758
+ device = self._execution_device
759
+
760
+ # 3. Encode input prompt
761
+ prompt_embeds = self.empty_prompt_embeds.to(device)
762
+
763
+ num_inference_steps = (
764
+ num_inference_steps or self._default_num_inference_steps[task]
765
+ )
766
+ guidance_scale = guidance_scale or self._default_guidance_scale[task]
767
+ use_dynamic_cfg = use_dynamic_cfg or self._default_use_dynamic_cfg[task]
768
+
769
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
770
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
771
+ # corresponds to doing no classifier free guidance.
772
+ do_classifier_free_guidance = guidance_scale > 1.0
773
+
774
+ # 4. Prepare timesteps
775
+ timesteps, num_inference_steps = retrieve_timesteps(
776
+ self.scheduler, num_inference_steps, device, timesteps
777
+ )
778
+ self._num_timesteps = len(timesteps)
779
+
780
+ # 5. Prepare latents
781
+ latents, condition_latents = self.prepare_latents(
782
+ image,
783
+ goal,
784
+ video,
785
+ raymap,
786
+ batch_size * num_videos_per_prompt,
787
+ num_frames,
788
+ height,
789
+ width,
790
+ prompt_embeds.dtype,
791
+ device,
792
+ generator,
793
+ )
794
+
795
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
796
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
797
+
798
+ # 7. Create rotary embeds if required
799
+ image_rotary_emb = (
800
+ self._prepare_rotary_positional_embeddings(
801
+ height, width, latents.size(1), device, fps=fps
802
+ )
803
+ if self.transformer.config.use_rotary_positional_embeddings
804
+ else None
805
+ )
806
+
807
+ # 8. Create ofs embeds if required
808
+ ofs_emb = (
809
+ None
810
+ if self.transformer.config.ofs_embed_dim is None
811
+ else latents.new_full((1,), fill_value=2.0)
812
+ )
813
+
814
+ # 8. Denoising loop
815
+ num_warmup_steps = max(
816
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
817
+ )
818
+
819
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
820
+ # for DPM-solver++
821
+ old_pred_original_sample = None
822
+ for i, t in enumerate(timesteps):
823
+ if self.interrupt:
824
+ continue
825
+
826
+ self._current_timestep = t
827
+ latent_model_input = (
828
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
829
+ )
830
+ latent_model_input = self.scheduler.scale_model_input(
831
+ latent_model_input, t
832
+ )
833
+
834
+ if do_classifier_free_guidance:
835
+ if task == "planning":
836
+ assert goal is not None
837
+ uncond = condition_latents.clone()
838
+ uncond[:, :, : self.vae.config.latent_channels] = 0
839
+ latent_condition = torch.cat([uncond, condition_latents])
840
+ elif task == "prediction":
841
+ uncond = condition_latents.clone()
842
+ uncond[:, :1, : self.vae.config.latent_channels] = 0
843
+ latent_condition = torch.cat([uncond, condition_latents])
844
+ else:
845
+ raise ValueError(
846
+ f"Task {task} not supported for classifier-free guidance."
847
+ )
848
+
849
+ else:
850
+ latent_condition = condition_latents
851
+
852
+ latent_model_input = torch.cat(
853
+ [latent_model_input, latent_condition], dim=2
854
+ )
855
+
856
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
857
+ timestep = t.expand(latent_model_input.shape[0])
858
+
859
+ # predict noise model_output
860
+ noise_pred = self.transformer(
861
+ hidden_states=latent_model_input,
862
+ encoder_hidden_states=prompt_embeds.repeat(
863
+ latent_model_input.shape[0], 1, 1
864
+ ),
865
+ timestep=timestep,
866
+ ofs=ofs_emb,
867
+ image_rotary_emb=image_rotary_emb,
868
+ attention_kwargs=attention_kwargs,
869
+ return_dict=False,
870
+ )[0]
871
+ noise_pred = noise_pred.float()
872
+
873
+ # perform guidance
874
+ if use_dynamic_cfg:
875
+ self._guidance_scale = 1 + guidance_scale * (
876
+ (
877
+ 1
878
+ - math.cos(
879
+ math.pi
880
+ * (
881
+ (num_inference_steps - t.item())
882
+ / num_inference_steps
883
+ )
884
+ ** 5.0
885
+ )
886
+ )
887
+ / 2
888
+ )
889
+
890
+ if do_classifier_free_guidance:
891
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
892
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
893
+ noise_pred_text - noise_pred_uncond
894
+ )
895
+
896
+ # compute the previous noisy sample x_t -> x_t-1
897
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
898
+ latents = self.scheduler.step(
899
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
900
+ )[0]
901
+ else:
902
+ latents, old_pred_original_sample = self.scheduler.step(
903
+ noise_pred,
904
+ old_pred_original_sample,
905
+ t,
906
+ timesteps[i - 1] if i > 0 else None,
907
+ latents,
908
+ **extra_step_kwargs,
909
+ return_dict=False,
910
+ )
911
+ latents = latents.to(prompt_embeds.dtype)
912
+
913
+ if i == len(timesteps) - 1 or (
914
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
915
+ ):
916
+ progress_bar.update()
917
+
918
+ self._current_timestep = None
919
+
920
+ rgb_latents = latents[:, :, : self.vae.config.latent_channels]
921
+ disparity_latents = latents[
922
+ :, :, self.vae.config.latent_channels : self.vae.config.latent_channels * 2
923
+ ]
924
+ camera_latents = latents[:, :, self.vae.config.latent_channels * 2 :]
925
+
926
+ rgb_video = self.decode_latents(rgb_latents)
927
+ rgb_video = self.video_processor.postprocess_video(
928
+ video=rgb_video, output_type="np"
929
+ )
930
+
931
+ disparity_video = self.decode_latents(disparity_latents)
932
+ disparity_video = disparity_video.mean(dim=1, keepdim=False)
933
+ disparity_video = disparity_video * 0.5 + 0.5
934
+ disparity_video = torch.square(disparity_video)
935
+ disparity_video = disparity_video.float().cpu().numpy()
936
+
937
+ raymap = (
938
+ rearrange(camera_latents, "b t (n c) h w -> b (n t) c h w", n=4)[
939
+ :, -rgb_video.shape[1] :, :, :
940
+ ]
941
+ .cpu()
942
+ .numpy()
943
+ )
944
+
945
+ # Offload all models
946
+ self.maybe_free_model_hooks()
947
+
948
+ if not return_dict:
949
+ return (
950
+ rgb_video,
951
+ disparity_video,
952
+ raymap,
953
+ )
954
+
955
+ return AetherV1PipelineOutput(
956
+ rgb=rgb_video.squeeze(0),
957
+ disparity=disparity_video.squeeze(0),
958
+ raymap=raymap.squeeze(0),
959
+ )
aether/utils/__init__.py ADDED
File without changes
aether/utils/postprocess_utils.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import matplotlib
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from plyfile import PlyData, PlyElement
11
+
12
+
13
+ def signed_log1p_inverse(x):
14
+ """
15
+ Computes the inverse of signed_log1p: x = sign(x) * (exp(abs(x)) - 1).
16
+
17
+ Args:
18
+ y (torch.Tensor): Input tensor (output of signed_log1p).
19
+
20
+ Returns:
21
+ torch.Tensor: Original tensor x.
22
+ """
23
+ if isinstance(x, torch.Tensor):
24
+ return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
25
+ elif isinstance(x, np.ndarray):
26
+ return np.sign(x) * (np.exp(np.abs(x)) - 1)
27
+ else:
28
+ raise TypeError("Input must be a torch.Tensor or numpy.ndarray")
29
+
30
+
31
+ def colorize_depth(depth, cmap="Spectral"):
32
+ min_d, max_d = (depth[depth > 0]).min(), (depth[depth > 0]).max()
33
+ depth = (max_d - depth) / (max_d - min_d)
34
+
35
+ cm = matplotlib.colormaps[cmap]
36
+ depth = depth.clip(0, 1)
37
+ depth = cm(depth, bytes=False)[..., 0:3]
38
+ return depth
39
+
40
+
41
+ def save_ply(pointmap, image, output_file, downsample=20, mask=None):
42
+ _, h, w, _ = pointmap.shape
43
+ image = image[:, :h, :w]
44
+ pointmap = pointmap[:, :h, :w]
45
+
46
+ points = pointmap.reshape(-1, 3) # (H*W, 3)
47
+ colors = image.reshape(-1, 3) # (H*W, 3)
48
+ if mask is not None:
49
+ points = points[mask.reshape(-1)]
50
+ colors = colors[mask.reshape(-1)]
51
+
52
+ indices = np.random.choice(
53
+ colors.shape[0], int(colors.shape[0] / downsample), replace=False
54
+ )
55
+ points = points[indices]
56
+ colors = colors[indices]
57
+
58
+ vertices = []
59
+ for p, c in zip(points, colors):
60
+ vertex = (p[0], p[1], p[2], int(c[0]), int(c[1]), int(c[2]))
61
+ vertices.append(vertex)
62
+
63
+ vertex_dtype = np.dtype(
64
+ [
65
+ ("x", "f4"),
66
+ ("y", "f4"),
67
+ ("z", "f4"),
68
+ ("red", "u1"),
69
+ ("green", "u1"),
70
+ ("blue", "u1"),
71
+ ]
72
+ )
73
+ vertex_array = np.array(vertices, dtype=vertex_dtype)
74
+
75
+ ply_element = PlyElement.describe(vertex_array, "vertex")
76
+ PlyData([ply_element], text=True).write(output_file)
77
+
78
+
79
+ def fov_to_focal(fovx, fovy, h, w):
80
+ focal_x = w * 0.5 / np.tan(fovx)
81
+ focal_y = h * 0.5 / np.tan(fovy)
82
+ focal = (focal_x + focal_y) / 2
83
+ return focal
84
+
85
+
86
+ def get_rays(pose, h, w, focal=None, fovx=None, fovy=None):
87
+ import torch.nn.functional as F
88
+
89
+ pose = torch.from_numpy(pose).float()
90
+ x, y = torch.meshgrid(
91
+ torch.arange(w),
92
+ torch.arange(h),
93
+ indexing="xy",
94
+ )
95
+ x = x.flatten().unsqueeze(0).repeat(pose.shape[0], 1)
96
+ y = y.flatten().unsqueeze(0).repeat(pose.shape[0], 1)
97
+
98
+ cx = w * 0.5
99
+ cy = h * 0.5
100
+ intrinsics, focal = get_intrinsics(pose.shape[0], h, w, fovx, fovy, focal)
101
+ focal = torch.from_numpy(focal).float()
102
+ camera_dirs = F.pad(
103
+ torch.stack(
104
+ [
105
+ (x - cx + 0.5) / focal.unsqueeze(-1),
106
+ (y - cy + 0.5) / focal.unsqueeze(-1),
107
+ ],
108
+ dim=-1,
109
+ ),
110
+ (0, 1),
111
+ value=1.0,
112
+ ) # [t, hw, 3]
113
+
114
+ pose = pose.to(dtype=camera_dirs.dtype)
115
+ rays_d = camera_dirs @ pose[:, :3, :3].transpose(1, 2) # [t, hw, 3]
116
+
117
+ rays_o = pose[:, :3, 3].unsqueeze(1).expand_as(rays_d) # [hw, 3]
118
+
119
+ rays_o = rays_o.view(pose.shape[0], h, w, 3)
120
+ rays_d = rays_d.view(pose.shape[0], h, w, 3)
121
+
122
+ return rays_o.float().numpy(), rays_d.float().numpy(), intrinsics
123
+
124
+
125
+ def get_intrinsics(batch_size, h, w, fovx=None, fovy=None, focal=None):
126
+ if focal is None:
127
+ focal_x = w * 0.5 / np.tan(fovx)
128
+ focal_y = h * 0.5 / np.tan(fovy)
129
+ focal = (focal_x + focal_y) / 2
130
+ cx = w * 0.5
131
+ cy = h * 0.5
132
+ intrinsics = np.zeros((batch_size, 3, 3))
133
+ intrinsics[:, 0, 0] = focal
134
+ intrinsics[:, 1, 1] = focal
135
+ intrinsics[:, 0, 2] = cx
136
+ intrinsics[:, 1, 2] = cy
137
+ intrinsics[:, 2, 2] = 1.0
138
+
139
+ return intrinsics, focal
140
+
141
+
142
+ def save_pointmap(
143
+ rgb,
144
+ disparity,
145
+ raymap,
146
+ save_file,
147
+ vae_downsample_scale=8,
148
+ camera_pose=None,
149
+ ray_o_scale_inv=1.0,
150
+ max_depth=1e2,
151
+ save_full_pcd_videos=False,
152
+ smooth_camera=False,
153
+ smooth_method="kalman", # or simple
154
+ **kwargs,
155
+ ):
156
+ """
157
+
158
+ Args:
159
+ rgb (numpy.ndarray): Shape of (t, h, w, 3), range [0, 1]
160
+ disparity (numpy.ndarray): Shape of (t, h, w), range [0, 1]
161
+ raymap (numpy.ndarray): Shape of (t, 6, h // 8, w // 8)
162
+ ray_o_scale_inv (float, optional): A `ray_o` scale constant. Defaults to 10.
163
+ """
164
+ rgb = np.clip(rgb, 0, 1) * 255
165
+
166
+ pointmap_dict = postprocess_pointmap(
167
+ disparity,
168
+ raymap,
169
+ vae_downsample_scale,
170
+ camera_pose,
171
+ ray_o_scale_inv=ray_o_scale_inv,
172
+ smooth_camera=smooth_camera,
173
+ smooth_method=smooth_method,
174
+ **kwargs,
175
+ )
176
+
177
+ save_ply(
178
+ pointmap_dict["pointmap"],
179
+ rgb,
180
+ save_file,
181
+ mask=(pointmap_dict["depth"] < max_depth),
182
+ )
183
+
184
+ if save_full_pcd_videos:
185
+ pcd_dict = {
186
+ "points": pointmap_dict["pointmap"],
187
+ "colors": rgb,
188
+ "intrinsics": pointmap_dict["intrinsics"],
189
+ "poses": pointmap_dict["camera_pose"],
190
+ "depths": pointmap_dict["depth"],
191
+ }
192
+ np.save(save_file.replace(".ply", "_pcd.npy"), pcd_dict)
193
+
194
+ return pointmap_dict
195
+
196
+
197
+ def raymap_to_poses(
198
+ raymap, camera_pose=None, ray_o_scale_inv=1.0, return_intrinsics=True
199
+ ):
200
+ ts = raymap.shape[0]
201
+ if (not return_intrinsics) and (camera_pose is not None):
202
+ return camera_pose, None, None
203
+
204
+ raymap[:, 3:] = signed_log1p_inverse(raymap[:, 3:])
205
+
206
+ # Extract ray origins and directions
207
+ ray_o = (
208
+ rearrange(raymap[:, 3:], "t c h w -> t h w c") * ray_o_scale_inv
209
+ ) # [T, H, W, C]
210
+ ray_d = rearrange(raymap[:, :3], "t c h w -> t h w c") # [T, H, W, C]
211
+
212
+ # Compute orientation and directions
213
+ orient = ray_o.reshape(ts, -1, 3).mean(axis=1) # T, 3
214
+ image_orient = (ray_o + ray_d).reshape(ts, -1, 3).mean(axis=1) # T, 3
215
+ Focal = np.linalg.norm(image_orient - orient, axis=-1) # T,
216
+ Z_Dir = image_orient - orient # T, 3
217
+
218
+ # Compute the width (W) and field of view (FoV_x)
219
+ W_Left = ray_d[:, :, :1, :].reshape(ts, -1, 3).mean(axis=1)
220
+ W_Right = ray_d[:, :, -1:, :].reshape(ts, -1, 3).mean(axis=1)
221
+ W = W_Right - W_Left
222
+ W_real = (
223
+ np.linalg.norm(np.cross(W, Z_Dir), axis=-1)
224
+ / (raymap.shape[-1] - 1)
225
+ * raymap.shape[-1]
226
+ )
227
+ Fov_x = np.arctan(W_real / (2 * Focal))
228
+
229
+ # Compute the height (H) and field of view (FoV_y)
230
+ H_Up = ray_d[:, :1, :, :].reshape(ts, -1, 3).mean(axis=1)
231
+ H_Down = ray_d[:, -1:, :, :].reshape(ts, -1, 3).mean(axis=1)
232
+ H = H_Up - H_Down
233
+ H_real = (
234
+ np.linalg.norm(np.cross(H, Z_Dir), axis=-1)
235
+ / (raymap.shape[-2] - 1)
236
+ * raymap.shape[-2]
237
+ )
238
+ Fov_y = np.arctan(H_real / (2 * Focal))
239
+
240
+ # Compute X, Y, and Z directions for the camera
241
+ X_Dir = W_Right - W_Left
242
+ Y_Dir = np.cross(Z_Dir, X_Dir)
243
+ X_Dir = np.cross(Y_Dir, Z_Dir)
244
+
245
+ X_Dir /= np.linalg.norm(X_Dir, axis=-1, keepdims=True)
246
+ Y_Dir /= np.linalg.norm(Y_Dir, axis=-1, keepdims=True)
247
+ Z_Dir /= np.linalg.norm(Z_Dir, axis=-1, keepdims=True)
248
+
249
+ # Create the camera-to-world (camera_pose) transformation matrix
250
+ if camera_pose is None:
251
+ camera_pose = np.zeros((ts, 4, 4))
252
+ camera_pose[:, :3, 0] = X_Dir
253
+ camera_pose[:, :3, 1] = Y_Dir
254
+ camera_pose[:, :3, 2] = Z_Dir
255
+ camera_pose[:, :3, 3] = orient
256
+ camera_pose[:, 3, 3] = 1.0
257
+
258
+ return camera_pose, Fov_x, Fov_y
259
+
260
+
261
+ def postprocess_pointmap(
262
+ disparity,
263
+ raymap,
264
+ vae_downsample_scale=8,
265
+ camera_pose=None,
266
+ focal=None,
267
+ ray_o_scale_inv=1.0,
268
+ smooth_camera=False,
269
+ smooth_method="simple",
270
+ **kwargs,
271
+ ):
272
+ """
273
+
274
+ Args:
275
+ disparity (numpy.ndarray): Shape of (t, h, w), range [0, 1]
276
+ raymap (numpy.ndarray): Shape of (t, 6, h // 8, w // 8)
277
+ ray_o_scale_inv (float, optional): A `ray_o` scale constant. Defaults to 10.
278
+ """
279
+ depth = np.clip(1.0 / np.clip(disparity, 1e-3, 1), 0, 1e8)
280
+
281
+ camera_pose, fov_x, fov_y = raymap_to_poses(
282
+ raymap,
283
+ camera_pose=camera_pose,
284
+ ray_o_scale_inv=ray_o_scale_inv,
285
+ return_intrinsics=(focal is not None),
286
+ )
287
+ if focal is None:
288
+ focal = fov_to_focal(
289
+ fov_x,
290
+ fov_y,
291
+ int(raymap.shape[2] * vae_downsample_scale),
292
+ int(raymap.shape[3] * vae_downsample_scale),
293
+ )
294
+
295
+ if smooth_camera:
296
+ # Check if sequence is static
297
+ is_static, trans_diff, rot_diff = detect_static_sequence(camera_pose)
298
+
299
+ if is_static:
300
+ print(
301
+ f"Detected static/near-static sequence (trans_diff={trans_diff:.6f}, rot_diff={rot_diff:.6f})"
302
+ )
303
+ # Apply stronger smoothing for static sequences
304
+ camera_pose = adaptive_pose_smoothing(camera_pose, trans_diff, rot_diff)
305
+ else:
306
+ if smooth_method == "simple":
307
+ camera_pose = smooth_poses(
308
+ camera_pose, window_size=5, method="gaussian"
309
+ )
310
+ elif smooth_method == "kalman":
311
+ camera_pose = smooth_trajectory(camera_pose, window_size=5)
312
+
313
+ ray_o, ray_d, intrinsics = get_rays(
314
+ camera_pose,
315
+ int(raymap.shape[2] * vae_downsample_scale),
316
+ int(raymap.shape[3] * vae_downsample_scale),
317
+ focal,
318
+ )
319
+
320
+ pointmap = depth[..., None] * ray_d + ray_o
321
+
322
+ return {
323
+ "pointmap": pointmap,
324
+ "camera_pose": camera_pose,
325
+ "intrinsics": intrinsics,
326
+ "ray_o": ray_o,
327
+ "ray_d": ray_d,
328
+ "depth": depth,
329
+ }
330
+
331
+
332
+ def detect_static_sequence(poses, threshold=0.01):
333
+ """Detect if the camera sequence is static based on pose differences."""
334
+ translations = poses[:, :3, 3]
335
+ rotations = poses[:, :3, :3]
336
+
337
+ # Compute translation differences
338
+ trans_diff = np.linalg.norm(translations[1:] - translations[:-1], axis=1).mean()
339
+
340
+ # Compute rotation differences (using matrix frobenius norm)
341
+ rot_diff = np.linalg.norm(rotations[1:] - rotations[:-1], axis=(1, 2)).mean()
342
+
343
+ return trans_diff < threshold and rot_diff < threshold, trans_diff, rot_diff
344
+
345
+
346
+ def adaptive_pose_smoothing(poses, trans_diff, rot_diff, base_window=5):
347
+ """Apply adaptive smoothing based on motion magnitude."""
348
+ # Increase window size for low motion sequences
349
+ motion_magnitude = trans_diff + rot_diff
350
+ adaptive_window = min(
351
+ 41, max(base_window, int(base_window * (0.1 / max(motion_magnitude, 1e-6))))
352
+ )
353
+
354
+ # Apply stronger smoothing for low motion
355
+ poses_smooth = smooth_poses(poses, window_size=adaptive_window, method="gaussian")
356
+ return poses_smooth
357
+
358
+
359
+ def get_pixel(H, W):
360
+ # get 2D pixels (u, v) for image_a in cam_a pixel space
361
+ u_a, v_a = np.meshgrid(np.arange(W), np.arange(H))
362
+ # u_a = np.flip(u_a, axis=1)
363
+ # v_a = np.flip(v_a, axis=0)
364
+ pixels_a = np.stack(
365
+ [u_a.flatten() + 0.5, v_a.flatten() + 0.5, np.ones_like(u_a.flatten())], axis=0
366
+ )
367
+
368
+ return pixels_a
369
+
370
+
371
+ def project(depth, intrinsic, pose):
372
+ H, W = depth.shape
373
+ pixel = get_pixel(H, W).astype(np.float32)
374
+ points = (np.linalg.inv(intrinsic) @ pixel) * depth.reshape(-1)
375
+ points = pose[:3, :4] @ np.concatenate(
376
+ [points, np.ones((1, points.shape[1]))], axis=0
377
+ )
378
+
379
+ points = points.T.reshape(H, W, 3)
380
+
381
+ return points
382
+
383
+
384
+ def depth_edge(
385
+ depth: torch.Tensor,
386
+ atol: float = None,
387
+ rtol: float = None,
388
+ kernel_size: int = 3,
389
+ mask: Optional[torch.Tensor] = None,
390
+ ) -> torch.BoolTensor:
391
+ """
392
+ Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
393
+
394
+ Args:
395
+ depth (torch.Tensor): shape (..., height, width), linear depth map
396
+ atol (float): absolute tolerance
397
+ rtol (float): relative tolerance
398
+
399
+ Returns:
400
+ edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
401
+ """
402
+ is_numpy = isinstance(depth, np.ndarray)
403
+ if is_numpy:
404
+ depth = torch.from_numpy(depth)
405
+ if isinstance(mask, np.ndarray):
406
+ mask = torch.from_numpy(mask)
407
+
408
+ shape = depth.shape
409
+ depth = depth.reshape(-1, 1, *shape[-2:])
410
+ if mask is not None:
411
+ mask = mask.reshape(-1, 1, *shape[-2:])
412
+
413
+ if mask is None:
414
+ diff = F.max_pool2d(
415
+ depth, kernel_size, stride=1, padding=kernel_size // 2
416
+ ) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
417
+ else:
418
+ diff = F.max_pool2d(
419
+ torch.where(mask, depth, -torch.inf),
420
+ kernel_size,
421
+ stride=1,
422
+ padding=kernel_size // 2,
423
+ ) + F.max_pool2d(
424
+ torch.where(mask, -depth, -torch.inf),
425
+ kernel_size,
426
+ stride=1,
427
+ padding=kernel_size // 2,
428
+ )
429
+
430
+ edge = torch.zeros_like(depth, dtype=torch.bool)
431
+ if atol is not None:
432
+ edge |= diff > atol
433
+ if rtol is not None:
434
+ edge |= (diff / depth).nan_to_num_() > rtol
435
+ edge = edge.reshape(*shape)
436
+
437
+ if is_numpy:
438
+ return edge.numpy()
439
+ return edge
440
+
441
+
442
+ @torch.jit.script
443
+ def align_rigid(
444
+ p,
445
+ q,
446
+ weights,
447
+ ):
448
+ """Compute a rigid transformation that, when applied to p, minimizes the weighted
449
+ squared distance between transformed points in p and points in q. See "Least-Squares
450
+ Rigid Motion Using SVD" by Olga Sorkine-Hornung and Michael Rabinovich for more
451
+ details (https://igl.ethz.ch/projects/ARAP/svd_rot.pdf).
452
+ """
453
+
454
+ device = p.device
455
+ dtype = p.dtype
456
+ batch, _, _ = p.shape
457
+
458
+ # 1. Compute the centroids of both point sets.
459
+ weights_normalized = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8)
460
+ p_centroid = (weights_normalized[..., None] * p).sum(dim=-2)
461
+ q_centroid = (weights_normalized[..., None] * q).sum(dim=-2)
462
+
463
+ # 2. Compute the centered vectors.
464
+ p_centered = p - p_centroid[..., None, :]
465
+ q_centered = q - q_centroid[..., None, :]
466
+
467
+ # 3. Compute the 3x3 covariance matrix.
468
+ covariance = (q_centered * weights[..., None]).transpose(-1, -2) @ p_centered
469
+
470
+ # 4. Compute the singular value decomposition and then the rotation.
471
+ u, _, vt = torch.linalg.svd(covariance)
472
+ s = torch.eye(3, dtype=dtype, device=device)
473
+ s = s.expand((batch, 3, 3)).contiguous()
474
+ s[..., 2, 2] = (u.det() * vt.det()).sign()
475
+ rotation = u @ s @ vt
476
+
477
+ # 5. Compute the optimal scale
478
+ scale = (
479
+ (torch.einsum("b i j, b k j -> b k i", rotation, p_centered) * q_centered).sum(
480
+ -1
481
+ )
482
+ * weights
483
+ ).sum(-1) / ((p_centered**2).sum(-1) * weights).sum(-1)
484
+ # scale = (torch.einsum("b i j, b k j -> b k i", rotation, p_centered) * q_centered).sum([-1, -2]) / (p_centered**2).sum([-1, -2])
485
+
486
+ # 6. Compute the optimal translation.
487
+ translation = q_centroid - torch.einsum(
488
+ "b i j, b j -> b i", rotation, p_centroid * scale[:, None]
489
+ )
490
+
491
+ return rotation, translation, scale
492
+
493
+
494
+ def align_camera_extrinsics(
495
+ cameras_src: torch.Tensor, # Bx3x4 tensor representing [R | t]
496
+ cameras_tgt: torch.Tensor, # Bx3x4 tensor representing [R | t]
497
+ estimate_scale: bool = True,
498
+ eps: float = 1e-9,
499
+ ):
500
+ """
501
+ Align the source camera extrinsics to the target camera extrinsics.
502
+ NOTE Assume OPENCV convention
503
+
504
+ Args:
505
+ cameras_src (torch.Tensor): Bx3x4 tensor representing [R | t] for source cameras.
506
+ cameras_tgt (torch.Tensor): Bx3x4 tensor representing [R | t] for target cameras.
507
+ estimate_scale (bool, optional): Whether to estimate the scale factor. Default is True.
508
+ eps (float, optional): Small value to avoid division by zero. Default is 1e-9.
509
+
510
+ Returns:
511
+ align_t_R (torch.Tensor): 1x3x3 rotation matrix for alignment.
512
+ align_t_T (torch.Tensor): 1x3 translation vector for alignment.
513
+ align_t_s (float): Scaling factor for alignment.
514
+ """
515
+
516
+ R_src = cameras_src[:, :, :3] # Extracting the rotation matrices from [R | t]
517
+ R_tgt = cameras_tgt[:, :, :3] # Extracting the rotation matrices from [R | t]
518
+
519
+ RRcov = torch.bmm(R_tgt.transpose(2, 1), R_src).mean(0)
520
+ U, _, V = torch.svd(RRcov)
521
+ align_t_R = V @ U.t()
522
+
523
+ T_src = cameras_src[:, :, 3] # Extracting the translation vectors from [R | t]
524
+ T_tgt = cameras_tgt[:, :, 3] # Extracting the translation vectors from [R | t]
525
+
526
+ A = torch.bmm(T_src[:, None], R_src)[:, 0]
527
+ B = torch.bmm(T_tgt[:, None], R_src)[:, 0]
528
+
529
+ Amu = A.mean(0, keepdim=True)
530
+ Bmu = B.mean(0, keepdim=True)
531
+
532
+ if estimate_scale and A.shape[0] > 1:
533
+ # get the scaling component by matching covariances
534
+ # of centered A and centered B
535
+ Ac = A - Amu
536
+ Bc = B - Bmu
537
+ align_t_s = (Ac * Bc).mean() / (Ac**2).mean().clamp(eps)
538
+ else:
539
+ # set the scale to identity
540
+ align_t_s = 1.0
541
+
542
+ # get the translation as the difference between the means of A and B
543
+ align_t_T = Bmu - align_t_s * Amu
544
+
545
+ align_t_R = align_t_R[None]
546
+ return align_t_R, align_t_T, align_t_s
547
+
548
+
549
+ def apply_transformation(
550
+ cameras_src: torch.Tensor, # Bx3x4 tensor representing [R | t]
551
+ align_t_R: torch.Tensor, # 1x3x3 rotation matrix
552
+ align_t_T: torch.Tensor, # 1x3 translation vector
553
+ align_t_s: float, # Scaling factor
554
+ return_extri: bool = True,
555
+ ) -> torch.Tensor:
556
+ """
557
+ Align and transform the source cameras using the provided rotation, translation, and scaling factors.
558
+ NOTE Assume OPENCV convention
559
+
560
+ Args:
561
+ cameras_src (torch.Tensor): Bx3x4 tensor representing [R | t] for source cameras.
562
+ align_t_R (torch.Tensor): 1x3x3 rotation matrix for alignment.
563
+ align_t_T (torch.Tensor): 1x3 translation vector for alignment.
564
+ align_t_s (float): Scaling factor for alignment.
565
+
566
+ Returns:
567
+ aligned_R (torch.Tensor): Bx3x3 tensor representing the aligned rotation matrices.
568
+ aligned_T (torch.Tensor): Bx3 tensor representing the aligned translation vectors.
569
+ """
570
+
571
+ R_src = cameras_src[:, :, :3]
572
+ T_src = cameras_src[:, :, 3]
573
+
574
+ aligned_R = torch.bmm(R_src, align_t_R.expand(R_src.shape[0], 3, 3))
575
+
576
+ # Apply the translation alignment to the source translations
577
+ align_t_T_expanded = align_t_T[..., None].repeat(R_src.shape[0], 1, 1)
578
+ transformed_T = torch.bmm(R_src, align_t_T_expanded)[..., 0]
579
+ aligned_T = transformed_T + T_src * align_t_s
580
+
581
+ if return_extri:
582
+ extri = torch.cat([aligned_R, aligned_T.unsqueeze(-1)], dim=-1)
583
+ return extri
584
+
585
+ return aligned_R, aligned_T
586
+
587
+
588
+ def slerp(q1, q2, t):
589
+ """Spherical Linear Interpolation between quaternions.
590
+ Args:
591
+ q1: (4,) first quaternion
592
+ q2: (4,) second quaternion
593
+ t: float between 0 and 1
594
+ Returns:
595
+ (4,) interpolated quaternion
596
+ """
597
+ # Compute the cosine of the angle between the two vectors
598
+ dot = np.sum(q1 * q2)
599
+
600
+ # If the dot product is negative, slerp won't take the shorter path
601
+ # Fix by negating one of the input quaternions
602
+ if dot < 0.0:
603
+ q2 = -q2
604
+ dot = -dot
605
+
606
+ # Threshold for using linear interpolation instead of spherical
607
+ DOT_THRESHOLD = 0.9995
608
+ if dot > DOT_THRESHOLD:
609
+ # If the inputs are too close for comfort, linearly interpolate
610
+ # and normalize the result
611
+ result = q1 + t * (q2 - q1)
612
+ return result / np.linalg.norm(result)
613
+
614
+ # Compute the angle between the quaternions
615
+ theta_0 = np.arccos(dot)
616
+ sin_theta_0 = np.sin(theta_0)
617
+
618
+ # Compute interpolation factors
619
+ theta = theta_0 * t
620
+ sin_theta = np.sin(theta)
621
+
622
+ s0 = np.cos(theta) - dot * sin_theta / sin_theta_0
623
+ s1 = sin_theta / sin_theta_0
624
+
625
+ return (s0 * q1) + (s1 * q2)
626
+
627
+
628
+ def interpolate_poses(pose1, pose2, weight):
629
+ """Interpolate between two camera poses with weight.
630
+ Args:
631
+ pose1: (4, 4) first camera pose
632
+ pose2: (4, 4) second camera pose
633
+ weight: float between 0 and 1, weight for pose1 (1-weight for pose2)
634
+ Returns:
635
+ (4, 4) interpolated pose
636
+ """
637
+ from scipy.spatial.transform import Rotation as R
638
+
639
+ # Extract rotations and translations
640
+ R1 = R.from_matrix(pose1[:3, :3])
641
+ R2 = R.from_matrix(pose2[:3, :3])
642
+ t1 = pose1[:3, 3]
643
+ t2 = pose2[:3, 3]
644
+
645
+ # Get quaternions
646
+ q1 = R1.as_quat()
647
+ q2 = R2.as_quat()
648
+
649
+ # Interpolate rotation using our slerp implementation
650
+ q_interp = slerp(q1, q2, 1 - weight) # 1-weight because weight is for pose1
651
+ R_interp = R.from_quat(q_interp)
652
+
653
+ # Linear interpolation for translation
654
+ t_interp = weight * t1 + (1 - weight) * t2
655
+
656
+ # Construct interpolated pose
657
+ pose_interp = np.eye(4)
658
+ pose_interp[:3, :3] = R_interp.as_matrix()
659
+ pose_interp[:3, 3] = t_interp
660
+
661
+ return pose_interp
662
+
663
+
664
+ def smooth_poses(poses, window_size=5, method="gaussian"):
665
+ """Smooth camera poses temporally.
666
+ Args:
667
+ poses: (N, 4, 4) camera poses
668
+ window_size: int, must be odd number
669
+ method: str, 'gaussian' or 'savgol' or 'ma'
670
+ Returns:
671
+ (N, 4, 4) smoothed poses
672
+ """
673
+ from scipy.ndimage import gaussian_filter1d
674
+ from scipy.signal import savgol_filter
675
+ from scipy.spatial.transform import Rotation as R
676
+
677
+ assert window_size % 2 == 1, "window_size must be odd"
678
+ N = poses.shape[0]
679
+ smoothed = np.zeros_like(poses)
680
+
681
+ # Extract translations and quaternions
682
+ translations = poses[:, :3, 3]
683
+ rotations = R.from_matrix(poses[:, :3, :3])
684
+ quats = rotations.as_quat() # (N, 4)
685
+
686
+ # Ensure consistent quaternion signs to prevent interpolation artifacts
687
+ for i in range(1, N):
688
+ if np.dot(quats[i], quats[i - 1]) < 0:
689
+ quats[i] = -quats[i]
690
+
691
+ # Smooth translations
692
+ if method == "gaussian":
693
+ sigma = window_size / 6.0 # approximately 99.7% of the weight within the window
694
+ smoothed_trans = gaussian_filter1d(translations, sigma, axis=0, mode="nearest")
695
+ smoothed_quats = gaussian_filter1d(quats, sigma, axis=0, mode="nearest")
696
+ elif method == "savgol":
697
+ # Savitzky-Golay filter: polynomial fitting
698
+ poly_order = min(window_size - 1, 3)
699
+ smoothed_trans = savgol_filter(
700
+ translations, window_size, poly_order, axis=0, mode="nearest"
701
+ )
702
+ smoothed_quats = savgol_filter(
703
+ quats, window_size, poly_order, axis=0, mode="nearest"
704
+ )
705
+ elif method == "ma":
706
+ # Simple moving average
707
+ kernel = np.ones(window_size) / window_size
708
+ smoothed_trans = np.array(
709
+ [np.convolve(translations[:, i], kernel, mode="same") for i in range(3)]
710
+ ).T
711
+ smoothed_quats = np.array(
712
+ [np.convolve(quats[:, i], kernel, mode="same") for i in range(4)]
713
+ ).T
714
+
715
+ # Normalize quaternions
716
+ smoothed_quats /= np.linalg.norm(smoothed_quats, axis=1, keepdims=True)
717
+
718
+ # Reconstruct poses
719
+ smoothed_rots = R.from_quat(smoothed_quats).as_matrix()
720
+
721
+ for i in range(N):
722
+ smoothed[i] = np.eye(4)
723
+ smoothed[i, :3, :3] = smoothed_rots[i]
724
+ smoothed[i, :3, 3] = smoothed_trans[i]
725
+
726
+ return smoothed
727
+
728
+
729
+ def smooth_trajectory(poses, window_size=5):
730
+ """Smooth camera trajectory using Kalman filter.
731
+ Args:
732
+ poses: (N, 4, 4) camera poses
733
+ window_size: int, window size for initial smoothing
734
+ Returns:
735
+ (N, 4, 4) smoothed poses
736
+ """
737
+ from filterpy.kalman import KalmanFilter
738
+ from scipy.spatial.transform import Rotation as R
739
+
740
+ N = poses.shape[0]
741
+
742
+ # Initialize Kalman filter for position and velocity
743
+ kf = KalmanFilter(dim_x=6, dim_z=3) # 3D position and velocity
744
+ dt = 1.0 # assume uniform time steps
745
+
746
+ # State transition matrix
747
+ kf.F = np.array(
748
+ [
749
+ [1, 0, 0, dt, 0, 0],
750
+ [0, 1, 0, 0, dt, 0],
751
+ [0, 0, 1, 0, 0, dt],
752
+ [0, 0, 0, 1, 0, 0],
753
+ [0, 0, 0, 0, 1, 0],
754
+ [0, 0, 0, 0, 0, 1],
755
+ ]
756
+ )
757
+
758
+ # Measurement matrix
759
+ kf.H = np.array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0]])
760
+
761
+ # Measurement noise
762
+ kf.R *= 0.1
763
+
764
+ # Process noise
765
+ kf.Q *= 0.1
766
+
767
+ # Initial state uncertainty
768
+ kf.P *= 1.0
769
+
770
+ # Extract translations and rotations
771
+ translations = poses[:, :3, 3]
772
+ rotations = R.from_matrix(poses[:, :3, :3])
773
+ quats = rotations.as_quat()
774
+
775
+ # First pass: simple smoothing for initial estimates
776
+ smoothed = smooth_poses(poses, window_size, method="gaussian")
777
+ smooth_trans = smoothed[:, :3, 3]
778
+
779
+ # Second pass: Kalman filter for trajectory
780
+ filtered_trans = np.zeros_like(translations)
781
+ kf.x = np.zeros(6)
782
+ kf.x[:3] = smooth_trans[0]
783
+
784
+ filtered_trans[0] = smooth_trans[0]
785
+
786
+ # Forward pass
787
+ for i in range(1, N):
788
+ kf.predict()
789
+ kf.update(smooth_trans[i])
790
+ filtered_trans[i] = kf.x[:3]
791
+
792
+ # Backward smoothing for rotations using SLERP
793
+ window_half = window_size // 2
794
+ smoothed_quats = np.zeros_like(quats)
795
+
796
+ for i in range(N):
797
+ start_idx = max(0, i - window_half)
798
+ end_idx = min(N, i + window_half + 1)
799
+ weights = np.exp(
800
+ -0.5 * ((np.arange(start_idx, end_idx) - i) / (window_half / 2)) ** 2
801
+ )
802
+ weights /= weights.sum()
803
+
804
+ # Weighted average of nearby quaternions
805
+ avg_quat = np.zeros(4)
806
+ for j, w in zip(range(start_idx, end_idx), weights):
807
+ if np.dot(quats[j], quats[i]) < 0:
808
+ avg_quat += w * -quats[j]
809
+ else:
810
+ avg_quat += w * quats[j]
811
+ smoothed_quats[i] = avg_quat / np.linalg.norm(avg_quat)
812
+
813
+ # Reconstruct final smoothed poses
814
+ final_smoothed = np.zeros_like(poses)
815
+ smoothed_rots = R.from_quat(smoothed_quats).as_matrix()
816
+
817
+ for i in range(N):
818
+ final_smoothed[i] = np.eye(4)
819
+ final_smoothed[i, :3, :3] = smoothed_rots[i]
820
+ final_smoothed[i, :3, 3] = filtered_trans[i]
821
+
822
+ return final_smoothed
823
+
824
+
825
+ def compute_scale(prediction, target, mask):
826
+ if isinstance(prediction, np.ndarray):
827
+ prediction = torch.from_numpy(prediction).float()
828
+ if isinstance(target, np.ndarray):
829
+ target = torch.from_numpy(target).float()
830
+ if isinstance(mask, np.ndarray):
831
+ mask = torch.from_numpy(mask).bool()
832
+
833
+ numerator = torch.sum(mask * prediction * target, (1, 2))
834
+ denominator = torch.sum(mask * prediction * prediction, (1, 2))
835
+
836
+ scale = torch.zeros_like(numerator)
837
+
838
+ valid = (denominator != 0).nonzero()
839
+
840
+ scale[valid] = numerator[valid] / denominator[valid]
841
+
842
+ return scale.item()
aether/utils/preprocess_utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def imcrop_center(img_list, crop_p_h, crop_p_w):
5
+ new_img = []
6
+ for i, _img in enumerate(img_list):
7
+ if crop_p_h / crop_p_w > _img.shape[0] / _img.shape[1]: # crop left and right
8
+ start_h = int(0)
9
+ start_w = int((_img.shape[1] - _img.shape[0] / crop_p_h * crop_p_w) / 2)
10
+ crop_size = (_img.shape[0], int(_img.shape[0] / crop_p_h * crop_p_w))
11
+ else:
12
+ start_h = int((_img.shape[0] - _img.shape[1] / crop_p_w * crop_p_h) / 2)
13
+ start_w = int(0)
14
+ crop_size = (int(_img.shape[1] / crop_p_w * crop_p_h), _img.shape[1])
15
+
16
+ _img_src = crop(_img, start_h, start_w, crop_size[0], crop_size[1])
17
+ new_img.append(_img_src)
18
+
19
+ return new_img
20
+
21
+
22
+ def crop(img, start_h, start_w, crop_h, crop_w):
23
+ img_src = np.zeros((crop_h, crop_w, *img.shape[2:]), dtype=img.dtype)
24
+ hsize, wsize = crop_h, crop_w
25
+ dh, dw, sh, sw = start_h, start_w, 0, 0
26
+ if dh < 0:
27
+ sh = -dh
28
+ hsize += dh
29
+ dh = 0
30
+ if dh + hsize > img.shape[0]:
31
+ hsize = img.shape[0] - dh
32
+ if dw < 0:
33
+ sw = -dw
34
+ wsize += dw
35
+ dw = 0
36
+ if dw + wsize > img.shape[1]:
37
+ wsize = img.shape[1] - dw
38
+ img_src[sh : sh + hsize, sw : sw + wsize] = img[dh : dh + hsize, dw : dw + wsize]
39
+ return img_src
aether/utils/visualize_utils.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Adapted from https://github.com/facebookresearch/vggt/blob/main/visual_util.py
8
+
9
+
10
+ import matplotlib
11
+ import numpy as np
12
+ import trimesh
13
+ from scipy.spatial.transform import Rotation
14
+
15
+ from aether.utils.postprocess_utils import depth_edge
16
+
17
+
18
+ def predictions_to_glb(
19
+ predictions,
20
+ filter_by_frames="all",
21
+ show_cam=True,
22
+ max_depth=100.0,
23
+ rtol=0.03,
24
+ frame_rel_idx: float = 0.0,
25
+ ) -> trimesh.Scene:
26
+ """
27
+ Converts predictions to a 3D scene represented as a GLB file.
28
+
29
+ Args:
30
+ predictions (dict): Dictionary containing model predictions with keys:
31
+ - world_points: 3D point coordinates (S, H, W, 3)
32
+ - images: Input images (S, H, W, 3)
33
+ - depths: Depths (S, H, W)
34
+ - camera poses: Camera poses (S, 4, 4)
35
+ filter_by_frames (str): Frame filter specification (default: "all")
36
+ show_cam (bool): Include camera visualization (default: True)
37
+ max_depth (float): Maximum depth value (default: 100.0)
38
+ rtol (float): Relative tolerance for depth edge detection (default: 0.2)
39
+ frame_rel_idx (float): Relative index of the frame to visualize (default: 0.0)
40
+ Returns:
41
+ trimesh.Scene: Processed 3D scene containing point cloud and cameras
42
+
43
+ Raises:
44
+ ValueError: If input predictions structure is invalid
45
+ """
46
+ if not isinstance(predictions, dict):
47
+ raise ValueError("predictions must be a dictionary")
48
+
49
+ selected_frame_idx = None
50
+ if filter_by_frames != "all" and filter_by_frames != "All":
51
+ try:
52
+ # Extract the index part before the colon
53
+ selected_frame_idx = int(filter_by_frames.split(":")[0])
54
+ except (ValueError, IndexError):
55
+ pass
56
+
57
+ pred_world_points = predictions["world_points"]
58
+
59
+ # Get images from predictions
60
+ images = predictions["images"]
61
+ # Use extrinsic matrices instead of pred_extrinsic_list
62
+ camera_poses = predictions["camera_poses"]
63
+
64
+ if selected_frame_idx is not None:
65
+ pred_world_points = pred_world_points[selected_frame_idx][None]
66
+ images = images[selected_frame_idx][None]
67
+ camera_poses = camera_poses[selected_frame_idx][None]
68
+
69
+ vertices_3d = pred_world_points.reshape(-1, 3)
70
+ # Handle different image formats - check if images need transposing
71
+ if images.ndim == 4 and images.shape[1] == 3: # NCHW format
72
+ colors_rgb = np.transpose(images, (0, 2, 3, 1))
73
+ else: # Assume already in NHWC format
74
+ colors_rgb = images
75
+ colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
76
+
77
+ depths = predictions["depths"]
78
+ masks = depths < max_depth
79
+ edge = ~depth_edge(depths, rtol=rtol, mask=masks)
80
+ masks = (masks & edge).reshape(-1)
81
+ vertices_3d = vertices_3d[masks]
82
+ colors_rgb = colors_rgb[masks]
83
+
84
+ if vertices_3d is None or np.asarray(vertices_3d).size == 0:
85
+ vertices_3d = np.array([[1, 0, 0]])
86
+ colors_rgb = np.array([[255, 255, 255]])
87
+ scene_scale = 1
88
+ else:
89
+ # Calculate the 5th and 95th percentiles along each axis
90
+ lower_percentile = np.percentile(vertices_3d, 5, axis=0)
91
+ upper_percentile = np.percentile(vertices_3d, 95, axis=0)
92
+
93
+ # Calculate the diagonal length of the percentile bounding box
94
+ scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
95
+
96
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
97
+
98
+ # Initialize a 3D scene
99
+ scene_3d = trimesh.Scene()
100
+
101
+ # Add point cloud data to the scene
102
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
103
+
104
+ scene_3d.add_geometry(point_cloud_data)
105
+
106
+ # Prepare 4x4 matrices for camera extrinsics
107
+ num_cameras = len(camera_poses)
108
+ extrinsics_matrices = np.zeros((num_cameras, 4, 4))
109
+ extrinsics_matrices[:, :3, :4] = camera_poses[:, :3, :4]
110
+ extrinsics_matrices[:, 3, 3] = 1
111
+
112
+ if show_cam:
113
+ # Add camera models to the scene
114
+ for i in range(num_cameras):
115
+ camera_to_world = camera_poses[i]
116
+ rgba_color = colormap(frame_rel_idx)
117
+ current_color = tuple(int(255 * x) for x in rgba_color[:3])
118
+
119
+ integrate_camera_into_scene(
120
+ scene_3d, camera_to_world, current_color, scene_scale
121
+ )
122
+
123
+ return scene_3d
124
+
125
+
126
+ def integrate_camera_into_scene(
127
+ scene: trimesh.Scene,
128
+ transform: np.ndarray,
129
+ face_colors: tuple,
130
+ scene_scale: float,
131
+ ):
132
+ """
133
+ Integrates a fake camera mesh into the 3D scene.
134
+
135
+ Args:
136
+ scene (trimesh.Scene): The 3D scene to add the camera model.
137
+ transform (np.ndarray): Transformation matrix for camera positioning.
138
+ face_colors (tuple): Color of the camera face.
139
+ scene_scale (float): Scale of the scene.
140
+ """
141
+
142
+ cam_width = scene_scale * 0.025
143
+ cam_height = scene_scale * 0.05
144
+
145
+ # Create cone shape for camera
146
+ rot_45_degree = np.eye(4)
147
+ rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
148
+ rot_45_degree[2, 3] = -cam_height
149
+
150
+ opengl_transform = get_opengl_conversion_matrix()
151
+ # Combine transformations
152
+ complete_transform = transform @ opengl_transform @ rot_45_degree
153
+ camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
154
+
155
+ # Generate mesh for the camera
156
+ slight_rotation = np.eye(4)
157
+ slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
158
+
159
+ vertices_combined = np.concatenate(
160
+ [
161
+ camera_cone_shape.vertices,
162
+ 0.95 * camera_cone_shape.vertices,
163
+ transform_points(slight_rotation, camera_cone_shape.vertices),
164
+ ]
165
+ )
166
+ vertices_transformed = transform_points(complete_transform, vertices_combined)
167
+
168
+ mesh_faces = compute_camera_faces(camera_cone_shape)
169
+
170
+ # Add the camera mesh to the scene
171
+ camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
172
+ camera_mesh.visual.face_colors[:, :3] = face_colors
173
+ scene.add_geometry(camera_mesh)
174
+
175
+
176
+ def get_opengl_conversion_matrix() -> np.ndarray:
177
+ """
178
+ Constructs and returns the OpenGL conversion matrix.
179
+
180
+ Returns:
181
+ numpy.ndarray: A 4x4 OpenGL conversion matrix.
182
+ """
183
+ # Create an identity matrix
184
+ matrix = np.identity(4)
185
+
186
+ # Flip the y and z axes
187
+ matrix[1, 1] = -1
188
+ matrix[2, 2] = -1
189
+
190
+ return matrix
191
+
192
+
193
+ def transform_points(
194
+ transformation: np.ndarray, points: np.ndarray, dim: int = None
195
+ ) -> np.ndarray:
196
+ """
197
+ Applies a 4x4 transformation to a set of points.
198
+
199
+ Args:
200
+ transformation (np.ndarray): Transformation matrix.
201
+ points (np.ndarray): Points to be transformed.
202
+ dim (int, optional): Dimension for reshaping the result.
203
+
204
+ Returns:
205
+ np.ndarray: Transformed points.
206
+ """
207
+ points = np.asarray(points)
208
+ initial_shape = points.shape[:-1]
209
+ dim = dim or points.shape[-1]
210
+
211
+ # Apply transformation
212
+ transformation = transformation.swapaxes(
213
+ -1, -2
214
+ ) # Transpose the transformation matrix
215
+ points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
216
+
217
+ # Reshape the result
218
+ result = points[..., :dim].reshape(*initial_shape, dim)
219
+ return result
220
+
221
+
222
+ def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
223
+ """
224
+ Computes the faces for the camera mesh.
225
+
226
+ Args:
227
+ cone_shape (trimesh.Trimesh): The shape of the camera cone.
228
+
229
+ Returns:
230
+ np.ndarray: Array of faces for the camera mesh.
231
+ """
232
+ # Create pseudo cameras
233
+ faces_list = []
234
+ num_vertices_cone = len(cone_shape.vertices)
235
+
236
+ for face in cone_shape.faces:
237
+ if 0 in face:
238
+ continue
239
+ v1, v2, v3 = face
240
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
241
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
242
+
243
+ faces_list.extend(
244
+ [
245
+ (v1, v2, v2_offset),
246
+ (v1, v1_offset, v3),
247
+ (v3_offset, v2, v3),
248
+ (v1, v2, v2_offset_2),
249
+ (v1, v1_offset_2, v3),
250
+ (v3_offset_2, v2, v3),
251
+ ]
252
+ )
253
+
254
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
255
+ return np.array(faces_list)
app.py DELETED
@@ -1,7 +0,0 @@
1
- import gradio as gr
2
-
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
assets/example_obs/car.png ADDED

Git LFS Details

  • SHA256: e69a7af1f0aeb161c76f0d6b251b3394ee09705c0d1823207cc280add07a3933
  • Pointer size: 131 Bytes
  • Size of remote file: 475 kB
assets/example_obs/cartoon.png ADDED

Git LFS Details

  • SHA256: 82df67cd9a1393cbaed97ce3052fbf04a60cbb0e2603b8871e6364a432cb086d
  • Pointer size: 131 Bytes
  • Size of remote file: 564 kB
assets/example_obs/garden.jpg ADDED
assets/example_obs/room.jpg ADDED
assets/example_obs_goal/01_goal.png ADDED

Git LFS Details

  • SHA256: 42cecbddab958627643db8651fe18886ab05db51faa1288549efbca3c85e2276
  • Pointer size: 131 Bytes
  • Size of remote file: 514 kB
assets/example_obs_goal/01_obs.png ADDED

Git LFS Details

  • SHA256: 705f540f79b897aa44bb7f09618015def3b71f7a5a251adbef30dca7e323d740
  • Pointer size: 131 Bytes
  • Size of remote file: 451 kB
assets/example_obs_goal/02_goal.png ADDED

Git LFS Details

  • SHA256: ed95356dd64f889a4516ae71196f4e26fdc33927918c35d692013224cfdee0c3
  • Pointer size: 131 Bytes
  • Size of remote file: 506 kB
assets/example_obs_goal/02_obs.png ADDED

Git LFS Details

  • SHA256: 9915c1c5f6d09c1d1486822e4f91aad0a3240bb750d8e2ba87a3383d843065b1
  • Pointer size: 131 Bytes
  • Size of remote file: 578 kB
assets/example_obs_goal/03_goal.png ADDED

Git LFS Details

  • SHA256: ff452a47a87bb8df41b7feb310d6dc430c18e1a4c8d662a186ebf706e37dd521
  • Pointer size: 131 Bytes
  • Size of remote file: 411 kB
assets/example_obs_goal/03_obs.png ADDED

Git LFS Details

  • SHA256: 7d05892ce404b3170c682bf82ea71b215d6549ba1cf4903ad368a9fc6cb8b363
  • Pointer size: 131 Bytes
  • Size of remote file: 365 kB
assets/example_obs_goal/04_goal.png ADDED

Git LFS Details

  • SHA256: cc896d573a888dbf2d77e95e4a901efa239f07ee5cd10787199c45e559ece166
  • Pointer size: 131 Bytes
  • Size of remote file: 588 kB
assets/example_obs_goal/04_obs.png ADDED

Git LFS Details

  • SHA256: 9e3d6ec20f188b5f5501933b75aae89c9c8e057e544a5d10bc562e69b5bfdf6d
  • Pointer size: 131 Bytes
  • Size of remote file: 599 kB
assets/example_raymaps/raymap_backward.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:396d0f398d8a59bc27a1d76b91222a58c8d751c6aff5a3ccc10a56a2beecd540
3
+ size 5313728
assets/example_raymaps/raymap_forward_right.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46f73167495d98d19e86cd47cd062f6d18f2c2882d73c1239effc1b5f039bd32
3
+ size 5313728
assets/example_raymaps/raymap_left_forward.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6e744a0adf4c5f0386ce8c1d586c85f1b17d392978fa34057c638aff5b84ac3
3
+ size 5313728
assets/example_raymaps/raymap_right.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eab316eab3e94297b0d4d5cffae367484265873ce3c4c98e24bffb4361d4a8d
3
+ size 5313728
assets/example_videos/bridge.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dcb3a3dff9d02adfa97534fb3cb96b0618c101c049cf37cb81ce375bf71f252
3
+ size 6714828
assets/example_videos/moviegen.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4204806ad057c53068710d9f3db1ad74f3e947be92d1e8c3f3e06343efd0c1c6
3
+ size 2164737
assets/example_videos/nuscenes.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1230424dfae7159140968649e06f0d8a874f0e52f5ad296ba0fa5fd9c1c2d467
3
+ size 6256650
assets/example_videos/veo2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9206a10f368f4334338c0ca44d6c85aa2d260294d720309a855741d57116ed60
3
+ size 2909620
assets/logo.png ADDED

Git LFS Details

  • SHA256: 1fcc6a3c8e5fc8206ce96ca50f85b06aa337d38354b98b4faef986f06026550e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.29 MB
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 3b9cfe7dbabbb999ad75f78ef3a38ffb0ed9f56303cff3a0d9ebfa90bf29031c
  • Pointer size: 133 Bytes
  • Size of remote file: 11.8 MB
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.ruff]
2
+ line-length = 88
3
+
4
+ [tool.ruff.lint]
5
+ # Never enforce `E501` (line length violations).
6
+ ignore = ["C901", "E501", "E741", "F402", "F823"]
7
+ select = ["C", "E", "F", "I", "W"]
8
+
9
+ # Ignore import violations in all `__init__.py` files.
10
+ [tool.ruff.lint.per-file-ignores]
11
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
12
+ "aether/*.py" = ["E402"]
13
+ "tests/*.py" = ["E402", "F841"]
14
+
15
+ [tool.ruff.lint.isort]
16
+ lines-after-imports = 2
17
+ known-first-party = ["aether"]
18
+
19
+ [tool.ruff.format]
20
+ # Like Black, use double quotes for strings.
21
+ quote-style = "double"
22
+
23
+ # Like Black, indent with spaces, rather than tabs.
24
+ indent-style = "space"
25
+
26
+ # Like Black, respect magic trailing commas.
27
+ skip-magic-trailing-comma = false
28
+
29
+ # Like Black, automatically detect the appropriate line ending.
30
+ line-ending = "auto"
requirements.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.2.1
2
+ coloredlogs>=15.0.1
3
+ colorlog>=6.9.0
4
+ diffusers>=0.32.2
5
+ easydict>=1.13
6
+ einops>=0.8.0
7
+ hf_transfer>=0.1.8
8
+ huggingface-hub>=0.27.1
9
+ imageio>=2.33.1
10
+ imageio-ffmpeg>=0.5.1
11
+ iopath>=0.1.10
12
+ matplotlib>=3.10.0
13
+ numpy>=1.26.4
14
+ omegaconf>=2.3.0
15
+ opencv-python-headless>=4.10.0.84
16
+ pillow>=11.1.0
17
+ plotly>=5.24.1
18
+ plyfile>=1.1
19
+ pre_commit>=4.0.1
20
+ python-dotenv>=1.0.1
21
+ PyYAML>=6.0.2
22
+ rich>=13.9.4
23
+ rootutils>=1.0.7
24
+ safetensors>=0.5.2
25
+ scikit-image>=0.25.0
26
+ scipy>=1.15.0
27
+ sentencepiece>=0.2.0
28
+ six>=1.17.0
29
+ tokenizers>=0.21.0
30
+ torchaudio>=2.5.1
31
+ torchmetrics>=1.6.1
32
+ torchvision>=0.20.1
33
+ tqdm>=4.67.1
34
+ transformers>=4.48.0
35
+ triton>=3.1.0
36
+ typer>=0.15.1
37
+ typing_extensions>=4.12.2
38
+ viser>=0.2.23
39
+ filterpy
40
+ trimesh
41
+ gradio
scripts/demo.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from typing import List, Optional, Tuple
5
+
6
+ import imageio.v3 as iio
7
+ import numpy as np
8
+ import PIL
9
+ import rootutils
10
+ import torch
11
+ from diffusers import (
12
+ AutoencoderKLCogVideoX,
13
+ CogVideoXDPMScheduler,
14
+ CogVideoXTransformer3DModel,
15
+ )
16
+ from transformers import AutoTokenizer, T5EncoderModel
17
+
18
+
19
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
+
21
+ from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402
22
+ AetherV1PipelineCogVideoX,
23
+ AetherV1PipelineOutput,
24
+ )
25
+ from aether.utils.postprocess_utils import ( # noqa: E402
26
+ align_camera_extrinsics,
27
+ apply_transformation,
28
+ colorize_depth,
29
+ compute_scale,
30
+ get_intrinsics,
31
+ interpolate_poses,
32
+ postprocess_pointmap,
33
+ project,
34
+ raymap_to_poses,
35
+ )
36
+ from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
37
+
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+
42
+ def seed_all(seed: int = 0) -> None:
43
+ """
44
+ Set random seeds of all components.
45
+ """
46
+ random.seed(seed)
47
+ np.random.seed(seed)
48
+ torch.manual_seed(seed)
49
+ torch.cuda.manual_seed_all(seed)
50
+
51
+
52
+ def parse_args() -> argparse.Namespace:
53
+ """Parse command line arguments."""
54
+ parser = argparse.ArgumentParser(description="AetherV1-CogvideoX Inference Demo")
55
+
56
+ parser.add_argument(
57
+ "--task",
58
+ type=str,
59
+ required=True,
60
+ choices=["reconstruction", "prediction", "planning"],
61
+ help="Task to perform: 'reconstruction', 'prediction' or 'planning'.",
62
+ )
63
+ parser.add_argument(
64
+ "--video",
65
+ type=str,
66
+ default=None,
67
+ help="Path to a video file. Only used for 'reconstruction' task.",
68
+ )
69
+ parser.add_argument(
70
+ "--image",
71
+ type=str,
72
+ default=None,
73
+ help="Path to an image file. Only used for 'prediction' and 'planning' tasks.",
74
+ )
75
+ parser.add_argument(
76
+ "--goal",
77
+ type=str,
78
+ default=None,
79
+ help="Path to a goal image file. Only used for 'planning' task.",
80
+ )
81
+ parser.add_argument(
82
+ "--raymap_action",
83
+ type=str,
84
+ default=None,
85
+ help="Path to a raymap action file. Should be a numpy array of shape (num_frame, 6, latent_height, latent_width).",
86
+ )
87
+ parser.add_argument(
88
+ "--output_dir",
89
+ type=str,
90
+ default="outputs",
91
+ help="Path to save the outputs.",
92
+ )
93
+ parser.add_argument(
94
+ "--seed",
95
+ type=int,
96
+ default=42,
97
+ help="Random seed.",
98
+ )
99
+ parser.add_argument(
100
+ "--fps",
101
+ type=int,
102
+ default=12,
103
+ choices=[8, 10, 12, 15, 24],
104
+ help="Frames per second. Options: 8, 10, 12, 15, 24.",
105
+ )
106
+ parser.add_argument(
107
+ "--num_inference_steps",
108
+ type=int,
109
+ default=None,
110
+ help="Number of inference steps. If not specified, will use the default number of steps for the task.",
111
+ )
112
+ parser.add_argument(
113
+ "--guidance_scale",
114
+ type=float,
115
+ default=None,
116
+ help="Guidance scale. If not specified, will use the default guidance scale for the task.",
117
+ )
118
+ parser.add_argument(
119
+ "--use_dynamic_cfg",
120
+ action="store_true",
121
+ default=True,
122
+ help="Use dynamic cfg.",
123
+ )
124
+ parser.add_argument(
125
+ "--height",
126
+ type=int,
127
+ default=480,
128
+ help="Height of the output video.",
129
+ )
130
+ parser.add_argument(
131
+ "--width",
132
+ type=int,
133
+ default=720,
134
+ help="Width of the output video.",
135
+ )
136
+ parser.add_argument(
137
+ "--num_frames",
138
+ type=int,
139
+ default=41,
140
+ help="Number of frames to predict.",
141
+ )
142
+ parser.add_argument(
143
+ "--max_depth",
144
+ type=float,
145
+ default=100.0,
146
+ help="Maximum depth of the scene in meters.",
147
+ )
148
+ parser.add_argument(
149
+ "--rtol",
150
+ type=float,
151
+ default=0.03,
152
+ help="Relative tolerance for depth edge detection.",
153
+ )
154
+ parser.add_argument(
155
+ "--cogvideox_pretrained_model_name_or_path",
156
+ type=str,
157
+ default="THUDM/CogVideoX-5b-I2V",
158
+ help="Name or path of the CogVideoX model to use.",
159
+ )
160
+ parser.add_argument(
161
+ "--aether_pretrained_model_name_or_path",
162
+ type=str,
163
+ default="AetherWorldModel/AetherV1-CogVideoX",
164
+ help="Name or path of the Aether model to use.",
165
+ )
166
+ parser.add_argument(
167
+ "--smooth_camera",
168
+ action="store_true",
169
+ default=True,
170
+ help="Smooth the camera trajectory.",
171
+ )
172
+ parser.add_argument(
173
+ "--smooth_method",
174
+ type=str,
175
+ default="kalman",
176
+ choices=["kalman", "simple"],
177
+ help="Smooth method.",
178
+ )
179
+ parser.add_argument(
180
+ "--sliding_window_stride",
181
+ type=int,
182
+ default=24,
183
+ help="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task.",
184
+ )
185
+ parser.add_argument(
186
+ "--post_reconstruction",
187
+ action="store_true",
188
+ default=True,
189
+ help="Run reconstruction after prediction for better quality. Only used for 'prediction' and 'planning' tasks.",
190
+ )
191
+ parser.add_argument(
192
+ "--pointcloud_save_frame_interval",
193
+ type=int,
194
+ default=10,
195
+ help="Pointcloud save frame interval.",
196
+ )
197
+ parser.add_argument(
198
+ "--align_pointmaps",
199
+ action="store_true",
200
+ default=False,
201
+ help="Align pointmaps.",
202
+ )
203
+ return parser.parse_args()
204
+
205
+
206
+ def build_pipeline(args: argparse.Namespace) -> AetherV1PipelineCogVideoX:
207
+ pipeline = AetherV1PipelineCogVideoX(
208
+ tokenizer=AutoTokenizer.from_pretrained(
209
+ args.cogvideox_pretrained_model_name_or_path,
210
+ subfolder="tokenizer",
211
+ ),
212
+ text_encoder=T5EncoderModel.from_pretrained(
213
+ args.cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
214
+ ),
215
+ vae=AutoencoderKLCogVideoX.from_pretrained(
216
+ args.cogvideox_pretrained_model_name_or_path, subfolder="vae"
217
+ ),
218
+ scheduler=CogVideoXDPMScheduler.from_pretrained(
219
+ args.cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
220
+ ),
221
+ transformer=CogVideoXTransformer3DModel.from_pretrained(
222
+ args.aether_pretrained_model_name_or_path, subfolder="transformer"
223
+ ),
224
+ )
225
+ pipeline.vae.enable_slicing()
226
+ pipeline.vae.enable_tiling()
227
+ pipeline.to(device)
228
+ return pipeline
229
+
230
+
231
+ def get_window_starts(
232
+ total_frames: int, sliding_window_size: int, temporal_stride: int
233
+ ) -> List[int]:
234
+ """Calculate window start indices."""
235
+ starts = list(
236
+ range(
237
+ 0,
238
+ total_frames - sliding_window_size + 1,
239
+ temporal_stride,
240
+ )
241
+ )
242
+ if (
243
+ total_frames > sliding_window_size
244
+ and (total_frames - sliding_window_size) % temporal_stride != 0
245
+ ):
246
+ starts.append(total_frames - sliding_window_size)
247
+ return starts
248
+
249
+
250
+ def blend_and_merge_window_results(
251
+ window_results: List[AetherV1PipelineOutput],
252
+ window_indices: List[int],
253
+ args: argparse.Namespace,
254
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
255
+ """Blend and merge window results."""
256
+ merged_rgb = None
257
+ merged_disparity = None
258
+ merged_poses = None
259
+ merged_focals = None
260
+ if args.align_pointmaps:
261
+ merged_pointmaps = None
262
+
263
+ w1 = window_results[0].disparity
264
+
265
+ for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)):
266
+ t_end = t_start + window_result.rgb.shape[0]
267
+ if idx == 0:
268
+ merged_rgb = window_result.rgb
269
+ merged_disparity = window_result.disparity
270
+ pointmap_dict = postprocess_pointmap(
271
+ window_result.disparity,
272
+ window_result.raymap,
273
+ vae_downsample_scale=8,
274
+ ray_o_scale_inv=0.1,
275
+ smooth_camera=args.smooth_camera,
276
+ smooth_method=args.smooth_method if args.smooth_camera else "none",
277
+ )
278
+ merged_poses = pointmap_dict["camera_pose"]
279
+ merged_focals = (
280
+ pointmap_dict["intrinsics"][:, 0, 0]
281
+ + pointmap_dict["intrinsics"][:, 1, 1]
282
+ ) / 2
283
+ if args.align_pointmaps:
284
+ merged_pointmaps = pointmap_dict["pointmap"]
285
+ else:
286
+ overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start
287
+
288
+ window_disparity = window_result.disparity
289
+
290
+ # Align disparity
291
+ disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1
292
+ scale = compute_scale(
293
+ window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]),
294
+ merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]),
295
+ disp_mask.reshape(1, -1, w1.shape[-1]),
296
+ )
297
+ window_disparity = scale * window_disparity
298
+
299
+ # Blend disparity
300
+ result_disparity = np.ones((t_end, *w1.shape[1:]))
301
+ result_disparity[:t_start] = merged_disparity[:t_start]
302
+ result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:]
303
+ weight = np.linspace(1, 0, overlap_t)[:, None, None]
304
+ result_disparity[t_start : t_start + overlap_t] = merged_disparity[
305
+ t_start : t_start + overlap_t
306
+ ] * weight + window_disparity[:overlap_t] * (1 - weight)
307
+ merged_disparity = result_disparity
308
+
309
+ # Blend RGB
310
+ result_rgb = np.ones((t_end, *w1.shape[1:], 3))
311
+ result_rgb[:t_start] = merged_rgb[:t_start]
312
+ result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:]
313
+ weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None]
314
+ result_rgb[t_start : t_start + overlap_t] = merged_rgb[
315
+ t_start : t_start + overlap_t
316
+ ] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb)
317
+ merged_rgb = result_rgb
318
+
319
+ # Align poses
320
+ window_raymap = window_result.raymap
321
+ window_poses, window_Fov_x, window_Fov_y = raymap_to_poses(
322
+ window_raymap, ray_o_scale_inv=0.1
323
+ )
324
+ rel_r, rel_t, rel_s = align_camera_extrinsics(
325
+ torch.from_numpy(window_poses[:overlap_t]),
326
+ torch.from_numpy(merged_poses[-overlap_t:]),
327
+ )
328
+ aligned_window_poses = (
329
+ apply_transformation(
330
+ torch.from_numpy(window_poses),
331
+ rel_r,
332
+ rel_t,
333
+ rel_s,
334
+ return_extri=True,
335
+ )
336
+ .cpu()
337
+ .numpy()
338
+ )
339
+
340
+ result_poses = np.ones((t_end, 4, 4))
341
+ result_poses[:t_start] = merged_poses[:t_start]
342
+ result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:]
343
+
344
+ # Interpolate poses in overlap region
345
+ weights = np.linspace(1, 0, overlap_t)
346
+ for t in range(overlap_t):
347
+ weight = weights[t]
348
+ pose1 = merged_poses[t_start + t]
349
+ pose2 = aligned_window_poses[t]
350
+ result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight)
351
+
352
+ merged_poses = result_poses
353
+
354
+ # Align intrinsics
355
+ window_intrinsics, _ = get_intrinsics(
356
+ batch_size=window_poses.shape[0],
357
+ h=window_result.disparity.shape[1],
358
+ w=window_result.disparity.shape[2],
359
+ fovx=window_Fov_x,
360
+ fovy=window_Fov_y,
361
+ )
362
+ window_focals = (
363
+ window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1]
364
+ ) / 2
365
+ scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean()
366
+ window_focals = scale * window_focals
367
+ result_focals = np.ones((t_end,))
368
+ result_focals[:t_start] = merged_focals[:t_start]
369
+ result_focals[t_start + overlap_t :] = window_focals[overlap_t:]
370
+ weight = np.linspace(1, 0, overlap_t)
371
+ result_focals[t_start : t_start + overlap_t] = merged_focals[
372
+ t_start : t_start + overlap_t
373
+ ] * weight + window_focals[:overlap_t] * (1 - weight)
374
+ merged_focals = result_focals
375
+
376
+ if args.align_pointmaps:
377
+ # Align pointmaps
378
+ window_pointmaps = postprocess_pointmap(
379
+ result_disparity[t_start:],
380
+ window_raymap,
381
+ vae_downsample_scale=8,
382
+ camera_pose=aligned_window_poses,
383
+ focal=window_focals,
384
+ ray_o_scale_inv=0.1,
385
+ smooth_camera=args.smooth_camera,
386
+ smooth_method=args.smooth_method if args.smooth_camera else "none",
387
+ )
388
+ result_pointmaps = np.ones((t_end, *w1.shape[1:], 3))
389
+ result_pointmaps[:t_start] = merged_pointmaps[:t_start]
390
+ result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][
391
+ overlap_t:
392
+ ]
393
+ weight = np.linspace(1, 0, overlap_t)[:, None, None, None]
394
+ result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[
395
+ t_start : t_start + overlap_t
396
+ ] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight)
397
+ merged_pointmaps = result_pointmaps
398
+
399
+ # project to pointmaps
400
+ intrinsics = [
401
+ np.array([[f, 0, 0.5 * args.width], [0, f, 0.5 * args.height], [0, 0, 1]])
402
+ for f in merged_focals
403
+ ]
404
+ if args.align_pointmaps:
405
+ pointmaps = merged_pointmaps
406
+ else:
407
+ pointmaps = np.stack(
408
+ [
409
+ project(
410
+ 1 / np.clip(merged_disparity[i], 1e-8, 1e8),
411
+ intrinsics[i],
412
+ merged_poses[i],
413
+ )
414
+ for i in range(merged_poses.shape[0])
415
+ ]
416
+ )
417
+
418
+ return merged_rgb, merged_disparity, merged_poses, pointmaps
419
+
420
+
421
+ def save_output(
422
+ rgb: np.ndarray,
423
+ disparity: np.ndarray,
424
+ poses: Optional[np.ndarray] = None,
425
+ raymap: Optional[np.ndarray] = None,
426
+ pointmap: Optional[np.ndarray] = None,
427
+ args: argparse.Namespace = None,
428
+ ) -> None:
429
+ output_dir = args.output_dir
430
+ os.makedirs(output_dir, exist_ok=True)
431
+
432
+ if pointmap is None:
433
+ assert raymap is not None, "Raymap is required for saving pointmap."
434
+ pointmap_dict = postprocess_pointmap(
435
+ disparity,
436
+ raymap,
437
+ vae_downsample_scale=8,
438
+ ray_o_scale_inv=0.1,
439
+ smooth_camera=args.smooth_camera,
440
+ smooth_method=args.smooth_method,
441
+ )
442
+ pointmap = pointmap_dict["pointmap"]
443
+
444
+ if poses is None:
445
+ assert raymap is not None, "Raymap is required for saving poses."
446
+ poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
447
+
448
+ if args.task == "reconstruction":
449
+ filename = f"reconstruction_{args.video.split('/')[-1].split('.')[0]}"
450
+ elif args.task == "prediction":
451
+ filename = f"prediction_{args.image.split('/')[-1].split('.')[0]}"
452
+ elif args.task == "planning":
453
+ filename = f"planning_{args.image.split('/')[-1].split('.')[0]}_{args.goal.split('/')[-1].split('.')[0]}"
454
+
455
+ filename = os.path.join(output_dir, filename)
456
+
457
+ iio.imwrite(
458
+ f"{filename}_rgb.mp4",
459
+ (np.clip(rgb, 0, 1) * 255).astype(np.uint8),
460
+ fps=12,
461
+ )
462
+ iio.imwrite(
463
+ f"{filename}_disparity.mp4",
464
+ (colorize_depth(disparity) * 255).astype(np.uint8),
465
+ fps=12,
466
+ )
467
+
468
+ print("Building GLB scene")
469
+ for frame_idx in range(pointmap.shape[0])[:: args.pointcloud_save_frame_interval]:
470
+ predictions = {
471
+ "world_points": pointmap[frame_idx : frame_idx + 1],
472
+ "images": rgb[frame_idx : frame_idx + 1],
473
+ "depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
474
+ "camera_poses": poses[frame_idx : frame_idx + 1],
475
+ }
476
+ scene_3d = predictions_to_glb(
477
+ predictions,
478
+ filter_by_frames="all",
479
+ show_cam=True,
480
+ max_depth=args.max_depth,
481
+ rtol=args.rtol,
482
+ frame_rel_idx=float(frame_idx) / pointmap.shape[0],
483
+ )
484
+ scene_3d.export(f"{filename}_pointcloud_frame_{frame_idx}.glb")
485
+ print("GLB Scene built")
486
+
487
+
488
+ def main() -> None:
489
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
490
+ args = parse_args()
491
+ seed_all(args.seed)
492
+
493
+ if args.num_inference_steps is None:
494
+ args.num_inference_steps = 4 if args.task == "reconstruction" else 50
495
+
496
+ if args.guidance_scale is None:
497
+ args.guidance_scale = 1.0 if args.task == "reconstruction" else 3.0
498
+
499
+ pipeline = build_pipeline(args)
500
+
501
+ if args.task == "reconstruction":
502
+ assert args.video is not None, "Video is required for reconstruction task."
503
+ assert args.image is None, "Image is not required for reconstruction task."
504
+ assert args.goal is None, "Goal is not required for reconstruction task."
505
+
506
+ video = iio.imread(args.video).astype(np.float32) / 255.0
507
+ image, goal = None, None
508
+ elif args.task == "prediction":
509
+ assert args.image is not None, "Image is required for prediction task."
510
+ assert args.goal is None, "Goal is not required for prediction task."
511
+
512
+ image = PIL.Image.open(args.image)
513
+ video, goal = None, None
514
+ elif args.task == "planning":
515
+ assert args.image is not None, "Image is required for planning task."
516
+ assert args.goal is not None, "Goal is required for planning task."
517
+
518
+ image = PIL.Image.open(args.image)
519
+ goal = PIL.Image.open(args.goal)
520
+
521
+ video = None
522
+
523
+ if args.raymap_action is not None:
524
+ raymap = np.load(args.raymap_action)
525
+ else:
526
+ raymap = None
527
+
528
+ if args.task != "reconstruction":
529
+ output = pipeline(
530
+ task=args.task,
531
+ image=image,
532
+ video=video,
533
+ goal=goal,
534
+ raymap=raymap,
535
+ height=args.height,
536
+ width=args.width,
537
+ num_frames=args.num_frames,
538
+ fps=args.fps,
539
+ num_inference_steps=args.num_inference_steps,
540
+ guidance_scale=args.guidance_scale,
541
+ use_dynamic_cfg=args.use_dynamic_cfg,
542
+ generator=torch.Generator(device=device).manual_seed(args.seed),
543
+ return_dict=True,
544
+ )
545
+ if not args.post_reconstruction:
546
+ save_output(
547
+ rgb=output.rgb,
548
+ disparity=output.disparity,
549
+ raymap=output.raymap,
550
+ args=args,
551
+ )
552
+ else:
553
+ recon_output = pipeline(
554
+ task="reconstruction",
555
+ video=output.rgb,
556
+ height=args.height,
557
+ width=args.width,
558
+ num_frames=args.num_frames,
559
+ fps=args.fps,
560
+ num_inference_steps=4,
561
+ guidance_scale=1.0, # we don't need guidance scale for reconstruction task
562
+ use_dynamic_cfg=False,
563
+ generator=torch.Generator(device=device).manual_seed(args.seed),
564
+ )
565
+ save_output(
566
+ rgb=output.rgb,
567
+ disparity=recon_output.disparity,
568
+ raymap=recon_output.raymap,
569
+ args=args,
570
+ )
571
+ else:
572
+ # for reconstruction task, we have to employ sliding window on long videos
573
+ window_results = []
574
+ window_indices = get_window_starts(
575
+ len(video), args.num_frames, args.sliding_window_stride
576
+ )
577
+ for start_idx in window_indices:
578
+ output = pipeline(
579
+ task=args.task,
580
+ image=None,
581
+ goal=None,
582
+ video=video[start_idx : start_idx + args.num_frames],
583
+ raymap=raymap[start_idx : start_idx + args.num_frames]
584
+ if raymap is not None
585
+ else None,
586
+ height=args.height,
587
+ width=args.width,
588
+ num_frames=args.num_frames,
589
+ fps=args.fps,
590
+ num_inference_steps=args.num_inference_steps,
591
+ guidance_scale=1.0, # we don't need guidance scale for reconstruction task
592
+ use_dynamic_cfg=False,
593
+ generator=torch.Generator(device=device).manual_seed(args.seed),
594
+ )
595
+ window_results.append(output)
596
+
597
+ # merge window results
598
+ (
599
+ merged_rgb,
600
+ merged_disparity,
601
+ merged_poses,
602
+ pointmaps,
603
+ ) = blend_and_merge_window_results(window_results, window_indices, args)
604
+ save_output(
605
+ rgb=merged_rgb,
606
+ disparity=merged_disparity,
607
+ poses=merged_poses,
608
+ pointmap=pointmaps,
609
+ args=args,
610
+ )
611
+
612
+
613
+ if __name__ == "__main__":
614
+ main()
scripts/demo_gradio.py ADDED
@@ -0,0 +1,1470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import random
4
+ import re
5
+ from datetime import datetime
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import gradio as gr
9
+ import imageio.v3 as iio
10
+ import numpy as np
11
+ import PIL
12
+ import rootutils
13
+ import torch
14
+ from diffusers import (
15
+ AutoencoderKLCogVideoX,
16
+ CogVideoXDPMScheduler,
17
+ CogVideoXTransformer3DModel,
18
+ )
19
+ from transformers import AutoTokenizer, T5EncoderModel
20
+
21
+
22
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
23
+
24
+ from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402
25
+ AetherV1PipelineCogVideoX,
26
+ AetherV1PipelineOutput,
27
+ )
28
+ from aether.utils.postprocess_utils import ( # noqa: E402
29
+ align_camera_extrinsics,
30
+ apply_transformation,
31
+ colorize_depth,
32
+ compute_scale,
33
+ get_intrinsics,
34
+ interpolate_poses,
35
+ postprocess_pointmap,
36
+ project,
37
+ raymap_to_poses,
38
+ )
39
+ from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
40
+
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+
45
+ def seed_all(seed: int = 0) -> None:
46
+ """
47
+ Set random seeds of all components.
48
+ """
49
+ random.seed(seed)
50
+ np.random.seed(seed)
51
+ torch.manual_seed(seed)
52
+ torch.cuda.manual_seed_all(seed)
53
+
54
+
55
+ # Global pipeline
56
+ cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
57
+ aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
58
+ pipeline = AetherV1PipelineCogVideoX(
59
+ tokenizer=AutoTokenizer.from_pretrained(
60
+ cogvideox_pretrained_model_name_or_path,
61
+ subfolder="tokenizer",
62
+ ),
63
+ text_encoder=T5EncoderModel.from_pretrained(
64
+ cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
65
+ ),
66
+ vae=AutoencoderKLCogVideoX.from_pretrained(
67
+ cogvideox_pretrained_model_name_or_path, subfolder="vae"
68
+ ),
69
+ scheduler=CogVideoXDPMScheduler.from_pretrained(
70
+ cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
71
+ ),
72
+ transformer=CogVideoXTransformer3DModel.from_pretrained(
73
+ aether_pretrained_model_name_or_path, subfolder="transformer"
74
+ ),
75
+ )
76
+ pipeline.vae.enable_slicing()
77
+ pipeline.vae.enable_tiling()
78
+ pipeline.to(device)
79
+
80
+
81
+ def build_pipeline() -> AetherV1PipelineCogVideoX:
82
+ """Initialize the model pipeline."""
83
+ return pipeline
84
+
85
+
86
+ def get_window_starts(
87
+ total_frames: int, sliding_window_size: int, temporal_stride: int
88
+ ) -> List[int]:
89
+ """Calculate window start indices."""
90
+ starts = list(
91
+ range(
92
+ 0,
93
+ total_frames - sliding_window_size + 1,
94
+ temporal_stride,
95
+ )
96
+ )
97
+ if (
98
+ total_frames > sliding_window_size
99
+ and (total_frames - sliding_window_size) % temporal_stride != 0
100
+ ):
101
+ starts.append(total_frames - sliding_window_size)
102
+ return starts
103
+
104
+
105
+ def blend_and_merge_window_results(
106
+ window_results: List[AetherV1PipelineOutput], window_indices: List[int], args: Dict
107
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
108
+ """Blend and merge window results."""
109
+ merged_rgb = None
110
+ merged_disparity = None
111
+ merged_poses = None
112
+ merged_focals = None
113
+ align_pointmaps = args.get("align_pointmaps", True)
114
+ smooth_camera = args.get("smooth_camera", True)
115
+ smooth_method = args.get("smooth_method", "kalman") if smooth_camera else "none"
116
+
117
+ if align_pointmaps:
118
+ merged_pointmaps = None
119
+
120
+ w1 = window_results[0].disparity
121
+
122
+ for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)):
123
+ t_end = t_start + window_result.rgb.shape[0]
124
+ if idx == 0:
125
+ merged_rgb = window_result.rgb
126
+ merged_disparity = window_result.disparity
127
+ pointmap_dict = postprocess_pointmap(
128
+ window_result.disparity,
129
+ window_result.raymap,
130
+ vae_downsample_scale=8,
131
+ ray_o_scale_inv=0.1,
132
+ smooth_camera=smooth_camera,
133
+ smooth_method=smooth_method if smooth_camera else "none",
134
+ )
135
+ merged_poses = pointmap_dict["camera_pose"]
136
+ merged_focals = (
137
+ pointmap_dict["intrinsics"][:, 0, 0]
138
+ + pointmap_dict["intrinsics"][:, 1, 1]
139
+ ) / 2
140
+ if align_pointmaps:
141
+ merged_pointmaps = pointmap_dict["pointmap"]
142
+ else:
143
+ overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start
144
+
145
+ window_disparity = window_result.disparity
146
+
147
+ # Align disparity
148
+ disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1
149
+ scale = compute_scale(
150
+ window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]),
151
+ merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]),
152
+ disp_mask.reshape(1, -1, w1.shape[-1]),
153
+ )
154
+ window_disparity = scale * window_disparity
155
+
156
+ # Blend disparity
157
+ result_disparity = np.ones((t_end, *w1.shape[1:]))
158
+ result_disparity[:t_start] = merged_disparity[:t_start]
159
+ result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:]
160
+ weight = np.linspace(1, 0, overlap_t)[:, None, None]
161
+ result_disparity[t_start : t_start + overlap_t] = merged_disparity[
162
+ t_start : t_start + overlap_t
163
+ ] * weight + window_disparity[:overlap_t] * (1 - weight)
164
+ merged_disparity = result_disparity
165
+
166
+ # Blend RGB
167
+ result_rgb = np.ones((t_end, *w1.shape[1:], 3))
168
+ result_rgb[:t_start] = merged_rgb[:t_start]
169
+ result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:]
170
+ weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None]
171
+ result_rgb[t_start : t_start + overlap_t] = merged_rgb[
172
+ t_start : t_start + overlap_t
173
+ ] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb)
174
+ merged_rgb = result_rgb
175
+
176
+ # Align poses
177
+ window_raymap = window_result.raymap
178
+ window_poses, window_Fov_x, window_Fov_y = raymap_to_poses(
179
+ window_raymap, ray_o_scale_inv=0.1
180
+ )
181
+ rel_r, rel_t, rel_s = align_camera_extrinsics(
182
+ torch.from_numpy(window_poses[:overlap_t]),
183
+ torch.from_numpy(merged_poses[-overlap_t:]),
184
+ )
185
+ aligned_window_poses = (
186
+ apply_transformation(
187
+ torch.from_numpy(window_poses),
188
+ rel_r,
189
+ rel_t,
190
+ rel_s,
191
+ return_extri=True,
192
+ )
193
+ .cpu()
194
+ .numpy()
195
+ )
196
+
197
+ result_poses = np.ones((t_end, 4, 4))
198
+ result_poses[:t_start] = merged_poses[:t_start]
199
+ result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:]
200
+
201
+ # Interpolate poses in overlap region
202
+ weights = np.linspace(1, 0, overlap_t)
203
+ for t in range(overlap_t):
204
+ weight = weights[t]
205
+ pose1 = merged_poses[t_start + t]
206
+ pose2 = aligned_window_poses[t]
207
+ result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight)
208
+
209
+ merged_poses = result_poses
210
+
211
+ # Align intrinsics
212
+ window_intrinsics, _ = get_intrinsics(
213
+ batch_size=window_poses.shape[0],
214
+ h=window_result.disparity.shape[1],
215
+ w=window_result.disparity.shape[2],
216
+ fovx=window_Fov_x,
217
+ fovy=window_Fov_y,
218
+ )
219
+ window_focals = (
220
+ window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1]
221
+ ) / 2
222
+ scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean()
223
+ window_focals = scale * window_focals
224
+ result_focals = np.ones((t_end,))
225
+ result_focals[:t_start] = merged_focals[:t_start]
226
+ result_focals[t_start + overlap_t :] = window_focals[overlap_t:]
227
+ weight = np.linspace(1, 0, overlap_t)
228
+ result_focals[t_start : t_start + overlap_t] = merged_focals[
229
+ t_start : t_start + overlap_t
230
+ ] * weight + window_focals[:overlap_t] * (1 - weight)
231
+ merged_focals = result_focals
232
+
233
+ if align_pointmaps:
234
+ # Align pointmaps
235
+ window_pointmaps = postprocess_pointmap(
236
+ result_disparity[t_start:],
237
+ window_raymap,
238
+ vae_downsample_scale=8,
239
+ camera_pose=aligned_window_poses,
240
+ focal=window_focals,
241
+ ray_o_scale_inv=0.1,
242
+ smooth_camera=smooth_camera,
243
+ smooth_method=smooth_method if smooth_camera else "none",
244
+ )
245
+ result_pointmaps = np.ones((t_end, *w1.shape[1:], 3))
246
+ result_pointmaps[:t_start] = merged_pointmaps[:t_start]
247
+ result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][
248
+ overlap_t:
249
+ ]
250
+ weight = np.linspace(1, 0, overlap_t)[:, None, None, None]
251
+ result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[
252
+ t_start : t_start + overlap_t
253
+ ] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight)
254
+ merged_pointmaps = result_pointmaps
255
+
256
+ # project to pointmaps
257
+ height = args.get("height", 480)
258
+ width = args.get("width", 720)
259
+
260
+ intrinsics = [
261
+ np.array([[f, 0, 0.5 * width], [0, f, 0.5 * height], [0, 0, 1]])
262
+ for f in merged_focals
263
+ ]
264
+ if align_pointmaps:
265
+ pointmaps = merged_pointmaps
266
+ else:
267
+ pointmaps = np.stack(
268
+ [
269
+ project(
270
+ 1 / np.clip(merged_disparity[i], 1e-8, 1e8),
271
+ intrinsics[i],
272
+ merged_poses[i],
273
+ )
274
+ for i in range(merged_poses.shape[0])
275
+ ]
276
+ )
277
+
278
+ return merged_rgb, merged_disparity, merged_poses, pointmaps
279
+
280
+
281
+ def process_video_to_frames(video_path: str, fps_sample: int = 12) -> List[str]:
282
+ """Process video into frames and save them locally."""
283
+ # Create a unique output directory
284
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
285
+ output_dir = f"temp_frames_{timestamp}"
286
+ os.makedirs(output_dir, exist_ok=True)
287
+
288
+ # Read video
289
+ video = iio.imread(video_path)
290
+
291
+ # Calculate frame interval based on original video fps
292
+ if isinstance(video, np.ndarray):
293
+ # For captured videos
294
+ total_frames = len(video)
295
+ frame_interval = max(
296
+ 1, round(total_frames / (fps_sample * (total_frames / 30)))
297
+ )
298
+ else:
299
+ # Default if can't determine
300
+ frame_interval = 2
301
+
302
+ frame_paths = []
303
+ for i, frame in enumerate(video[::frame_interval]):
304
+ frame_path = os.path.join(output_dir, f"frame_{i:04d}.jpg")
305
+ if isinstance(frame, np.ndarray):
306
+ iio.imwrite(frame_path, frame)
307
+ frame_paths.append(frame_path)
308
+
309
+ return frame_paths, output_dir
310
+
311
+
312
+ def save_output_files(
313
+ rgb: np.ndarray,
314
+ disparity: np.ndarray,
315
+ poses: Optional[np.ndarray] = None,
316
+ raymap: Optional[np.ndarray] = None,
317
+ pointmap: Optional[np.ndarray] = None,
318
+ task: str = "reconstruction",
319
+ output_dir: str = "outputs",
320
+ **kwargs,
321
+ ) -> Dict[str, str]:
322
+ """
323
+ Save outputs and return paths to saved files.
324
+ """
325
+ os.makedirs(output_dir, exist_ok=True)
326
+
327
+ if pointmap is None and raymap is not None:
328
+ # Generate pointmap from raymap and disparity
329
+ smooth_camera = kwargs.get("smooth_camera", True)
330
+ smooth_method = (
331
+ kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
332
+ )
333
+
334
+ pointmap_dict = postprocess_pointmap(
335
+ disparity,
336
+ raymap,
337
+ vae_downsample_scale=8,
338
+ ray_o_scale_inv=0.1,
339
+ smooth_camera=smooth_camera,
340
+ smooth_method=smooth_method,
341
+ )
342
+ pointmap = pointmap_dict["pointmap"]
343
+
344
+ if poses is None and raymap is not None:
345
+ poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
346
+
347
+ # Create a unique filename
348
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
349
+ base_filename = f"{task}_{timestamp}"
350
+
351
+ # Paths for saved files
352
+ paths = {}
353
+
354
+ # Save RGB video
355
+ rgb_path = os.path.join(output_dir, f"{base_filename}_rgb.mp4")
356
+ iio.imwrite(
357
+ rgb_path,
358
+ (np.clip(rgb, 0, 1) * 255).astype(np.uint8),
359
+ fps=kwargs.get("fps", 12),
360
+ )
361
+ paths["rgb"] = rgb_path
362
+
363
+ # Save depth/disparity video
364
+ depth_path = os.path.join(output_dir, f"{base_filename}_disparity.mp4")
365
+ iio.imwrite(
366
+ depth_path,
367
+ (colorize_depth(disparity) * 255).astype(np.uint8),
368
+ fps=kwargs.get("fps", 12),
369
+ )
370
+ paths["disparity"] = depth_path
371
+
372
+ # Save point cloud GLB files
373
+ if pointmap is not None and poses is not None:
374
+ pointcloud_save_frame_interval = kwargs.get(
375
+ "pointcloud_save_frame_interval", 10
376
+ )
377
+ max_depth = kwargs.get("max_depth", 100.0)
378
+ rtol = kwargs.get("rtol", 0.03)
379
+
380
+ glb_paths = []
381
+ # Determine which frames to save based on the interval
382
+ frames_to_save = list(
383
+ range(0, pointmap.shape[0], pointcloud_save_frame_interval)
384
+ )
385
+
386
+ # Always include the first and last frame
387
+ if 0 not in frames_to_save:
388
+ frames_to_save.insert(0, 0)
389
+ if pointmap.shape[0] - 1 not in frames_to_save:
390
+ frames_to_save.append(pointmap.shape[0] - 1)
391
+
392
+ # Sort the frames to ensure they're in order
393
+ frames_to_save = sorted(set(frames_to_save))
394
+
395
+ for frame_idx in frames_to_save:
396
+ if frame_idx >= pointmap.shape[0]:
397
+ continue
398
+
399
+ predictions = {
400
+ "world_points": pointmap[frame_idx : frame_idx + 1],
401
+ "images": rgb[frame_idx : frame_idx + 1],
402
+ "depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
403
+ "camera_poses": poses[frame_idx : frame_idx + 1],
404
+ }
405
+
406
+ glb_path = os.path.join(
407
+ output_dir, f"{base_filename}_pointcloud_frame_{frame_idx}.glb"
408
+ )
409
+
410
+ scene_3d = predictions_to_glb(
411
+ predictions,
412
+ filter_by_frames="all",
413
+ show_cam=True,
414
+ max_depth=max_depth,
415
+ rtol=rtol,
416
+ frame_rel_idx=float(frame_idx) / pointmap.shape[0],
417
+ )
418
+ scene_3d.export(glb_path)
419
+ glb_paths.append(glb_path)
420
+
421
+ paths["pointcloud_glbs"] = glb_paths
422
+
423
+ return paths
424
+
425
+
426
+ def process_reconstruction(
427
+ video_file,
428
+ height,
429
+ width,
430
+ num_frames,
431
+ num_inference_steps,
432
+ guidance_scale,
433
+ sliding_window_stride,
434
+ fps,
435
+ smooth_camera,
436
+ align_pointmaps,
437
+ max_depth,
438
+ rtol,
439
+ pointcloud_save_frame_interval,
440
+ seed,
441
+ progress=gr.Progress(),
442
+ ):
443
+ """
444
+ Process reconstruction task.
445
+ """
446
+ try:
447
+ gc.collect()
448
+ torch.cuda.empty_cache()
449
+
450
+ # Set random seed
451
+ seed_all(seed)
452
+
453
+ # Build the pipeline
454
+ pipeline = build_pipeline()
455
+
456
+ progress(0.1, "Loading video")
457
+ # Check if video_file is a string or a file object
458
+ if isinstance(video_file, str):
459
+ video_path = video_file
460
+ else:
461
+ video_path = video_file.name
462
+
463
+ video = iio.imread(video_path).astype(np.float32) / 255.0
464
+
465
+ # Setup arguments
466
+ args = {
467
+ "height": height,
468
+ "width": width,
469
+ "num_frames": num_frames,
470
+ "sliding_window_stride": sliding_window_stride,
471
+ "smooth_camera": smooth_camera,
472
+ "smooth_method": "kalman" if smooth_camera else "none",
473
+ "align_pointmaps": align_pointmaps,
474
+ "max_depth": max_depth,
475
+ "rtol": rtol,
476
+ "pointcloud_save_frame_interval": pointcloud_save_frame_interval,
477
+ }
478
+
479
+ # Process in sliding windows
480
+ window_results = []
481
+ window_indices = get_window_starts(
482
+ len(video), num_frames, sliding_window_stride
483
+ )
484
+
485
+ progress(0.2, f"Processing video in {len(window_indices)} windows")
486
+
487
+ for i, start_idx in enumerate(window_indices):
488
+ progress_val = 0.2 + (0.6 * (i / len(window_indices)))
489
+ progress(progress_val, f"Processing window {i+1}/{len(window_indices)}")
490
+
491
+ output = pipeline(
492
+ task="reconstruction",
493
+ image=None,
494
+ goal=None,
495
+ video=video[start_idx : start_idx + num_frames],
496
+ raymap=None,
497
+ height=height,
498
+ width=width,
499
+ num_frames=num_frames,
500
+ fps=fps,
501
+ num_inference_steps=num_inference_steps,
502
+ guidance_scale=guidance_scale,
503
+ use_dynamic_cfg=False,
504
+ generator=torch.Generator(device=device).manual_seed(seed),
505
+ )
506
+ window_results.append(output)
507
+
508
+ progress(0.8, "Merging results from all windows")
509
+ # Merge window results
510
+ (
511
+ merged_rgb,
512
+ merged_disparity,
513
+ merged_poses,
514
+ pointmaps,
515
+ ) = blend_and_merge_window_results(window_results, window_indices, args)
516
+
517
+ progress(0.9, "Saving output files")
518
+ # Save output files
519
+ output_dir = "outputs"
520
+ os.makedirs(output_dir, exist_ok=True)
521
+ output_paths = save_output_files(
522
+ rgb=merged_rgb,
523
+ disparity=merged_disparity,
524
+ poses=merged_poses,
525
+ pointmap=pointmaps,
526
+ task="reconstruction",
527
+ output_dir=output_dir,
528
+ fps=12,
529
+ **args,
530
+ )
531
+
532
+ progress(1.0, "Done!")
533
+
534
+ # Return paths for displaying
535
+ return (
536
+ output_paths["rgb"],
537
+ output_paths["disparity"],
538
+ output_paths.get("pointcloud_glbs", []),
539
+ )
540
+
541
+ except Exception:
542
+ import traceback
543
+
544
+ traceback.print_exc()
545
+ return None, None, []
546
+
547
+
548
+ def process_prediction(
549
+ image_file,
550
+ height,
551
+ width,
552
+ num_frames,
553
+ num_inference_steps,
554
+ guidance_scale,
555
+ use_dynamic_cfg,
556
+ raymap_option,
557
+ post_reconstruction,
558
+ fps,
559
+ smooth_camera,
560
+ align_pointmaps,
561
+ max_depth,
562
+ rtol,
563
+ pointcloud_save_frame_interval,
564
+ seed,
565
+ progress=gr.Progress(),
566
+ ):
567
+ """
568
+ Process prediction task.
569
+ """
570
+ try:
571
+ gc.collect()
572
+ torch.cuda.empty_cache()
573
+
574
+ # Set random seed
575
+ seed_all(seed)
576
+
577
+ # Build the pipeline
578
+ pipeline = build_pipeline()
579
+
580
+ progress(0.1, "Loading image")
581
+ # Check if image_file is a string or a file object
582
+ if isinstance(image_file, str):
583
+ image_path = image_file
584
+ else:
585
+ image_path = image_file.name
586
+
587
+ image = PIL.Image.open(image_path)
588
+
589
+ progress(0.2, "Running prediction")
590
+ # Run prediction
591
+ output = pipeline(
592
+ task="prediction",
593
+ image=image,
594
+ video=None,
595
+ goal=None,
596
+ raymap=np.load(f"assets/example_raymaps/raymap_{raymap_option}.npy"),
597
+ height=height,
598
+ width=width,
599
+ num_frames=num_frames,
600
+ fps=fps,
601
+ num_inference_steps=num_inference_steps,
602
+ guidance_scale=guidance_scale,
603
+ use_dynamic_cfg=use_dynamic_cfg,
604
+ generator=torch.Generator(device=device).manual_seed(seed),
605
+ return_dict=True,
606
+ )
607
+
608
+ # Show RGB output immediately
609
+ rgb_output = output.rgb
610
+
611
+ # Setup arguments for saving
612
+ args = {
613
+ "height": height,
614
+ "width": width,
615
+ "smooth_camera": smooth_camera,
616
+ "smooth_method": "kalman" if smooth_camera else "none",
617
+ "align_pointmaps": align_pointmaps,
618
+ "max_depth": max_depth,
619
+ "rtol": rtol,
620
+ "pointcloud_save_frame_interval": pointcloud_save_frame_interval,
621
+ }
622
+
623
+ if post_reconstruction:
624
+ progress(0.5, "Running post-reconstruction for better quality")
625
+ recon_output = pipeline(
626
+ task="reconstruction",
627
+ video=output.rgb,
628
+ height=height,
629
+ width=width,
630
+ num_frames=num_frames,
631
+ fps=fps,
632
+ num_inference_steps=4,
633
+ guidance_scale=1.0,
634
+ use_dynamic_cfg=False,
635
+ generator=torch.Generator(device=device).manual_seed(seed),
636
+ )
637
+
638
+ disparity = recon_output.disparity
639
+ raymap = recon_output.raymap
640
+ else:
641
+ disparity = output.disparity
642
+ raymap = output.raymap
643
+
644
+ progress(0.8, "Saving output files")
645
+ # Save output files
646
+ output_dir = "outputs"
647
+ os.makedirs(output_dir, exist_ok=True)
648
+ output_paths = save_output_files(
649
+ rgb=rgb_output,
650
+ disparity=disparity,
651
+ raymap=raymap,
652
+ task="prediction",
653
+ output_dir=output_dir,
654
+ fps=12,
655
+ **args,
656
+ )
657
+
658
+ progress(1.0, "Done!")
659
+
660
+ # Return paths for displaying
661
+ return (
662
+ output_paths["rgb"],
663
+ output_paths["disparity"],
664
+ output_paths.get("pointcloud_glbs", []),
665
+ )
666
+
667
+ except Exception:
668
+ import traceback
669
+
670
+ traceback.print_exc()
671
+ return None, None, []
672
+
673
+
674
+ def process_planning(
675
+ image_file,
676
+ goal_file,
677
+ height,
678
+ width,
679
+ num_frames,
680
+ num_inference_steps,
681
+ guidance_scale,
682
+ use_dynamic_cfg,
683
+ post_reconstruction,
684
+ fps,
685
+ smooth_camera,
686
+ align_pointmaps,
687
+ max_depth,
688
+ rtol,
689
+ pointcloud_save_frame_interval,
690
+ seed,
691
+ progress=gr.Progress(),
692
+ ):
693
+ """
694
+ Process planning task.
695
+ """
696
+ try:
697
+ gc.collect()
698
+ torch.cuda.empty_cache()
699
+
700
+ # Set random seed
701
+ seed_all(seed)
702
+
703
+ # Build the pipeline
704
+ pipeline = build_pipeline()
705
+
706
+ progress(0.1, "Loading images")
707
+ # Check if image_file and goal_file are strings or file objects
708
+ if isinstance(image_file, str):
709
+ image_path = image_file
710
+ else:
711
+ image_path = image_file.name
712
+
713
+ if isinstance(goal_file, str):
714
+ goal_path = goal_file
715
+ else:
716
+ goal_path = goal_file.name
717
+
718
+ image = PIL.Image.open(image_path)
719
+ goal = PIL.Image.open(goal_path)
720
+
721
+ progress(0.2, "Running planning")
722
+ # Run planning
723
+ output = pipeline(
724
+ task="planning",
725
+ image=image,
726
+ video=None,
727
+ goal=goal,
728
+ raymap=None,
729
+ height=height,
730
+ width=width,
731
+ num_frames=num_frames,
732
+ fps=fps,
733
+ num_inference_steps=num_inference_steps,
734
+ guidance_scale=guidance_scale,
735
+ use_dynamic_cfg=use_dynamic_cfg,
736
+ generator=torch.Generator(device=device).manual_seed(seed),
737
+ return_dict=True,
738
+ )
739
+
740
+ # Show RGB output immediately
741
+ rgb_output = output.rgb
742
+
743
+ # Setup arguments for saving
744
+ args = {
745
+ "height": height,
746
+ "width": width,
747
+ "smooth_camera": smooth_camera,
748
+ "smooth_method": "kalman" if smooth_camera else "none",
749
+ "align_pointmaps": align_pointmaps,
750
+ "max_depth": max_depth,
751
+ "rtol": rtol,
752
+ "pointcloud_save_frame_interval": pointcloud_save_frame_interval,
753
+ }
754
+
755
+ if post_reconstruction:
756
+ progress(0.5, "Running post-reconstruction for better quality")
757
+ recon_output = pipeline(
758
+ task="reconstruction",
759
+ video=output.rgb,
760
+ height=height,
761
+ width=width,
762
+ num_frames=num_frames,
763
+ fps=12,
764
+ num_inference_steps=4,
765
+ guidance_scale=1.0,
766
+ use_dynamic_cfg=False,
767
+ generator=torch.Generator(device=device).manual_seed(seed),
768
+ )
769
+
770
+ disparity = recon_output.disparity
771
+ raymap = recon_output.raymap
772
+ else:
773
+ disparity = output.disparity
774
+ raymap = output.raymap
775
+
776
+ progress(0.8, "Saving output files")
777
+ # Save output files
778
+ output_dir = "outputs"
779
+ os.makedirs(output_dir, exist_ok=True)
780
+ output_paths = save_output_files(
781
+ rgb=rgb_output,
782
+ disparity=disparity,
783
+ raymap=raymap,
784
+ task="planning",
785
+ output_dir=output_dir,
786
+ fps=fps,
787
+ **args,
788
+ )
789
+
790
+ progress(1.0, "Done!")
791
+
792
+ # Return paths for displaying
793
+ return (
794
+ output_paths["rgb"],
795
+ output_paths["disparity"],
796
+ output_paths.get("pointcloud_glbs", []),
797
+ )
798
+
799
+ except Exception:
800
+ import traceback
801
+
802
+ traceback.print_exc()
803
+ return None, None, []
804
+
805
+
806
+ def update_task_ui(task):
807
+ """Update UI elements based on selected task."""
808
+ if task == "reconstruction":
809
+ return (
810
+ gr.update(visible=True), # video_input
811
+ gr.update(visible=False), # image_input
812
+ gr.update(visible=False), # goal_input
813
+ gr.update(visible=False), # image_preview
814
+ gr.update(visible=False), # goal_preview
815
+ gr.update(value=4), # num_inference_steps
816
+ gr.update(visible=True), # sliding_window_stride
817
+ gr.update(visible=False), # use_dynamic_cfg
818
+ gr.update(visible=False), # raymap_option
819
+ gr.update(visible=False), # post_reconstruction
820
+ gr.update(value=1.0), # guidance_scale
821
+ )
822
+ elif task == "prediction":
823
+ return (
824
+ gr.update(visible=False), # video_input
825
+ gr.update(visible=True), # image_input
826
+ gr.update(visible=False), # goal_input
827
+ gr.update(visible=True), # image_preview
828
+ gr.update(visible=False), # goal_preview
829
+ gr.update(value=50), # num_inference_steps
830
+ gr.update(visible=False), # sliding_window_stride
831
+ gr.update(visible=True), # use_dynamic_cfg
832
+ gr.update(visible=True), # raymap_option
833
+ gr.update(visible=True), # post_reconstruction
834
+ gr.update(value=3.0), # guidance_scale
835
+ )
836
+ elif task == "planning":
837
+ return (
838
+ gr.update(visible=False), # video_input
839
+ gr.update(visible=True), # image_input
840
+ gr.update(visible=True), # goal_input
841
+ gr.update(visible=True), # image_preview
842
+ gr.update(visible=True), # goal_preview
843
+ gr.update(value=50), # num_inference_steps
844
+ gr.update(visible=False), # sliding_window_stride
845
+ gr.update(visible=True), # use_dynamic_cfg
846
+ gr.update(visible=False), # raymap_option
847
+ gr.update(visible=True), # post_reconstruction
848
+ gr.update(value=3.0), # guidance_scale
849
+ )
850
+
851
+
852
+ def update_image_preview(image_file):
853
+ """Update the image preview."""
854
+ if image_file:
855
+ return image_file.name
856
+ return None
857
+
858
+
859
+ def update_goal_preview(goal_file):
860
+ """Update the goal preview."""
861
+ if goal_file:
862
+ return goal_file.name
863
+ return None
864
+
865
+
866
+ def get_download_link(selected_frame, all_paths):
867
+ """Update the download button with the selected file path."""
868
+ if not selected_frame or not all_paths:
869
+ return gr.update(visible=False, value=None)
870
+
871
+ frame_num = int(re.search(r"Frame (\d+)", selected_frame).group(1))
872
+
873
+ for path in all_paths:
874
+ if f"frame_{frame_num}" in path:
875
+ # Make sure the file exists before setting it
876
+ if os.path.exists(path):
877
+ return gr.update(visible=True, value=path, interactive=True)
878
+
879
+ return gr.update(visible=False, value=None)
880
+
881
+
882
+ # Theme setup
883
+ theme = gr.themes.Default(
884
+ primary_hue="blue",
885
+ secondary_hue="cyan",
886
+ )
887
+
888
+ with gr.Blocks(
889
+ theme=theme,
890
+ css="""
891
+ .output-column {
892
+ min-height: 400px;
893
+ }
894
+ .warning {
895
+ color: #ff9800;
896
+ font-weight: bold;
897
+ }
898
+ .highlight {
899
+ background-color: rgba(0, 123, 255, 0.1);
900
+ padding: 10px;
901
+ border-radius: 8px;
902
+ border-left: 5px solid #007bff;
903
+ margin: 10px 0;
904
+ }
905
+ .task-header {
906
+ margin-top: 10px;
907
+ margin-bottom: 15px;
908
+ font-size: 1.2em;
909
+ font-weight: bold;
910
+ color: #007bff;
911
+ }
912
+ .flex-display {
913
+ display: flex;
914
+ flex-wrap: wrap;
915
+ gap: 10px;
916
+ }
917
+ .output-subtitle {
918
+ font-size: 1.1em;
919
+ margin-top: 5px;
920
+ margin-bottom: 5px;
921
+ color: #505050;
922
+ }
923
+ .input-section, .params-section, .advanced-section {
924
+ border: 1px solid #ddd;
925
+ padding: 15px;
926
+ border-radius: 8px;
927
+ margin-bottom: 15px;
928
+ }
929
+ .logo-container {
930
+ display: flex;
931
+ justify-content: center;
932
+ margin-bottom: 20px;
933
+ }
934
+ .logo-image {
935
+ max-width: 300px;
936
+ height: auto;
937
+ }
938
+ """,
939
+ ) as demo:
940
+ with gr.Row(elem_classes=["logo-container"]):
941
+ gr.Image("assets/logo.png", show_label=False, elem_classes=["logo-image"])
942
+
943
+ gr.Markdown(
944
+ """
945
+ # Aether: Geometric-Aware Unified World Modeling
946
+
947
+ Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with
948
+ generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities:
949
+
950
+ 1. **4D dynamic reconstruction** - Reconstruct dynamic point clouds from videos by estimating depths and camera poses.
951
+ 2. **Action-Conditioned Video Prediction** - Predict future frames based on initial observation images, with optional conditions of camera trajectory actions.
952
+ 3. **Goal-Conditioned Visual Planning** - Generate planning paths from pairs of observation and goal images.
953
+
954
+ Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.
955
+ """
956
+ )
957
+
958
+ with gr.Row():
959
+ with gr.Column(scale=1):
960
+ task = gr.Radio(
961
+ ["reconstruction", "prediction", "planning"],
962
+ label="Select Task",
963
+ value="reconstruction",
964
+ info="Choose the task you want to perform",
965
+ )
966
+
967
+ with gr.Group(elem_classes=["input-section"]):
968
+ # Input section - changes based on task
969
+ gr.Markdown("## 📥 Input", elem_classes=["task-header"])
970
+
971
+ # Task-specific inputs
972
+ video_input = gr.Video(
973
+ label="Upload Input Video",
974
+ sources=["upload"],
975
+ visible=True,
976
+ interactive=True,
977
+ elem_id="video_input",
978
+ )
979
+
980
+ image_input = gr.File(
981
+ label="Upload Start Image",
982
+ file_count="single",
983
+ file_types=["image"],
984
+ visible=False,
985
+ interactive=True,
986
+ elem_id="image_input",
987
+ )
988
+
989
+ goal_input = gr.File(
990
+ label="Upload Goal Image",
991
+ file_count="single",
992
+ file_types=["image"],
993
+ visible=False,
994
+ interactive=True,
995
+ elem_id="goal_input",
996
+ )
997
+
998
+ with gr.Row(visible=False) as preview_row:
999
+ image_preview = gr.Image(
1000
+ label="Start Image Preview",
1001
+ elem_id="image_preview",
1002
+ visible=False,
1003
+ )
1004
+ goal_preview = gr.Image(
1005
+ label="Goal Image Preview",
1006
+ elem_id="goal_preview",
1007
+ visible=False,
1008
+ )
1009
+
1010
+ with gr.Group(elem_classes=["params-section"]):
1011
+ gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"])
1012
+
1013
+ with gr.Row():
1014
+ with gr.Column(scale=1):
1015
+ height = gr.Dropdown(
1016
+ choices=[480],
1017
+ value=480,
1018
+ label="Height",
1019
+ info="Height of the output video",
1020
+ )
1021
+
1022
+ with gr.Column(scale=1):
1023
+ width = gr.Dropdown(
1024
+ choices=[720],
1025
+ value=720,
1026
+ label="Width",
1027
+ info="Width of the output video",
1028
+ )
1029
+
1030
+ with gr.Row():
1031
+ with gr.Column(scale=1):
1032
+ num_frames = gr.Dropdown(
1033
+ choices=[17, 25, 33, 41],
1034
+ value=41,
1035
+ label="Number of Frames",
1036
+ info="Number of frames to predict",
1037
+ )
1038
+
1039
+ with gr.Column(scale=1):
1040
+ fps = gr.Dropdown(
1041
+ choices=[8, 10, 12, 15, 24],
1042
+ value=12,
1043
+ label="FPS",
1044
+ info="Frames per second",
1045
+ )
1046
+
1047
+ with gr.Row():
1048
+ with gr.Column(scale=1):
1049
+ num_inference_steps = gr.Slider(
1050
+ minimum=1,
1051
+ maximum=60,
1052
+ value=4,
1053
+ step=1,
1054
+ label="Inference Steps",
1055
+ info="Number of inference step",
1056
+ )
1057
+
1058
+ sliding_window_stride = gr.Slider(
1059
+ minimum=1,
1060
+ maximum=40,
1061
+ value=24,
1062
+ step=1,
1063
+ label="Sliding Window Stride",
1064
+ info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task",
1065
+ visible=True,
1066
+ )
1067
+
1068
+ use_dynamic_cfg = gr.Checkbox(
1069
+ label="Use Dynamic CFG",
1070
+ value=True,
1071
+ info="Use dynamic CFG",
1072
+ visible=False,
1073
+ )
1074
+
1075
+ raymap_option = gr.Radio(
1076
+ choices=["backward", "forward_right", "left_forward", "right"],
1077
+ label="Camera Movement Direction",
1078
+ value="forward_right",
1079
+ info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.",
1080
+ visible=False,
1081
+ )
1082
+
1083
+ post_reconstruction = gr.Checkbox(
1084
+ label="Post-Reconstruction",
1085
+ value=True,
1086
+ info="Run reconstruction after prediction for better quality",
1087
+ visible=False,
1088
+ )
1089
+
1090
+ with gr.Accordion(
1091
+ "Advanced Options", open=False, visible=True
1092
+ ) as advanced_options:
1093
+ with gr.Group(elem_classes=["advanced-section"]):
1094
+ with gr.Row():
1095
+ with gr.Column(scale=1):
1096
+ guidance_scale = gr.Slider(
1097
+ minimum=1.0,
1098
+ maximum=10.0,
1099
+ value=1.0,
1100
+ step=0.1,
1101
+ label="Guidance Scale",
1102
+ info="Guidance scale (only for prediction / planning)",
1103
+ )
1104
+
1105
+ with gr.Row():
1106
+ with gr.Column(scale=1):
1107
+ seed = gr.Number(
1108
+ value=42,
1109
+ label="Random Seed",
1110
+ info="Set a seed for reproducible results",
1111
+ precision=0,
1112
+ minimum=0,
1113
+ maximum=2147483647,
1114
+ )
1115
+
1116
+ with gr.Row():
1117
+ with gr.Column(scale=1):
1118
+ smooth_camera = gr.Checkbox(
1119
+ label="Smooth Camera",
1120
+ value=True,
1121
+ info="Apply smoothing to camera trajectory",
1122
+ )
1123
+
1124
+ with gr.Column(scale=1):
1125
+ align_pointmaps = gr.Checkbox(
1126
+ label="Align Point Maps",
1127
+ value=False,
1128
+ info="Align point maps across frames",
1129
+ )
1130
+
1131
+ with gr.Row():
1132
+ with gr.Column(scale=1):
1133
+ max_depth = gr.Slider(
1134
+ minimum=10,
1135
+ maximum=200,
1136
+ value=60,
1137
+ step=10,
1138
+ label="Max Depth",
1139
+ info="Maximum depth for point cloud (higher = more distant points)",
1140
+ )
1141
+
1142
+ with gr.Column(scale=1):
1143
+ rtol = gr.Slider(
1144
+ minimum=0.01,
1145
+ maximum=2.0,
1146
+ value=0.03,
1147
+ step=0.01,
1148
+ label="Relative Tolerance",
1149
+ info="Used for depth edge detection. Lower = remove more edges",
1150
+ )
1151
+
1152
+ pointcloud_save_frame_interval = gr.Slider(
1153
+ minimum=1,
1154
+ maximum=20,
1155
+ value=10,
1156
+ step=1,
1157
+ label="Point Cloud Frame Interval",
1158
+ info="Save point cloud every N frames (higher = fewer files but less complete representation)",
1159
+ )
1160
+
1161
+ run_button = gr.Button("Run Aether", variant="primary")
1162
+
1163
+ with gr.Column(scale=1, elem_classes=["output-column"]):
1164
+ with gr.Group():
1165
+ gr.Markdown("## 📤 Output", elem_classes=["task-header"])
1166
+
1167
+ gr.Markdown("### RGB Video", elem_classes=["output-subtitle"])
1168
+ rgb_output = gr.Video(
1169
+ label="RGB Output", interactive=False, elem_id="rgb_output"
1170
+ )
1171
+
1172
+ gr.Markdown("### Depth Video", elem_classes=["output-subtitle"])
1173
+ depth_output = gr.Video(
1174
+ label="Depth Output", interactive=False, elem_id="depth_output"
1175
+ )
1176
+
1177
+ gr.Markdown("### Point Clouds", elem_classes=["output-subtitle"])
1178
+ with gr.Row(elem_classes=["flex-display"]):
1179
+ pointcloud_frames = gr.Dropdown(
1180
+ label="Select Frame",
1181
+ choices=[],
1182
+ value=None,
1183
+ interactive=True,
1184
+ elem_id="pointcloud_frames",
1185
+ )
1186
+ pointcloud_download = gr.DownloadButton(
1187
+ label="Download Point Cloud",
1188
+ visible=False,
1189
+ elem_id="pointcloud_download",
1190
+ )
1191
+
1192
+ model_output = gr.Model3D(
1193
+ label="Point Cloud Viewer", interactive=True, elem_id="model_output"
1194
+ )
1195
+
1196
+ with gr.Tab("About Results"):
1197
+ gr.Markdown(
1198
+ """
1199
+ ### Understanding the Outputs
1200
+
1201
+ - **RGB Video**: Shows the predicted or reconstructed RGB frames
1202
+ - **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue)
1203
+ - **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids
1204
+
1205
+ <p class="warning">Note: 3D point clouds take a long time to visualize, and we show the keyframes only.
1206
+ You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`.</p>
1207
+ """
1208
+ )
1209
+
1210
+ # Event handlers
1211
+ task.change(
1212
+ fn=update_task_ui,
1213
+ inputs=[task],
1214
+ outputs=[
1215
+ video_input,
1216
+ image_input,
1217
+ goal_input,
1218
+ image_preview,
1219
+ goal_preview,
1220
+ num_inference_steps,
1221
+ sliding_window_stride,
1222
+ use_dynamic_cfg,
1223
+ raymap_option,
1224
+ post_reconstruction,
1225
+ guidance_scale,
1226
+ ],
1227
+ )
1228
+
1229
+ image_input.change(
1230
+ fn=update_image_preview, inputs=[image_input], outputs=[image_preview]
1231
+ ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
1232
+
1233
+ goal_input.change(
1234
+ fn=update_goal_preview, inputs=[goal_input], outputs=[goal_preview]
1235
+ ).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
1236
+
1237
+ def update_pointcloud_frames(pointcloud_paths):
1238
+ """Update the pointcloud frames dropdown with available frames."""
1239
+ if not pointcloud_paths:
1240
+ return gr.update(choices=[], value=None), None, gr.update(visible=False)
1241
+
1242
+ # Extract frame numbers from filenames
1243
+ frame_info = []
1244
+ for path in pointcloud_paths:
1245
+ filename = os.path.basename(path)
1246
+ match = re.search(r"frame_(\d+)", filename)
1247
+ if match:
1248
+ frame_num = int(match.group(1))
1249
+ frame_info.append((f"Frame {frame_num}", path))
1250
+
1251
+ # Sort by frame number
1252
+ frame_info.sort(key=lambda x: int(re.search(r"Frame (\d+)", x[0]).group(1)))
1253
+
1254
+ choices = [label for label, _ in frame_info]
1255
+ paths = [path for _, path in frame_info]
1256
+
1257
+ if not choices:
1258
+ return gr.update(choices=[], value=None), None, gr.update(visible=False)
1259
+
1260
+ # Make download button visible when we have point cloud files
1261
+ return (
1262
+ gr.update(choices=choices, value=choices[0]),
1263
+ paths[0],
1264
+ gr.update(visible=True),
1265
+ )
1266
+
1267
+ def select_pointcloud_frame(frame_label, all_paths):
1268
+ """Select a specific pointcloud frame."""
1269
+ if not frame_label or not all_paths:
1270
+ return None
1271
+
1272
+ frame_num = int(re.search(r"Frame (\d+)", frame_label).group(1))
1273
+
1274
+ for path in all_paths:
1275
+ if f"frame_{frame_num}" in path:
1276
+ return path
1277
+
1278
+ return None
1279
+
1280
+ # Then in the run button click handler:
1281
+ def process_task(task_type, *args):
1282
+ """Process selected task with appropriate function."""
1283
+ if task_type == "reconstruction":
1284
+ rgb_path, depth_path, pointcloud_paths = process_reconstruction(*args)
1285
+ # Update the pointcloud frames dropdown
1286
+ frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
1287
+ pointcloud_paths
1288
+ )
1289
+ return (
1290
+ rgb_path,
1291
+ depth_path,
1292
+ initial_path,
1293
+ frame_dropdown,
1294
+ pointcloud_paths,
1295
+ download_visible,
1296
+ )
1297
+ elif task_type == "prediction":
1298
+ rgb_path, depth_path, pointcloud_paths = process_prediction(*args)
1299
+ frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
1300
+ pointcloud_paths
1301
+ )
1302
+ return (
1303
+ rgb_path,
1304
+ depth_path,
1305
+ initial_path,
1306
+ frame_dropdown,
1307
+ pointcloud_paths,
1308
+ download_visible,
1309
+ )
1310
+ elif task_type == "planning":
1311
+ rgb_path, depth_path, pointcloud_paths = process_planning(*args)
1312
+ frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
1313
+ pointcloud_paths
1314
+ )
1315
+ return (
1316
+ rgb_path,
1317
+ depth_path,
1318
+ initial_path,
1319
+ frame_dropdown,
1320
+ pointcloud_paths,
1321
+ download_visible,
1322
+ )
1323
+ return (
1324
+ None,
1325
+ None,
1326
+ None,
1327
+ gr.update(choices=[], value=None),
1328
+ [],
1329
+ gr.update(visible=False),
1330
+ )
1331
+
1332
+ # Store all pointcloud paths for later use
1333
+ all_pointcloud_paths = gr.State([])
1334
+
1335
+ run_button.click(
1336
+ fn=lambda task_type,
1337
+ video_file,
1338
+ image_file,
1339
+ goal_file,
1340
+ height,
1341
+ width,
1342
+ num_frames,
1343
+ num_inference_steps,
1344
+ guidance_scale,
1345
+ sliding_window_stride,
1346
+ use_dynamic_cfg,
1347
+ raymap_option,
1348
+ post_reconstruction,
1349
+ fps,
1350
+ smooth_camera,
1351
+ align_pointmaps,
1352
+ max_depth,
1353
+ rtol,
1354
+ pointcloud_save_frame_interval,
1355
+ seed: process_task(
1356
+ task_type,
1357
+ *(
1358
+ [
1359
+ video_file,
1360
+ height,
1361
+ width,
1362
+ num_frames,
1363
+ num_inference_steps,
1364
+ guidance_scale,
1365
+ sliding_window_stride,
1366
+ fps,
1367
+ smooth_camera,
1368
+ align_pointmaps,
1369
+ max_depth,
1370
+ rtol,
1371
+ pointcloud_save_frame_interval,
1372
+ seed,
1373
+ ]
1374
+ if task_type == "reconstruction"
1375
+ else [
1376
+ image_file,
1377
+ height,
1378
+ width,
1379
+ num_frames,
1380
+ num_inference_steps,
1381
+ guidance_scale,
1382
+ use_dynamic_cfg,
1383
+ raymap_option,
1384
+ post_reconstruction,
1385
+ fps,
1386
+ smooth_camera,
1387
+ align_pointmaps,
1388
+ max_depth,
1389
+ rtol,
1390
+ pointcloud_save_frame_interval,
1391
+ seed,
1392
+ ]
1393
+ if task_type == "prediction"
1394
+ else [
1395
+ image_file,
1396
+ goal_file,
1397
+ height,
1398
+ width,
1399
+ num_frames,
1400
+ num_inference_steps,
1401
+ guidance_scale,
1402
+ use_dynamic_cfg,
1403
+ post_reconstruction,
1404
+ fps,
1405
+ smooth_camera,
1406
+ align_pointmaps,
1407
+ max_depth,
1408
+ rtol,
1409
+ pointcloud_save_frame_interval,
1410
+ seed,
1411
+ ]
1412
+ ),
1413
+ ),
1414
+ inputs=[
1415
+ task,
1416
+ video_input,
1417
+ image_input,
1418
+ goal_input,
1419
+ height,
1420
+ width,
1421
+ num_frames,
1422
+ num_inference_steps,
1423
+ guidance_scale,
1424
+ sliding_window_stride,
1425
+ use_dynamic_cfg,
1426
+ raymap_option,
1427
+ post_reconstruction,
1428
+ fps,
1429
+ smooth_camera,
1430
+ align_pointmaps,
1431
+ max_depth,
1432
+ rtol,
1433
+ pointcloud_save_frame_interval,
1434
+ seed,
1435
+ ],
1436
+ outputs=[
1437
+ rgb_output,
1438
+ depth_output,
1439
+ model_output,
1440
+ pointcloud_frames,
1441
+ all_pointcloud_paths,
1442
+ pointcloud_download,
1443
+ ],
1444
+ )
1445
+
1446
+ pointcloud_frames.change(
1447
+ fn=select_pointcloud_frame,
1448
+ inputs=[pointcloud_frames, all_pointcloud_paths],
1449
+ outputs=[model_output],
1450
+ ).then(
1451
+ fn=get_download_link,
1452
+ inputs=[pointcloud_frames, all_pointcloud_paths],
1453
+ outputs=[pointcloud_download],
1454
+ )
1455
+
1456
+ # Example Accordion
1457
+ with gr.Accordion("Examples"):
1458
+ gr.Markdown(
1459
+ """
1460
+ ### Examples will be added soon
1461
+ Check back for example inputs for each task type.
1462
+ """
1463
+ )
1464
+
1465
+ # Load the model at startup
1466
+ demo.load(lambda: build_pipeline(), inputs=None, outputs=None)
1467
+
1468
+ if __name__ == "__main__":
1469
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
1470
+ demo.queue(max_size=20).launch(show_error=True, share=True)
setup.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import pathlib
3
+ import sys
4
+
5
+ import pkg_resources
6
+ from setuptools import find_packages, setup
7
+
8
+
9
+ PKG_NAME = "aether"
10
+ VERSION = "0.1"
11
+ EXTRAS = {}
12
+
13
+
14
+ def _read_file(fname):
15
+ with pathlib.Path(fname).open() as fp:
16
+ return fp.read()
17
+
18
+
19
+ def _read_install_requires():
20
+ with pathlib.Path("requirements.txt").open() as fp:
21
+ return [
22
+ str(requirement) for requirement in pkg_resources.parse_requirements(fp)
23
+ ]
24
+
25
+
26
+ def _fill_extras(extras):
27
+ if extras:
28
+ extras["all"] = list({item for group in extras.values() for item in group})
29
+ return extras
30
+
31
+
32
+ version_range_max = max(sys.version_info[1], 10) + 1
33
+ setup(
34
+ name=PKG_NAME,
35
+ version=VERSION,
36
+ author="Aether Team",
37
+ author_email="[email protected]",
38
+ url="https://github.com/OpenRobotLab/Aether",
39
+ description="",
40
+ long_description=_read_file("README.md"),
41
+ long_description_content_type="text/markdown",
42
+ keywords=[
43
+ "Deep Learning",
44
+ "Machine Learning",
45
+ "World Model",
46
+ "3D Vision",
47
+ "Reconstruction",
48
+ "Sythetic Data",
49
+ "Embodied AI",
50
+ ],
51
+ license="MIT License",
52
+ packages=find_packages(include=f"{PKG_NAME}.*"),
53
+ include_package_data=True,
54
+ zip_safe=False,
55
+ install_requires=_read_install_requires(),
56
+ extras_require=_fill_extras(EXTRAS),
57
+ python_requires=">=3.8",
58
+ classifiers=[
59
+ "Development Status :: 5 - Production/Stable",
60
+ "Intended Audience :: Developers",
61
+ "Intended Audience :: Education",
62
+ "Intended Audience :: Science/Research",
63
+ "License :: OSI Approved :: MIT License",
64
+ "Operating System :: OS Independent",
65
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
66
+ "Programming Language :: Python :: 3",
67
+ ]
68
+ + [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
69
+ )