Spaces:
Running
on
Zero
Running
on
Zero
Wenzheng Chang
commited on
Commit
·
19da45c
1
Parent(s):
ddee6ec
aetherv1 init
Browse files- .gitattributes +1 -0
- .gitignore +157 -0
- .pre-commit-config.yaml +14 -0
- .project-root +2 -0
- LICENSE +21 -0
- README.md +60 -3
- aether/__init__.py +0 -0
- aether/pipelines/__init__.py +0 -0
- aether/pipelines/aetherv1_pipeline_cogvideox.py +959 -0
- aether/utils/__init__.py +0 -0
- aether/utils/postprocess_utils.py +842 -0
- aether/utils/preprocess_utils.py +39 -0
- aether/utils/visualize_utils.py +255 -0
- app.py +0 -7
- assets/example_obs/car.png +3 -0
- assets/example_obs/cartoon.png +3 -0
- assets/example_obs/garden.jpg +0 -0
- assets/example_obs/room.jpg +0 -0
- assets/example_obs_goal/01_goal.png +3 -0
- assets/example_obs_goal/01_obs.png +3 -0
- assets/example_obs_goal/02_goal.png +3 -0
- assets/example_obs_goal/02_obs.png +3 -0
- assets/example_obs_goal/03_goal.png +3 -0
- assets/example_obs_goal/03_obs.png +3 -0
- assets/example_obs_goal/04_goal.png +3 -0
- assets/example_obs_goal/04_obs.png +3 -0
- assets/example_raymaps/raymap_backward.npy +3 -0
- assets/example_raymaps/raymap_forward_right.npy +3 -0
- assets/example_raymaps/raymap_left_forward.npy +3 -0
- assets/example_raymaps/raymap_right.npy +3 -0
- assets/example_videos/bridge.mp4 +3 -0
- assets/example_videos/moviegen.mp4 +3 -0
- assets/example_videos/nuscenes.mp4 +3 -0
- assets/example_videos/veo2.mp4 +3 -0
- assets/logo.png +3 -0
- assets/teaser.png +3 -0
- pyproject.toml +30 -0
- requirements.txt +41 -0
- scripts/demo.py +614 -0
- scripts/demo_gradio.py +1470 -0
- setup.py +69 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
env/
|
106 |
+
.venv
|
107 |
+
venv/
|
108 |
+
ENV/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
### VisualStudioCode
|
131 |
+
.vscode/*
|
132 |
+
!.vscode/settings.json
|
133 |
+
!.vscode/tasks.json
|
134 |
+
!.vscode/launch.json
|
135 |
+
!.vscode/extensions.json
|
136 |
+
*.code-workspace
|
137 |
+
**/.vscode
|
138 |
+
|
139 |
+
# JetBrains
|
140 |
+
.idea/
|
141 |
+
|
142 |
+
# Data & Models
|
143 |
+
*.h5
|
144 |
+
*.tar
|
145 |
+
*.tar.gz
|
146 |
+
|
147 |
+
# Lightning-Hydra-Template
|
148 |
+
configs/local/default.yaml
|
149 |
+
# data/
|
150 |
+
/logs/
|
151 |
+
.env
|
152 |
+
|
153 |
+
# Aim logging
|
154 |
+
.aim
|
155 |
+
|
156 |
+
# local files
|
157 |
+
logs
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
3 |
+
rev: v0.2.1
|
4 |
+
hooks:
|
5 |
+
- id: ruff
|
6 |
+
exclude: ^test/ # Skip the entire "test" directory
|
7 |
+
args:
|
8 |
+
- --fix
|
9 |
+
- id: ruff-format
|
10 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
11 |
+
rev: v4.5.0
|
12 |
+
hooks:
|
13 |
+
- id: check-merge-conflict
|
14 |
+
- id: check-yaml
|
.project-root
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# this file is required for inferring the project root directory
|
2 |
+
# do not delete
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Aether Team
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,14 +1,71 @@
|
|
1 |
---
|
2 |
title: AetherV1
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.23.1
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: 'Aether: Geometric-Aware Unified World Modeling'
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: AetherV1
|
3 |
+
emoji: 🌏
|
4 |
colorFrom: purple
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.23.1
|
8 |
+
app_file: scripts/demo_gradio.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: 'Aether: Geometric-Aware Unified World Modeling'
|
12 |
---
|
13 |
|
14 |
+
<div align="center">
|
15 |
+
|
16 |
+
# Aether: Geometric-Aware Unified World Modeling
|
17 |
+
|
18 |
+
</div>
|
19 |
+
|
20 |
+
<div align="center">
|
21 |
+
<img width="400" alt="image" src="assets/logo.png">
|
22 |
+
<!-- <br> -->
|
23 |
+
</div>
|
24 |
+
|
25 |
+
<div align="center">
|
26 |
+
<a href='https://arxiv.org/abs/2503.18945'><img src='https://img.shields.io/badge/arXiv-2503.18945-red'></a>
|
27 |
+
<a href='https://aether-world.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
28 |
+
<a href=''><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo%20(Coming%20Soon)-blue'></a>
|
29 |
+
</div>
|
30 |
+
|
31 |
+
Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with generative modeling
|
32 |
+
for human-like spatial reasoning. Our framework unifies three core capabilities: (1) **4D dynamic reconstruction**,
|
33 |
+
(2) **action-conditioned video prediction**, and (3) **goal-conditioned visual planning**. Trained entirely on
|
34 |
+
synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.
|
35 |
+
|
36 |
+
<div align="center">
|
37 |
+
<img src="assets/teaser.png" alt="Teaser" width="800"/>
|
38 |
+
</div>
|
39 |
+
|
40 |
+
|
41 |
+
## 📝 Citation
|
42 |
+
If you find this work useful in your research, please consider citing:
|
43 |
+
|
44 |
+
```bibtex
|
45 |
+
@article{aether,
|
46 |
+
title = {Aether: Geometric-Aware Unified World Modeling},
|
47 |
+
author = {Aether Team and Haoyi Zhu and Yifan Wang and Jianjun Zhou and Wenzheng Chang and Yang Zhou and Zizun Li and Junyi Chen and Chunhua Shen and Jiangmiao Pang and Tong He},
|
48 |
+
journal = {arXiv preprint arXiv:2503.18945},
|
49 |
+
year = {2025}
|
50 |
+
}
|
51 |
+
```
|
52 |
+
|
53 |
+
## ⚖️ License
|
54 |
+
This repository is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
55 |
+
|
56 |
+
## 🙏 Acknowledgements
|
57 |
+
Our work is primarily built upon
|
58 |
+
[Accelerate](https://github.com/huggingface/accelerate),
|
59 |
+
[Diffusers](https://github.com/huggingface/diffusers),
|
60 |
+
[CogVideoX](https://github.com/THUDM/CogVideo),
|
61 |
+
[Finetrainers](https://github.com/a-r-r-o-w/finetrainers),
|
62 |
+
[DepthAnyVideo](https://github.com/Nightmare-n/DepthAnyVideo),
|
63 |
+
[CUT3R](https://github.com/CUT3R/CUT3R),
|
64 |
+
[MonST3R](https://github.com/Junyi42/monst3r),
|
65 |
+
[VBench](https://github.com/Vchitect/VBench),
|
66 |
+
[GST](https://github.com/SOTAMak1r/GST),
|
67 |
+
[SPA](https://github.com/HaoyiZhu/SPA),
|
68 |
+
[DroidCalib](https://github.com/boschresearch/DroidCalib),
|
69 |
+
[Grounded-SAM-2](https://github.com/IDEA-Research/Grounded-SAM-2),
|
70 |
+
[ceres-solver](https://github.com/ceres-solver/ceres-solver), etc.
|
71 |
+
We extend our gratitude to all these authors for their generously open-sourced code and their significant contributions to the community.
|
aether/__init__.py
ADDED
File without changes
|
aether/pipelines/__init__.py
ADDED
File without changes
|
aether/pipelines/aetherv1_pipeline_cogvideox.py
ADDED
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import math
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL
|
8 |
+
import torch
|
9 |
+
from diffusers import (
|
10 |
+
AutoencoderKLCogVideoX,
|
11 |
+
CogVideoXDPMScheduler,
|
12 |
+
CogVideoXImageToVideoPipeline,
|
13 |
+
CogVideoXTransformer3DModel,
|
14 |
+
)
|
15 |
+
from diffusers.image_processor import PipelineImageInput
|
16 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
17 |
+
from diffusers.utils import BaseOutput
|
18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
19 |
+
from einops import rearrange
|
20 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
21 |
+
|
22 |
+
from aether.utils.preprocess_utils import imcrop_center
|
23 |
+
|
24 |
+
|
25 |
+
def get_3d_rotary_pos_embed(
|
26 |
+
embed_dim,
|
27 |
+
crops_coords,
|
28 |
+
grid_size,
|
29 |
+
temporal_size,
|
30 |
+
theta: int = 10000,
|
31 |
+
use_real: bool = True,
|
32 |
+
grid_type: str = "linspace",
|
33 |
+
max_size: Optional[Tuple[int, int]] = None,
|
34 |
+
device: Optional[torch.device] = None,
|
35 |
+
fps_factor: Optional[float] = 1.0,
|
36 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
37 |
+
"""
|
38 |
+
RoPE for video tokens with 3D structure.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
embed_dim: (`int`):
|
42 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
43 |
+
crops_coords (`Tuple[int]`):
|
44 |
+
The top-left and bottom-right coordinates of the crop.
|
45 |
+
grid_size (`Tuple[int]`):
|
46 |
+
The grid size of the spatial positional embedding (height, width).
|
47 |
+
temporal_size (`int`):
|
48 |
+
The size of the temporal dimension.
|
49 |
+
theta (`float`):
|
50 |
+
Scaling factor for frequency computation.
|
51 |
+
grid_type (`str`):
|
52 |
+
Whether to use "linspace" or "slice" to compute grids.
|
53 |
+
fps_factor (`float`):
|
54 |
+
The relative fps factor of the video, computed by base_fps / fps. Useful for variable fps training.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
58 |
+
"""
|
59 |
+
if use_real is not True:
|
60 |
+
raise ValueError(
|
61 |
+
" `use_real = False` is not currently supported for get_3d_rotary_pos_embed"
|
62 |
+
)
|
63 |
+
|
64 |
+
if grid_type == "linspace":
|
65 |
+
start, stop = crops_coords
|
66 |
+
grid_size_h, grid_size_w = grid_size
|
67 |
+
grid_h = torch.linspace(
|
68 |
+
start[0],
|
69 |
+
stop[0] * (grid_size_h - 1) / grid_size_h,
|
70 |
+
grid_size_h,
|
71 |
+
device=device,
|
72 |
+
dtype=torch.float32,
|
73 |
+
)
|
74 |
+
grid_w = torch.linspace(
|
75 |
+
start[1],
|
76 |
+
stop[1] * (grid_size_w - 1) / grid_size_w,
|
77 |
+
grid_size_w,
|
78 |
+
device=device,
|
79 |
+
dtype=torch.float32,
|
80 |
+
)
|
81 |
+
grid_t = (
|
82 |
+
torch.linspace(
|
83 |
+
0,
|
84 |
+
temporal_size * (temporal_size - 1) / temporal_size,
|
85 |
+
temporal_size,
|
86 |
+
device=device,
|
87 |
+
dtype=torch.float32,
|
88 |
+
)
|
89 |
+
* fps_factor
|
90 |
+
)
|
91 |
+
elif grid_type == "slice":
|
92 |
+
max_h, max_w = max_size
|
93 |
+
grid_size_h, grid_size_w = grid_size
|
94 |
+
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
|
95 |
+
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
|
96 |
+
grid_t = (
|
97 |
+
torch.arange(temporal_size, device=device, dtype=torch.float32) * fps_factor
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
raise ValueError("Invalid value passed for `grid_type`.")
|
101 |
+
|
102 |
+
# Compute dimensions for each axis
|
103 |
+
dim_t = embed_dim // 4
|
104 |
+
dim_h = embed_dim // 8 * 3
|
105 |
+
dim_w = embed_dim // 8 * 3
|
106 |
+
|
107 |
+
# Temporal frequencies
|
108 |
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
|
109 |
+
# Spatial frequencies for height and width
|
110 |
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
|
111 |
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
|
112 |
+
|
113 |
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
114 |
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
115 |
+
freqs_t = freqs_t[:, None, None, :].expand(
|
116 |
+
-1, grid_size_h, grid_size_w, -1
|
117 |
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
118 |
+
freqs_h = freqs_h[None, :, None, :].expand(
|
119 |
+
temporal_size, -1, grid_size_w, -1
|
120 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
121 |
+
freqs_w = freqs_w[None, None, :, :].expand(
|
122 |
+
temporal_size, grid_size_h, -1, -1
|
123 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
124 |
+
|
125 |
+
freqs = torch.cat(
|
126 |
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
127 |
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
128 |
+
freqs = freqs.view(
|
129 |
+
temporal_size * grid_size_h * grid_size_w, -1
|
130 |
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
131 |
+
return freqs
|
132 |
+
|
133 |
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
134 |
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
135 |
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
136 |
+
|
137 |
+
if grid_type == "slice":
|
138 |
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
139 |
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
140 |
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
141 |
+
|
142 |
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
143 |
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
144 |
+
return cos, sin
|
145 |
+
|
146 |
+
|
147 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
148 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
149 |
+
tw = tgt_width
|
150 |
+
th = tgt_height
|
151 |
+
h, w = src
|
152 |
+
r = h / w
|
153 |
+
if r > (th / tw):
|
154 |
+
resize_height = th
|
155 |
+
resize_width = int(round(th / h * w))
|
156 |
+
else:
|
157 |
+
resize_width = tw
|
158 |
+
resize_height = int(round(tw / w * h))
|
159 |
+
|
160 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
161 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
162 |
+
|
163 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
164 |
+
|
165 |
+
|
166 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
167 |
+
def retrieve_timesteps(
|
168 |
+
scheduler,
|
169 |
+
num_inference_steps: Optional[int] = None,
|
170 |
+
device: Optional[Union[str, torch.device]] = None,
|
171 |
+
timesteps: Optional[List[int]] = None,
|
172 |
+
sigmas: Optional[List[float]] = None,
|
173 |
+
**kwargs,
|
174 |
+
):
|
175 |
+
r"""
|
176 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
177 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
scheduler (`SchedulerMixin`):
|
181 |
+
The scheduler to get timesteps from.
|
182 |
+
num_inference_steps (`int`):
|
183 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
184 |
+
must be `None`.
|
185 |
+
device (`str` or `torch.device`, *optional*):
|
186 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
187 |
+
timesteps (`List[int]`, *optional*):
|
188 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
189 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
190 |
+
sigmas (`List[float]`, *optional*):
|
191 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
192 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
196 |
+
second element is the number of inference steps.
|
197 |
+
"""
|
198 |
+
if timesteps is not None and sigmas is not None:
|
199 |
+
raise ValueError(
|
200 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
201 |
+
)
|
202 |
+
if timesteps is not None:
|
203 |
+
accepts_timesteps = "timesteps" in set(
|
204 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
205 |
+
)
|
206 |
+
if not accepts_timesteps:
|
207 |
+
raise ValueError(
|
208 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
209 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
210 |
+
)
|
211 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
212 |
+
timesteps = scheduler.timesteps
|
213 |
+
num_inference_steps = len(timesteps)
|
214 |
+
elif sigmas is not None:
|
215 |
+
accept_sigmas = "sigmas" in set(
|
216 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
217 |
+
)
|
218 |
+
if not accept_sigmas:
|
219 |
+
raise ValueError(
|
220 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
221 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
222 |
+
)
|
223 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
224 |
+
timesteps = scheduler.timesteps
|
225 |
+
num_inference_steps = len(timesteps)
|
226 |
+
else:
|
227 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
228 |
+
timesteps = scheduler.timesteps
|
229 |
+
return timesteps, num_inference_steps
|
230 |
+
|
231 |
+
|
232 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
233 |
+
def retrieve_latents(
|
234 |
+
encoder_output: torch.Tensor,
|
235 |
+
generator: Optional[torch.Generator] = None,
|
236 |
+
sample_mode: str = "sample",
|
237 |
+
):
|
238 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
239 |
+
return encoder_output.latent_dist.sample(generator)
|
240 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
241 |
+
return encoder_output.latent_dist.mode()
|
242 |
+
elif hasattr(encoder_output, "latents"):
|
243 |
+
return encoder_output.latents
|
244 |
+
else:
|
245 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
246 |
+
|
247 |
+
|
248 |
+
@dataclass
|
249 |
+
class AetherV1PipelineOutput(BaseOutput):
|
250 |
+
rgb: np.ndarray
|
251 |
+
disparity: np.ndarray
|
252 |
+
raymap: np.ndarray
|
253 |
+
|
254 |
+
|
255 |
+
class AetherV1PipelineCogVideoX(CogVideoXImageToVideoPipeline):
|
256 |
+
_supported_tasks = ["reconstruction", "prediction", "planning"]
|
257 |
+
_default_num_inference_steps = {
|
258 |
+
"reconstruction": 4,
|
259 |
+
"prediction": 50,
|
260 |
+
"planning": 50,
|
261 |
+
}
|
262 |
+
_default_guidance_scale = {
|
263 |
+
"reconstruction": 1.0,
|
264 |
+
"prediction": 3.0,
|
265 |
+
"planning": 3.0,
|
266 |
+
}
|
267 |
+
_default_use_dynamic_cfg = {
|
268 |
+
"reconstruction": False,
|
269 |
+
"prediction": True,
|
270 |
+
"planning": True,
|
271 |
+
}
|
272 |
+
_base_fps = 12
|
273 |
+
|
274 |
+
def __init__(
|
275 |
+
self,
|
276 |
+
tokenizer: AutoTokenizer,
|
277 |
+
text_encoder: T5EncoderModel,
|
278 |
+
vae: AutoencoderKLCogVideoX,
|
279 |
+
scheduler: CogVideoXDPMScheduler,
|
280 |
+
transformer: CogVideoXTransformer3DModel,
|
281 |
+
):
|
282 |
+
super().__init__(
|
283 |
+
tokenizer=tokenizer,
|
284 |
+
text_encoder=text_encoder,
|
285 |
+
vae=vae,
|
286 |
+
scheduler=scheduler,
|
287 |
+
transformer=transformer,
|
288 |
+
)
|
289 |
+
|
290 |
+
self.empty_prompt_embeds, _ = self.encode_prompt(
|
291 |
+
prompt="",
|
292 |
+
negative_prompt=None,
|
293 |
+
do_classifier_free_guidance=False,
|
294 |
+
num_videos_per_prompt=1,
|
295 |
+
prompt_embeds=None,
|
296 |
+
)
|
297 |
+
|
298 |
+
def _prepare_rotary_positional_embeddings(
|
299 |
+
self,
|
300 |
+
height: int,
|
301 |
+
width: int,
|
302 |
+
num_frames: int,
|
303 |
+
device: torch.device,
|
304 |
+
fps: Optional[int] = None,
|
305 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
306 |
+
grid_height = height // (
|
307 |
+
self.vae_scale_factor_spatial * self.transformer.config.patch_size
|
308 |
+
)
|
309 |
+
grid_width = width // (
|
310 |
+
self.vae_scale_factor_spatial * self.transformer.config.patch_size
|
311 |
+
)
|
312 |
+
|
313 |
+
p = self.transformer.config.patch_size
|
314 |
+
p_t = self.transformer.config.patch_size_t
|
315 |
+
|
316 |
+
base_size_width = self.transformer.config.sample_width // p
|
317 |
+
base_size_height = self.transformer.config.sample_height // p
|
318 |
+
|
319 |
+
if p_t is None:
|
320 |
+
# CogVideoX 1.0
|
321 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
322 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
323 |
+
)
|
324 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
325 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
326 |
+
crops_coords=grid_crops_coords,
|
327 |
+
grid_size=(grid_height, grid_width),
|
328 |
+
temporal_size=num_frames,
|
329 |
+
device=device,
|
330 |
+
fps_factor=self._base_fps / fps,
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
# CogVideoX 1.5
|
334 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
335 |
+
|
336 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
337 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
338 |
+
crops_coords=None,
|
339 |
+
grid_size=(grid_height, grid_width),
|
340 |
+
temporal_size=base_num_frames,
|
341 |
+
grid_type="slice",
|
342 |
+
max_size=(base_size_height, base_size_width),
|
343 |
+
device=device,
|
344 |
+
fps_factor=self._base_fps / fps,
|
345 |
+
)
|
346 |
+
|
347 |
+
return freqs_cos, freqs_sin
|
348 |
+
|
349 |
+
def check_inputs(
|
350 |
+
self,
|
351 |
+
task,
|
352 |
+
image,
|
353 |
+
video,
|
354 |
+
goal,
|
355 |
+
raymap,
|
356 |
+
height,
|
357 |
+
width,
|
358 |
+
num_frames,
|
359 |
+
fps,
|
360 |
+
):
|
361 |
+
if task not in self._supported_tasks:
|
362 |
+
raise ValueError(f"`task` has to be one of {self._supported_tasks}.")
|
363 |
+
|
364 |
+
if image is None and video is None:
|
365 |
+
raise ValueError("`image` or `video` has to be provided.")
|
366 |
+
|
367 |
+
if image is not None and video is not None:
|
368 |
+
raise ValueError("`image` and `video` cannot both be provided.")
|
369 |
+
|
370 |
+
if image is not None:
|
371 |
+
if task == "reconstruction":
|
372 |
+
raise ValueError("`image` is not supported for `reconstruction` task.")
|
373 |
+
if (
|
374 |
+
not isinstance(image, torch.Tensor)
|
375 |
+
and not isinstance(image, np.ndarray)
|
376 |
+
and not isinstance(image, PIL.Image.Image)
|
377 |
+
):
|
378 |
+
raise ValueError(
|
379 |
+
"`image` has to be of type `torch.Tensor` or `np.ndarray` or `PIL.Image.Image` but is"
|
380 |
+
f" {type(image)}"
|
381 |
+
)
|
382 |
+
|
383 |
+
if goal is not None:
|
384 |
+
if task != "planning":
|
385 |
+
raise ValueError("`goal` is only supported for `planning` task.")
|
386 |
+
|
387 |
+
if (
|
388 |
+
not isinstance(goal, torch.Tensor)
|
389 |
+
and not isinstance(goal, np.ndarray)
|
390 |
+
and not isinstance(goal, PIL.Image.Image)
|
391 |
+
):
|
392 |
+
raise ValueError(
|
393 |
+
"`goal` has to be of type `torch.Tensor` or `np.ndarray` or `PIL.Image.Image` but is"
|
394 |
+
f" {type(goal)}"
|
395 |
+
)
|
396 |
+
|
397 |
+
if video is not None:
|
398 |
+
if task != "reconstruction":
|
399 |
+
raise ValueError("`video` is only supported for `reconstruction` task.")
|
400 |
+
|
401 |
+
if (
|
402 |
+
not isinstance(video, torch.Tensor)
|
403 |
+
and not isinstance(video, np.ndarray)
|
404 |
+
and not (
|
405 |
+
isinstance(video, list)
|
406 |
+
and all(isinstance(v, PIL.Image.Image) for v in video)
|
407 |
+
)
|
408 |
+
):
|
409 |
+
raise ValueError(
|
410 |
+
"`video` has to be of type `torch.Tensor` or `np.ndarray` or `List[PIL.Image.Image]` but is"
|
411 |
+
f" {type(video)}"
|
412 |
+
)
|
413 |
+
|
414 |
+
if height % 8 != 0 or width % 8 != 0:
|
415 |
+
raise ValueError(
|
416 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
417 |
+
)
|
418 |
+
|
419 |
+
if num_frames is None:
|
420 |
+
raise ValueError("`num_frames` is required.")
|
421 |
+
|
422 |
+
if num_frames not in [17, 25, 33, 41]:
|
423 |
+
raise ValueError("`num_frames` has to be one of [17, 25, 33, 41].")
|
424 |
+
|
425 |
+
if fps not in [8, 10, 12, 15, 24]:
|
426 |
+
raise ValueError("`fps` has to be one of [8, 10, 12, 15, 24].")
|
427 |
+
|
428 |
+
if (
|
429 |
+
raymap is not None
|
430 |
+
and not isinstance(raymap, torch.Tensor)
|
431 |
+
and not isinstance(raymap, np.ndarray)
|
432 |
+
):
|
433 |
+
raise ValueError(
|
434 |
+
"`raymap` has to be of type `torch.Tensor` or `np.ndarray`."
|
435 |
+
)
|
436 |
+
|
437 |
+
if raymap is not None:
|
438 |
+
if raymap.shape[-4:] != (
|
439 |
+
num_frames,
|
440 |
+
6,
|
441 |
+
height // self.vae_scale_factor_spatial,
|
442 |
+
width // self.vae_scale_factor_spatial,
|
443 |
+
):
|
444 |
+
raise ValueError(
|
445 |
+
f"`raymap` shape is not correct. "
|
446 |
+
f"Expected {num_frames, 6, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial}, "
|
447 |
+
f"got {raymap.shape}."
|
448 |
+
)
|
449 |
+
|
450 |
+
def _preprocess_image(self, image, height, width):
|
451 |
+
if isinstance(image, torch.Tensor):
|
452 |
+
image = image.cpu().numpy()
|
453 |
+
if image.dtype == np.uint8:
|
454 |
+
image = image.astype(np.float32) / 255.0
|
455 |
+
if image.ndim == 3:
|
456 |
+
image = [image]
|
457 |
+
image = imcrop_center(image, height, width)
|
458 |
+
image = self.video_processor.preprocess(image, height, width)
|
459 |
+
return image
|
460 |
+
|
461 |
+
def preprocess_inputs(
|
462 |
+
self,
|
463 |
+
image,
|
464 |
+
goal,
|
465 |
+
video,
|
466 |
+
raymap,
|
467 |
+
height,
|
468 |
+
width,
|
469 |
+
num_frames,
|
470 |
+
):
|
471 |
+
if image is not None:
|
472 |
+
if isinstance(image, PIL.Image.Image):
|
473 |
+
image = self.video_processor.preprocess(
|
474 |
+
image, height, width, resize_mode="crop"
|
475 |
+
).to(self._execution_device)
|
476 |
+
else:
|
477 |
+
image = self._preprocess_image(image, height, width).to(
|
478 |
+
self._execution_device
|
479 |
+
)
|
480 |
+
if goal is not None:
|
481 |
+
if isinstance(goal, PIL.Image.Image):
|
482 |
+
goal = self.video_processor.preprocess(
|
483 |
+
goal, height, width, resize_mode="crop"
|
484 |
+
).to(self._execution_device)
|
485 |
+
else:
|
486 |
+
goal = self._preprocess_image(goal, height, width).to(
|
487 |
+
self._execution_device
|
488 |
+
)
|
489 |
+
if video is not None:
|
490 |
+
if isinstance(video, list) and all(
|
491 |
+
isinstance(v, PIL.Image.Image) for v in video
|
492 |
+
):
|
493 |
+
video = self.video_processor.preprocess(
|
494 |
+
video, height, width, resize_mode="crop"
|
495 |
+
).to(self._execution_device)
|
496 |
+
else:
|
497 |
+
video = self._preprocess_image(video, height, width).to(
|
498 |
+
self._execution_device
|
499 |
+
)
|
500 |
+
# TODO: check raymap shape
|
501 |
+
if raymap is not None:
|
502 |
+
if isinstance(raymap, np.ndarray):
|
503 |
+
raymap = torch.from_numpy(raymap).to(self._execution_device)
|
504 |
+
if raymap.ndim == 4:
|
505 |
+
raymap = raymap.unsqueeze(0)
|
506 |
+
|
507 |
+
return image, goal, video, raymap
|
508 |
+
|
509 |
+
@torch.no_grad()
|
510 |
+
def prepare_latents(
|
511 |
+
self,
|
512 |
+
image: Optional[torch.Tensor] = None,
|
513 |
+
goal: Optional[torch.Tensor] = None,
|
514 |
+
video: Optional[torch.Tensor] = None,
|
515 |
+
raymap: Optional[torch.Tensor] = None,
|
516 |
+
batch_size: int = 1,
|
517 |
+
num_frames: int = 13,
|
518 |
+
height: int = 60,
|
519 |
+
width: int = 90,
|
520 |
+
dtype: Optional[torch.dtype] = None,
|
521 |
+
device: Optional[torch.device] = None,
|
522 |
+
generator: Optional[torch.Generator] = None,
|
523 |
+
):
|
524 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
525 |
+
raise ValueError(
|
526 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
527 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
528 |
+
)
|
529 |
+
|
530 |
+
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
531 |
+
shape = (
|
532 |
+
batch_size,
|
533 |
+
num_frames,
|
534 |
+
56,
|
535 |
+
height // self.vae_scale_factor_spatial,
|
536 |
+
width // self.vae_scale_factor_spatial,
|
537 |
+
)
|
538 |
+
|
539 |
+
# For CogVideoX1.5, the latent should add 1 for padding (Not use)
|
540 |
+
if self.transformer.config.patch_size_t is not None:
|
541 |
+
shape = (
|
542 |
+
shape[:1]
|
543 |
+
+ (shape[1] + shape[1] % self.transformer.config.patch_size_t,)
|
544 |
+
+ shape[2:]
|
545 |
+
)
|
546 |
+
|
547 |
+
if image is not None:
|
548 |
+
image = image.unsqueeze(2)
|
549 |
+
if isinstance(generator, list):
|
550 |
+
image_latents = [
|
551 |
+
retrieve_latents(
|
552 |
+
self.vae.encode(image[i].unsqueeze(0)), generator[i]
|
553 |
+
)
|
554 |
+
for i in range(batch_size)
|
555 |
+
]
|
556 |
+
else:
|
557 |
+
image_latents = [
|
558 |
+
retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator)
|
559 |
+
for img in image
|
560 |
+
]
|
561 |
+
|
562 |
+
image_latents = (
|
563 |
+
torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4)
|
564 |
+
) # [B, F, C, H, W]
|
565 |
+
|
566 |
+
if not self.vae.config.invert_scale_latents:
|
567 |
+
image_latents = self.vae_scaling_factor_image * image_latents
|
568 |
+
else:
|
569 |
+
# This is awkward but required because the CogVideoX team forgot to multiply the
|
570 |
+
# scaling factor during training :)
|
571 |
+
image_latents = 1 / self.vae_scaling_factor_image * image_latents
|
572 |
+
|
573 |
+
if goal is not None:
|
574 |
+
goal = goal.unsqueeze(2)
|
575 |
+
if isinstance(generator, list):
|
576 |
+
goal_latents = [
|
577 |
+
retrieve_latents(
|
578 |
+
self.vae.encode(goal[i].unsqueeze(0)), generator[i]
|
579 |
+
)
|
580 |
+
for i in range(batch_size)
|
581 |
+
]
|
582 |
+
else:
|
583 |
+
goal_latents = [
|
584 |
+
retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator)
|
585 |
+
for img in goal
|
586 |
+
]
|
587 |
+
|
588 |
+
goal_latents = (
|
589 |
+
torch.cat(goal_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4)
|
590 |
+
) # [B, F, C, H, W]
|
591 |
+
|
592 |
+
if not self.vae.config.invert_scale_latents:
|
593 |
+
goal_latents = self.vae_scaling_factor_image * goal_latents
|
594 |
+
else:
|
595 |
+
# This is awkward but required because the CogVideoX team forgot to multiply the
|
596 |
+
# scaling factor during training :)
|
597 |
+
goal_latents = 1 / self.vae_scaling_factor_image * goal_latents
|
598 |
+
|
599 |
+
if video is not None:
|
600 |
+
if video.ndim == 4:
|
601 |
+
video = video.unsqueeze(0)
|
602 |
+
|
603 |
+
video = video.permute(0, 2, 1, 3, 4)
|
604 |
+
if isinstance(generator, list):
|
605 |
+
video_latents = [
|
606 |
+
retrieve_latents(
|
607 |
+
self.vae.encode(video[i].unsqueeze(0)), generator[i]
|
608 |
+
)
|
609 |
+
for i in range(batch_size)
|
610 |
+
]
|
611 |
+
else:
|
612 |
+
video_latents = [
|
613 |
+
retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator)
|
614 |
+
for img in video
|
615 |
+
]
|
616 |
+
|
617 |
+
video_latents = (
|
618 |
+
torch.cat(video_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4)
|
619 |
+
) # [B, F, C, H, W]
|
620 |
+
|
621 |
+
if not self.vae.config.invert_scale_latents:
|
622 |
+
video_latents = self.vae_scaling_factor_image * video_latents
|
623 |
+
else:
|
624 |
+
# This is awkward but required because the CogVideoX team forgot to multiply the
|
625 |
+
# scaling factor during training :)
|
626 |
+
video_latents = 1 / self.vae_scaling_factor_image * video_latents
|
627 |
+
|
628 |
+
if image is not None and goal is None:
|
629 |
+
padding_shape = (
|
630 |
+
batch_size,
|
631 |
+
num_frames - image_latents.shape[1],
|
632 |
+
*image_latents.shape[2:],
|
633 |
+
)
|
634 |
+
padding = torch.zeros(padding_shape, device=device, dtype=dtype)
|
635 |
+
condition_latents = torch.cat([image_latents, padding], dim=1)
|
636 |
+
elif goal is not None:
|
637 |
+
padding_shape = (
|
638 |
+
batch_size,
|
639 |
+
num_frames - goal_latents.shape[1] - image_latents.shape[1],
|
640 |
+
*image_latents.shape[2:],
|
641 |
+
)
|
642 |
+
padding = torch.zeros(padding_shape, device=device, dtype=dtype)
|
643 |
+
condition_latents = torch.cat([image_latents, padding, goal_latents], dim=1)
|
644 |
+
elif video is not None:
|
645 |
+
condition_latents = video_latents
|
646 |
+
|
647 |
+
if raymap is not None:
|
648 |
+
if raymap.shape[1] % self.vae_scale_factor_temporal != 0:
|
649 |
+
# repeat
|
650 |
+
raymap = torch.cat(
|
651 |
+
[
|
652 |
+
raymap[
|
653 |
+
:,
|
654 |
+
: self.vae_scale_factor_temporal
|
655 |
+
- raymap.shape[1] % self.vae_scale_factor_temporal,
|
656 |
+
],
|
657 |
+
raymap,
|
658 |
+
],
|
659 |
+
dim=1,
|
660 |
+
)
|
661 |
+
camera_conditions = rearrange(
|
662 |
+
raymap,
|
663 |
+
"b (n t) c h w -> b t (n c) h w",
|
664 |
+
n=self.vae_scale_factor_temporal,
|
665 |
+
)
|
666 |
+
else:
|
667 |
+
camera_conditions = torch.zeros(
|
668 |
+
batch_size,
|
669 |
+
num_frames,
|
670 |
+
24,
|
671 |
+
height // self.vae_scale_factor_spatial,
|
672 |
+
width // self.vae_scale_factor_spatial,
|
673 |
+
device=device,
|
674 |
+
dtype=dtype,
|
675 |
+
)
|
676 |
+
|
677 |
+
condition_latents = torch.cat([condition_latents, camera_conditions], dim=2)
|
678 |
+
latents = randn_tensor(shape, device=device, generator=generator, dtype=dtype)
|
679 |
+
|
680 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
681 |
+
latents = latents * self.scheduler.init_noise_sigma
|
682 |
+
|
683 |
+
return latents, condition_latents
|
684 |
+
|
685 |
+
@torch.no_grad()
|
686 |
+
def __call__(
|
687 |
+
self,
|
688 |
+
task: Optional[str] = None,
|
689 |
+
image: Optional[PipelineImageInput] = None,
|
690 |
+
video: Optional[PipelineImageInput] = None,
|
691 |
+
goal: Optional[PipelineImageInput] = None,
|
692 |
+
raymap: Optional[Union[torch.Tensor, np.ndarray]] = None,
|
693 |
+
height: Optional[int] = None,
|
694 |
+
width: Optional[int] = None,
|
695 |
+
num_frames: Optional[int] = None,
|
696 |
+
num_inference_steps: Optional[int] = None,
|
697 |
+
timesteps: Optional[List[int]] = None,
|
698 |
+
guidance_scale: Optional[float] = None,
|
699 |
+
use_dynamic_cfg: bool = False,
|
700 |
+
num_videos_per_prompt: int = 1,
|
701 |
+
eta: float = 0.0,
|
702 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
703 |
+
return_dict: bool = True,
|
704 |
+
attention_kwargs: Optional[Dict] = None,
|
705 |
+
fps: Optional[int] = None,
|
706 |
+
) -> Union[AetherV1PipelineOutput, Tuple]:
|
707 |
+
if task is None:
|
708 |
+
if video is not None:
|
709 |
+
task = "reconstruction"
|
710 |
+
elif goal is not None:
|
711 |
+
task = "planning"
|
712 |
+
else:
|
713 |
+
task = "prediction"
|
714 |
+
|
715 |
+
height = (
|
716 |
+
height
|
717 |
+
or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
718 |
+
)
|
719 |
+
width = (
|
720 |
+
width
|
721 |
+
or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
722 |
+
)
|
723 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
724 |
+
fps = fps or self._base_fps
|
725 |
+
|
726 |
+
num_videos_per_prompt = 1
|
727 |
+
|
728 |
+
# 1. Check inputs. Raise error if not correct
|
729 |
+
self.check_inputs(
|
730 |
+
task=task,
|
731 |
+
image=image,
|
732 |
+
video=video,
|
733 |
+
goal=goal,
|
734 |
+
raymap=raymap,
|
735 |
+
height=height,
|
736 |
+
width=width,
|
737 |
+
num_frames=num_frames,
|
738 |
+
fps=fps,
|
739 |
+
)
|
740 |
+
|
741 |
+
# 2. Preprocess inputs
|
742 |
+
image, goal, video, raymap = self.preprocess_inputs(
|
743 |
+
image=image,
|
744 |
+
goal=goal,
|
745 |
+
video=video,
|
746 |
+
raymap=raymap,
|
747 |
+
height=height,
|
748 |
+
width=width,
|
749 |
+
num_frames=num_frames,
|
750 |
+
)
|
751 |
+
self._guidance_scale = guidance_scale
|
752 |
+
self._current_timestep = None
|
753 |
+
self._attention_kwargs = attention_kwargs
|
754 |
+
self._interrupt = False
|
755 |
+
|
756 |
+
batch_size = 1
|
757 |
+
|
758 |
+
device = self._execution_device
|
759 |
+
|
760 |
+
# 3. Encode input prompt
|
761 |
+
prompt_embeds = self.empty_prompt_embeds.to(device)
|
762 |
+
|
763 |
+
num_inference_steps = (
|
764 |
+
num_inference_steps or self._default_num_inference_steps[task]
|
765 |
+
)
|
766 |
+
guidance_scale = guidance_scale or self._default_guidance_scale[task]
|
767 |
+
use_dynamic_cfg = use_dynamic_cfg or self._default_use_dynamic_cfg[task]
|
768 |
+
|
769 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
770 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
771 |
+
# corresponds to doing no classifier free guidance.
|
772 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
773 |
+
|
774 |
+
# 4. Prepare timesteps
|
775 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
776 |
+
self.scheduler, num_inference_steps, device, timesteps
|
777 |
+
)
|
778 |
+
self._num_timesteps = len(timesteps)
|
779 |
+
|
780 |
+
# 5. Prepare latents
|
781 |
+
latents, condition_latents = self.prepare_latents(
|
782 |
+
image,
|
783 |
+
goal,
|
784 |
+
video,
|
785 |
+
raymap,
|
786 |
+
batch_size * num_videos_per_prompt,
|
787 |
+
num_frames,
|
788 |
+
height,
|
789 |
+
width,
|
790 |
+
prompt_embeds.dtype,
|
791 |
+
device,
|
792 |
+
generator,
|
793 |
+
)
|
794 |
+
|
795 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
796 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
797 |
+
|
798 |
+
# 7. Create rotary embeds if required
|
799 |
+
image_rotary_emb = (
|
800 |
+
self._prepare_rotary_positional_embeddings(
|
801 |
+
height, width, latents.size(1), device, fps=fps
|
802 |
+
)
|
803 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
804 |
+
else None
|
805 |
+
)
|
806 |
+
|
807 |
+
# 8. Create ofs embeds if required
|
808 |
+
ofs_emb = (
|
809 |
+
None
|
810 |
+
if self.transformer.config.ofs_embed_dim is None
|
811 |
+
else latents.new_full((1,), fill_value=2.0)
|
812 |
+
)
|
813 |
+
|
814 |
+
# 8. Denoising loop
|
815 |
+
num_warmup_steps = max(
|
816 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
817 |
+
)
|
818 |
+
|
819 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
820 |
+
# for DPM-solver++
|
821 |
+
old_pred_original_sample = None
|
822 |
+
for i, t in enumerate(timesteps):
|
823 |
+
if self.interrupt:
|
824 |
+
continue
|
825 |
+
|
826 |
+
self._current_timestep = t
|
827 |
+
latent_model_input = (
|
828 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
829 |
+
)
|
830 |
+
latent_model_input = self.scheduler.scale_model_input(
|
831 |
+
latent_model_input, t
|
832 |
+
)
|
833 |
+
|
834 |
+
if do_classifier_free_guidance:
|
835 |
+
if task == "planning":
|
836 |
+
assert goal is not None
|
837 |
+
uncond = condition_latents.clone()
|
838 |
+
uncond[:, :, : self.vae.config.latent_channels] = 0
|
839 |
+
latent_condition = torch.cat([uncond, condition_latents])
|
840 |
+
elif task == "prediction":
|
841 |
+
uncond = condition_latents.clone()
|
842 |
+
uncond[:, :1, : self.vae.config.latent_channels] = 0
|
843 |
+
latent_condition = torch.cat([uncond, condition_latents])
|
844 |
+
else:
|
845 |
+
raise ValueError(
|
846 |
+
f"Task {task} not supported for classifier-free guidance."
|
847 |
+
)
|
848 |
+
|
849 |
+
else:
|
850 |
+
latent_condition = condition_latents
|
851 |
+
|
852 |
+
latent_model_input = torch.cat(
|
853 |
+
[latent_model_input, latent_condition], dim=2
|
854 |
+
)
|
855 |
+
|
856 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
857 |
+
timestep = t.expand(latent_model_input.shape[0])
|
858 |
+
|
859 |
+
# predict noise model_output
|
860 |
+
noise_pred = self.transformer(
|
861 |
+
hidden_states=latent_model_input,
|
862 |
+
encoder_hidden_states=prompt_embeds.repeat(
|
863 |
+
latent_model_input.shape[0], 1, 1
|
864 |
+
),
|
865 |
+
timestep=timestep,
|
866 |
+
ofs=ofs_emb,
|
867 |
+
image_rotary_emb=image_rotary_emb,
|
868 |
+
attention_kwargs=attention_kwargs,
|
869 |
+
return_dict=False,
|
870 |
+
)[0]
|
871 |
+
noise_pred = noise_pred.float()
|
872 |
+
|
873 |
+
# perform guidance
|
874 |
+
if use_dynamic_cfg:
|
875 |
+
self._guidance_scale = 1 + guidance_scale * (
|
876 |
+
(
|
877 |
+
1
|
878 |
+
- math.cos(
|
879 |
+
math.pi
|
880 |
+
* (
|
881 |
+
(num_inference_steps - t.item())
|
882 |
+
/ num_inference_steps
|
883 |
+
)
|
884 |
+
** 5.0
|
885 |
+
)
|
886 |
+
)
|
887 |
+
/ 2
|
888 |
+
)
|
889 |
+
|
890 |
+
if do_classifier_free_guidance:
|
891 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
892 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
893 |
+
noise_pred_text - noise_pred_uncond
|
894 |
+
)
|
895 |
+
|
896 |
+
# compute the previous noisy sample x_t -> x_t-1
|
897 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
898 |
+
latents = self.scheduler.step(
|
899 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
900 |
+
)[0]
|
901 |
+
else:
|
902 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
903 |
+
noise_pred,
|
904 |
+
old_pred_original_sample,
|
905 |
+
t,
|
906 |
+
timesteps[i - 1] if i > 0 else None,
|
907 |
+
latents,
|
908 |
+
**extra_step_kwargs,
|
909 |
+
return_dict=False,
|
910 |
+
)
|
911 |
+
latents = latents.to(prompt_embeds.dtype)
|
912 |
+
|
913 |
+
if i == len(timesteps) - 1 or (
|
914 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
915 |
+
):
|
916 |
+
progress_bar.update()
|
917 |
+
|
918 |
+
self._current_timestep = None
|
919 |
+
|
920 |
+
rgb_latents = latents[:, :, : self.vae.config.latent_channels]
|
921 |
+
disparity_latents = latents[
|
922 |
+
:, :, self.vae.config.latent_channels : self.vae.config.latent_channels * 2
|
923 |
+
]
|
924 |
+
camera_latents = latents[:, :, self.vae.config.latent_channels * 2 :]
|
925 |
+
|
926 |
+
rgb_video = self.decode_latents(rgb_latents)
|
927 |
+
rgb_video = self.video_processor.postprocess_video(
|
928 |
+
video=rgb_video, output_type="np"
|
929 |
+
)
|
930 |
+
|
931 |
+
disparity_video = self.decode_latents(disparity_latents)
|
932 |
+
disparity_video = disparity_video.mean(dim=1, keepdim=False)
|
933 |
+
disparity_video = disparity_video * 0.5 + 0.5
|
934 |
+
disparity_video = torch.square(disparity_video)
|
935 |
+
disparity_video = disparity_video.float().cpu().numpy()
|
936 |
+
|
937 |
+
raymap = (
|
938 |
+
rearrange(camera_latents, "b t (n c) h w -> b (n t) c h w", n=4)[
|
939 |
+
:, -rgb_video.shape[1] :, :, :
|
940 |
+
]
|
941 |
+
.cpu()
|
942 |
+
.numpy()
|
943 |
+
)
|
944 |
+
|
945 |
+
# Offload all models
|
946 |
+
self.maybe_free_model_hooks()
|
947 |
+
|
948 |
+
if not return_dict:
|
949 |
+
return (
|
950 |
+
rgb_video,
|
951 |
+
disparity_video,
|
952 |
+
raymap,
|
953 |
+
)
|
954 |
+
|
955 |
+
return AetherV1PipelineOutput(
|
956 |
+
rgb=rgb_video.squeeze(0),
|
957 |
+
disparity=disparity_video.squeeze(0),
|
958 |
+
raymap=raymap.squeeze(0),
|
959 |
+
)
|
aether/utils/__init__.py
ADDED
File without changes
|
aether/utils/postprocess_utils.py
ADDED
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import matplotlib
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
from plyfile import PlyData, PlyElement
|
11 |
+
|
12 |
+
|
13 |
+
def signed_log1p_inverse(x):
|
14 |
+
"""
|
15 |
+
Computes the inverse of signed_log1p: x = sign(x) * (exp(abs(x)) - 1).
|
16 |
+
|
17 |
+
Args:
|
18 |
+
y (torch.Tensor): Input tensor (output of signed_log1p).
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
torch.Tensor: Original tensor x.
|
22 |
+
"""
|
23 |
+
if isinstance(x, torch.Tensor):
|
24 |
+
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
|
25 |
+
elif isinstance(x, np.ndarray):
|
26 |
+
return np.sign(x) * (np.exp(np.abs(x)) - 1)
|
27 |
+
else:
|
28 |
+
raise TypeError("Input must be a torch.Tensor or numpy.ndarray")
|
29 |
+
|
30 |
+
|
31 |
+
def colorize_depth(depth, cmap="Spectral"):
|
32 |
+
min_d, max_d = (depth[depth > 0]).min(), (depth[depth > 0]).max()
|
33 |
+
depth = (max_d - depth) / (max_d - min_d)
|
34 |
+
|
35 |
+
cm = matplotlib.colormaps[cmap]
|
36 |
+
depth = depth.clip(0, 1)
|
37 |
+
depth = cm(depth, bytes=False)[..., 0:3]
|
38 |
+
return depth
|
39 |
+
|
40 |
+
|
41 |
+
def save_ply(pointmap, image, output_file, downsample=20, mask=None):
|
42 |
+
_, h, w, _ = pointmap.shape
|
43 |
+
image = image[:, :h, :w]
|
44 |
+
pointmap = pointmap[:, :h, :w]
|
45 |
+
|
46 |
+
points = pointmap.reshape(-1, 3) # (H*W, 3)
|
47 |
+
colors = image.reshape(-1, 3) # (H*W, 3)
|
48 |
+
if mask is not None:
|
49 |
+
points = points[mask.reshape(-1)]
|
50 |
+
colors = colors[mask.reshape(-1)]
|
51 |
+
|
52 |
+
indices = np.random.choice(
|
53 |
+
colors.shape[0], int(colors.shape[0] / downsample), replace=False
|
54 |
+
)
|
55 |
+
points = points[indices]
|
56 |
+
colors = colors[indices]
|
57 |
+
|
58 |
+
vertices = []
|
59 |
+
for p, c in zip(points, colors):
|
60 |
+
vertex = (p[0], p[1], p[2], int(c[0]), int(c[1]), int(c[2]))
|
61 |
+
vertices.append(vertex)
|
62 |
+
|
63 |
+
vertex_dtype = np.dtype(
|
64 |
+
[
|
65 |
+
("x", "f4"),
|
66 |
+
("y", "f4"),
|
67 |
+
("z", "f4"),
|
68 |
+
("red", "u1"),
|
69 |
+
("green", "u1"),
|
70 |
+
("blue", "u1"),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
vertex_array = np.array(vertices, dtype=vertex_dtype)
|
74 |
+
|
75 |
+
ply_element = PlyElement.describe(vertex_array, "vertex")
|
76 |
+
PlyData([ply_element], text=True).write(output_file)
|
77 |
+
|
78 |
+
|
79 |
+
def fov_to_focal(fovx, fovy, h, w):
|
80 |
+
focal_x = w * 0.5 / np.tan(fovx)
|
81 |
+
focal_y = h * 0.5 / np.tan(fovy)
|
82 |
+
focal = (focal_x + focal_y) / 2
|
83 |
+
return focal
|
84 |
+
|
85 |
+
|
86 |
+
def get_rays(pose, h, w, focal=None, fovx=None, fovy=None):
|
87 |
+
import torch.nn.functional as F
|
88 |
+
|
89 |
+
pose = torch.from_numpy(pose).float()
|
90 |
+
x, y = torch.meshgrid(
|
91 |
+
torch.arange(w),
|
92 |
+
torch.arange(h),
|
93 |
+
indexing="xy",
|
94 |
+
)
|
95 |
+
x = x.flatten().unsqueeze(0).repeat(pose.shape[0], 1)
|
96 |
+
y = y.flatten().unsqueeze(0).repeat(pose.shape[0], 1)
|
97 |
+
|
98 |
+
cx = w * 0.5
|
99 |
+
cy = h * 0.5
|
100 |
+
intrinsics, focal = get_intrinsics(pose.shape[0], h, w, fovx, fovy, focal)
|
101 |
+
focal = torch.from_numpy(focal).float()
|
102 |
+
camera_dirs = F.pad(
|
103 |
+
torch.stack(
|
104 |
+
[
|
105 |
+
(x - cx + 0.5) / focal.unsqueeze(-1),
|
106 |
+
(y - cy + 0.5) / focal.unsqueeze(-1),
|
107 |
+
],
|
108 |
+
dim=-1,
|
109 |
+
),
|
110 |
+
(0, 1),
|
111 |
+
value=1.0,
|
112 |
+
) # [t, hw, 3]
|
113 |
+
|
114 |
+
pose = pose.to(dtype=camera_dirs.dtype)
|
115 |
+
rays_d = camera_dirs @ pose[:, :3, :3].transpose(1, 2) # [t, hw, 3]
|
116 |
+
|
117 |
+
rays_o = pose[:, :3, 3].unsqueeze(1).expand_as(rays_d) # [hw, 3]
|
118 |
+
|
119 |
+
rays_o = rays_o.view(pose.shape[0], h, w, 3)
|
120 |
+
rays_d = rays_d.view(pose.shape[0], h, w, 3)
|
121 |
+
|
122 |
+
return rays_o.float().numpy(), rays_d.float().numpy(), intrinsics
|
123 |
+
|
124 |
+
|
125 |
+
def get_intrinsics(batch_size, h, w, fovx=None, fovy=None, focal=None):
|
126 |
+
if focal is None:
|
127 |
+
focal_x = w * 0.5 / np.tan(fovx)
|
128 |
+
focal_y = h * 0.5 / np.tan(fovy)
|
129 |
+
focal = (focal_x + focal_y) / 2
|
130 |
+
cx = w * 0.5
|
131 |
+
cy = h * 0.5
|
132 |
+
intrinsics = np.zeros((batch_size, 3, 3))
|
133 |
+
intrinsics[:, 0, 0] = focal
|
134 |
+
intrinsics[:, 1, 1] = focal
|
135 |
+
intrinsics[:, 0, 2] = cx
|
136 |
+
intrinsics[:, 1, 2] = cy
|
137 |
+
intrinsics[:, 2, 2] = 1.0
|
138 |
+
|
139 |
+
return intrinsics, focal
|
140 |
+
|
141 |
+
|
142 |
+
def save_pointmap(
|
143 |
+
rgb,
|
144 |
+
disparity,
|
145 |
+
raymap,
|
146 |
+
save_file,
|
147 |
+
vae_downsample_scale=8,
|
148 |
+
camera_pose=None,
|
149 |
+
ray_o_scale_inv=1.0,
|
150 |
+
max_depth=1e2,
|
151 |
+
save_full_pcd_videos=False,
|
152 |
+
smooth_camera=False,
|
153 |
+
smooth_method="kalman", # or simple
|
154 |
+
**kwargs,
|
155 |
+
):
|
156 |
+
"""
|
157 |
+
|
158 |
+
Args:
|
159 |
+
rgb (numpy.ndarray): Shape of (t, h, w, 3), range [0, 1]
|
160 |
+
disparity (numpy.ndarray): Shape of (t, h, w), range [0, 1]
|
161 |
+
raymap (numpy.ndarray): Shape of (t, 6, h // 8, w // 8)
|
162 |
+
ray_o_scale_inv (float, optional): A `ray_o` scale constant. Defaults to 10.
|
163 |
+
"""
|
164 |
+
rgb = np.clip(rgb, 0, 1) * 255
|
165 |
+
|
166 |
+
pointmap_dict = postprocess_pointmap(
|
167 |
+
disparity,
|
168 |
+
raymap,
|
169 |
+
vae_downsample_scale,
|
170 |
+
camera_pose,
|
171 |
+
ray_o_scale_inv=ray_o_scale_inv,
|
172 |
+
smooth_camera=smooth_camera,
|
173 |
+
smooth_method=smooth_method,
|
174 |
+
**kwargs,
|
175 |
+
)
|
176 |
+
|
177 |
+
save_ply(
|
178 |
+
pointmap_dict["pointmap"],
|
179 |
+
rgb,
|
180 |
+
save_file,
|
181 |
+
mask=(pointmap_dict["depth"] < max_depth),
|
182 |
+
)
|
183 |
+
|
184 |
+
if save_full_pcd_videos:
|
185 |
+
pcd_dict = {
|
186 |
+
"points": pointmap_dict["pointmap"],
|
187 |
+
"colors": rgb,
|
188 |
+
"intrinsics": pointmap_dict["intrinsics"],
|
189 |
+
"poses": pointmap_dict["camera_pose"],
|
190 |
+
"depths": pointmap_dict["depth"],
|
191 |
+
}
|
192 |
+
np.save(save_file.replace(".ply", "_pcd.npy"), pcd_dict)
|
193 |
+
|
194 |
+
return pointmap_dict
|
195 |
+
|
196 |
+
|
197 |
+
def raymap_to_poses(
|
198 |
+
raymap, camera_pose=None, ray_o_scale_inv=1.0, return_intrinsics=True
|
199 |
+
):
|
200 |
+
ts = raymap.shape[0]
|
201 |
+
if (not return_intrinsics) and (camera_pose is not None):
|
202 |
+
return camera_pose, None, None
|
203 |
+
|
204 |
+
raymap[:, 3:] = signed_log1p_inverse(raymap[:, 3:])
|
205 |
+
|
206 |
+
# Extract ray origins and directions
|
207 |
+
ray_o = (
|
208 |
+
rearrange(raymap[:, 3:], "t c h w -> t h w c") * ray_o_scale_inv
|
209 |
+
) # [T, H, W, C]
|
210 |
+
ray_d = rearrange(raymap[:, :3], "t c h w -> t h w c") # [T, H, W, C]
|
211 |
+
|
212 |
+
# Compute orientation and directions
|
213 |
+
orient = ray_o.reshape(ts, -1, 3).mean(axis=1) # T, 3
|
214 |
+
image_orient = (ray_o + ray_d).reshape(ts, -1, 3).mean(axis=1) # T, 3
|
215 |
+
Focal = np.linalg.norm(image_orient - orient, axis=-1) # T,
|
216 |
+
Z_Dir = image_orient - orient # T, 3
|
217 |
+
|
218 |
+
# Compute the width (W) and field of view (FoV_x)
|
219 |
+
W_Left = ray_d[:, :, :1, :].reshape(ts, -1, 3).mean(axis=1)
|
220 |
+
W_Right = ray_d[:, :, -1:, :].reshape(ts, -1, 3).mean(axis=1)
|
221 |
+
W = W_Right - W_Left
|
222 |
+
W_real = (
|
223 |
+
np.linalg.norm(np.cross(W, Z_Dir), axis=-1)
|
224 |
+
/ (raymap.shape[-1] - 1)
|
225 |
+
* raymap.shape[-1]
|
226 |
+
)
|
227 |
+
Fov_x = np.arctan(W_real / (2 * Focal))
|
228 |
+
|
229 |
+
# Compute the height (H) and field of view (FoV_y)
|
230 |
+
H_Up = ray_d[:, :1, :, :].reshape(ts, -1, 3).mean(axis=1)
|
231 |
+
H_Down = ray_d[:, -1:, :, :].reshape(ts, -1, 3).mean(axis=1)
|
232 |
+
H = H_Up - H_Down
|
233 |
+
H_real = (
|
234 |
+
np.linalg.norm(np.cross(H, Z_Dir), axis=-1)
|
235 |
+
/ (raymap.shape[-2] - 1)
|
236 |
+
* raymap.shape[-2]
|
237 |
+
)
|
238 |
+
Fov_y = np.arctan(H_real / (2 * Focal))
|
239 |
+
|
240 |
+
# Compute X, Y, and Z directions for the camera
|
241 |
+
X_Dir = W_Right - W_Left
|
242 |
+
Y_Dir = np.cross(Z_Dir, X_Dir)
|
243 |
+
X_Dir = np.cross(Y_Dir, Z_Dir)
|
244 |
+
|
245 |
+
X_Dir /= np.linalg.norm(X_Dir, axis=-1, keepdims=True)
|
246 |
+
Y_Dir /= np.linalg.norm(Y_Dir, axis=-1, keepdims=True)
|
247 |
+
Z_Dir /= np.linalg.norm(Z_Dir, axis=-1, keepdims=True)
|
248 |
+
|
249 |
+
# Create the camera-to-world (camera_pose) transformation matrix
|
250 |
+
if camera_pose is None:
|
251 |
+
camera_pose = np.zeros((ts, 4, 4))
|
252 |
+
camera_pose[:, :3, 0] = X_Dir
|
253 |
+
camera_pose[:, :3, 1] = Y_Dir
|
254 |
+
camera_pose[:, :3, 2] = Z_Dir
|
255 |
+
camera_pose[:, :3, 3] = orient
|
256 |
+
camera_pose[:, 3, 3] = 1.0
|
257 |
+
|
258 |
+
return camera_pose, Fov_x, Fov_y
|
259 |
+
|
260 |
+
|
261 |
+
def postprocess_pointmap(
|
262 |
+
disparity,
|
263 |
+
raymap,
|
264 |
+
vae_downsample_scale=8,
|
265 |
+
camera_pose=None,
|
266 |
+
focal=None,
|
267 |
+
ray_o_scale_inv=1.0,
|
268 |
+
smooth_camera=False,
|
269 |
+
smooth_method="simple",
|
270 |
+
**kwargs,
|
271 |
+
):
|
272 |
+
"""
|
273 |
+
|
274 |
+
Args:
|
275 |
+
disparity (numpy.ndarray): Shape of (t, h, w), range [0, 1]
|
276 |
+
raymap (numpy.ndarray): Shape of (t, 6, h // 8, w // 8)
|
277 |
+
ray_o_scale_inv (float, optional): A `ray_o` scale constant. Defaults to 10.
|
278 |
+
"""
|
279 |
+
depth = np.clip(1.0 / np.clip(disparity, 1e-3, 1), 0, 1e8)
|
280 |
+
|
281 |
+
camera_pose, fov_x, fov_y = raymap_to_poses(
|
282 |
+
raymap,
|
283 |
+
camera_pose=camera_pose,
|
284 |
+
ray_o_scale_inv=ray_o_scale_inv,
|
285 |
+
return_intrinsics=(focal is not None),
|
286 |
+
)
|
287 |
+
if focal is None:
|
288 |
+
focal = fov_to_focal(
|
289 |
+
fov_x,
|
290 |
+
fov_y,
|
291 |
+
int(raymap.shape[2] * vae_downsample_scale),
|
292 |
+
int(raymap.shape[3] * vae_downsample_scale),
|
293 |
+
)
|
294 |
+
|
295 |
+
if smooth_camera:
|
296 |
+
# Check if sequence is static
|
297 |
+
is_static, trans_diff, rot_diff = detect_static_sequence(camera_pose)
|
298 |
+
|
299 |
+
if is_static:
|
300 |
+
print(
|
301 |
+
f"Detected static/near-static sequence (trans_diff={trans_diff:.6f}, rot_diff={rot_diff:.6f})"
|
302 |
+
)
|
303 |
+
# Apply stronger smoothing for static sequences
|
304 |
+
camera_pose = adaptive_pose_smoothing(camera_pose, trans_diff, rot_diff)
|
305 |
+
else:
|
306 |
+
if smooth_method == "simple":
|
307 |
+
camera_pose = smooth_poses(
|
308 |
+
camera_pose, window_size=5, method="gaussian"
|
309 |
+
)
|
310 |
+
elif smooth_method == "kalman":
|
311 |
+
camera_pose = smooth_trajectory(camera_pose, window_size=5)
|
312 |
+
|
313 |
+
ray_o, ray_d, intrinsics = get_rays(
|
314 |
+
camera_pose,
|
315 |
+
int(raymap.shape[2] * vae_downsample_scale),
|
316 |
+
int(raymap.shape[3] * vae_downsample_scale),
|
317 |
+
focal,
|
318 |
+
)
|
319 |
+
|
320 |
+
pointmap = depth[..., None] * ray_d + ray_o
|
321 |
+
|
322 |
+
return {
|
323 |
+
"pointmap": pointmap,
|
324 |
+
"camera_pose": camera_pose,
|
325 |
+
"intrinsics": intrinsics,
|
326 |
+
"ray_o": ray_o,
|
327 |
+
"ray_d": ray_d,
|
328 |
+
"depth": depth,
|
329 |
+
}
|
330 |
+
|
331 |
+
|
332 |
+
def detect_static_sequence(poses, threshold=0.01):
|
333 |
+
"""Detect if the camera sequence is static based on pose differences."""
|
334 |
+
translations = poses[:, :3, 3]
|
335 |
+
rotations = poses[:, :3, :3]
|
336 |
+
|
337 |
+
# Compute translation differences
|
338 |
+
trans_diff = np.linalg.norm(translations[1:] - translations[:-1], axis=1).mean()
|
339 |
+
|
340 |
+
# Compute rotation differences (using matrix frobenius norm)
|
341 |
+
rot_diff = np.linalg.norm(rotations[1:] - rotations[:-1], axis=(1, 2)).mean()
|
342 |
+
|
343 |
+
return trans_diff < threshold and rot_diff < threshold, trans_diff, rot_diff
|
344 |
+
|
345 |
+
|
346 |
+
def adaptive_pose_smoothing(poses, trans_diff, rot_diff, base_window=5):
|
347 |
+
"""Apply adaptive smoothing based on motion magnitude."""
|
348 |
+
# Increase window size for low motion sequences
|
349 |
+
motion_magnitude = trans_diff + rot_diff
|
350 |
+
adaptive_window = min(
|
351 |
+
41, max(base_window, int(base_window * (0.1 / max(motion_magnitude, 1e-6))))
|
352 |
+
)
|
353 |
+
|
354 |
+
# Apply stronger smoothing for low motion
|
355 |
+
poses_smooth = smooth_poses(poses, window_size=adaptive_window, method="gaussian")
|
356 |
+
return poses_smooth
|
357 |
+
|
358 |
+
|
359 |
+
def get_pixel(H, W):
|
360 |
+
# get 2D pixels (u, v) for image_a in cam_a pixel space
|
361 |
+
u_a, v_a = np.meshgrid(np.arange(W), np.arange(H))
|
362 |
+
# u_a = np.flip(u_a, axis=1)
|
363 |
+
# v_a = np.flip(v_a, axis=0)
|
364 |
+
pixels_a = np.stack(
|
365 |
+
[u_a.flatten() + 0.5, v_a.flatten() + 0.5, np.ones_like(u_a.flatten())], axis=0
|
366 |
+
)
|
367 |
+
|
368 |
+
return pixels_a
|
369 |
+
|
370 |
+
|
371 |
+
def project(depth, intrinsic, pose):
|
372 |
+
H, W = depth.shape
|
373 |
+
pixel = get_pixel(H, W).astype(np.float32)
|
374 |
+
points = (np.linalg.inv(intrinsic) @ pixel) * depth.reshape(-1)
|
375 |
+
points = pose[:3, :4] @ np.concatenate(
|
376 |
+
[points, np.ones((1, points.shape[1]))], axis=0
|
377 |
+
)
|
378 |
+
|
379 |
+
points = points.T.reshape(H, W, 3)
|
380 |
+
|
381 |
+
return points
|
382 |
+
|
383 |
+
|
384 |
+
def depth_edge(
|
385 |
+
depth: torch.Tensor,
|
386 |
+
atol: float = None,
|
387 |
+
rtol: float = None,
|
388 |
+
kernel_size: int = 3,
|
389 |
+
mask: Optional[torch.Tensor] = None,
|
390 |
+
) -> torch.BoolTensor:
|
391 |
+
"""
|
392 |
+
Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
|
393 |
+
|
394 |
+
Args:
|
395 |
+
depth (torch.Tensor): shape (..., height, width), linear depth map
|
396 |
+
atol (float): absolute tolerance
|
397 |
+
rtol (float): relative tolerance
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
|
401 |
+
"""
|
402 |
+
is_numpy = isinstance(depth, np.ndarray)
|
403 |
+
if is_numpy:
|
404 |
+
depth = torch.from_numpy(depth)
|
405 |
+
if isinstance(mask, np.ndarray):
|
406 |
+
mask = torch.from_numpy(mask)
|
407 |
+
|
408 |
+
shape = depth.shape
|
409 |
+
depth = depth.reshape(-1, 1, *shape[-2:])
|
410 |
+
if mask is not None:
|
411 |
+
mask = mask.reshape(-1, 1, *shape[-2:])
|
412 |
+
|
413 |
+
if mask is None:
|
414 |
+
diff = F.max_pool2d(
|
415 |
+
depth, kernel_size, stride=1, padding=kernel_size // 2
|
416 |
+
) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
|
417 |
+
else:
|
418 |
+
diff = F.max_pool2d(
|
419 |
+
torch.where(mask, depth, -torch.inf),
|
420 |
+
kernel_size,
|
421 |
+
stride=1,
|
422 |
+
padding=kernel_size // 2,
|
423 |
+
) + F.max_pool2d(
|
424 |
+
torch.where(mask, -depth, -torch.inf),
|
425 |
+
kernel_size,
|
426 |
+
stride=1,
|
427 |
+
padding=kernel_size // 2,
|
428 |
+
)
|
429 |
+
|
430 |
+
edge = torch.zeros_like(depth, dtype=torch.bool)
|
431 |
+
if atol is not None:
|
432 |
+
edge |= diff > atol
|
433 |
+
if rtol is not None:
|
434 |
+
edge |= (diff / depth).nan_to_num_() > rtol
|
435 |
+
edge = edge.reshape(*shape)
|
436 |
+
|
437 |
+
if is_numpy:
|
438 |
+
return edge.numpy()
|
439 |
+
return edge
|
440 |
+
|
441 |
+
|
442 |
+
@torch.jit.script
|
443 |
+
def align_rigid(
|
444 |
+
p,
|
445 |
+
q,
|
446 |
+
weights,
|
447 |
+
):
|
448 |
+
"""Compute a rigid transformation that, when applied to p, minimizes the weighted
|
449 |
+
squared distance between transformed points in p and points in q. See "Least-Squares
|
450 |
+
Rigid Motion Using SVD" by Olga Sorkine-Hornung and Michael Rabinovich for more
|
451 |
+
details (https://igl.ethz.ch/projects/ARAP/svd_rot.pdf).
|
452 |
+
"""
|
453 |
+
|
454 |
+
device = p.device
|
455 |
+
dtype = p.dtype
|
456 |
+
batch, _, _ = p.shape
|
457 |
+
|
458 |
+
# 1. Compute the centroids of both point sets.
|
459 |
+
weights_normalized = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8)
|
460 |
+
p_centroid = (weights_normalized[..., None] * p).sum(dim=-2)
|
461 |
+
q_centroid = (weights_normalized[..., None] * q).sum(dim=-2)
|
462 |
+
|
463 |
+
# 2. Compute the centered vectors.
|
464 |
+
p_centered = p - p_centroid[..., None, :]
|
465 |
+
q_centered = q - q_centroid[..., None, :]
|
466 |
+
|
467 |
+
# 3. Compute the 3x3 covariance matrix.
|
468 |
+
covariance = (q_centered * weights[..., None]).transpose(-1, -2) @ p_centered
|
469 |
+
|
470 |
+
# 4. Compute the singular value decomposition and then the rotation.
|
471 |
+
u, _, vt = torch.linalg.svd(covariance)
|
472 |
+
s = torch.eye(3, dtype=dtype, device=device)
|
473 |
+
s = s.expand((batch, 3, 3)).contiguous()
|
474 |
+
s[..., 2, 2] = (u.det() * vt.det()).sign()
|
475 |
+
rotation = u @ s @ vt
|
476 |
+
|
477 |
+
# 5. Compute the optimal scale
|
478 |
+
scale = (
|
479 |
+
(torch.einsum("b i j, b k j -> b k i", rotation, p_centered) * q_centered).sum(
|
480 |
+
-1
|
481 |
+
)
|
482 |
+
* weights
|
483 |
+
).sum(-1) / ((p_centered**2).sum(-1) * weights).sum(-1)
|
484 |
+
# scale = (torch.einsum("b i j, b k j -> b k i", rotation, p_centered) * q_centered).sum([-1, -2]) / (p_centered**2).sum([-1, -2])
|
485 |
+
|
486 |
+
# 6. Compute the optimal translation.
|
487 |
+
translation = q_centroid - torch.einsum(
|
488 |
+
"b i j, b j -> b i", rotation, p_centroid * scale[:, None]
|
489 |
+
)
|
490 |
+
|
491 |
+
return rotation, translation, scale
|
492 |
+
|
493 |
+
|
494 |
+
def align_camera_extrinsics(
|
495 |
+
cameras_src: torch.Tensor, # Bx3x4 tensor representing [R | t]
|
496 |
+
cameras_tgt: torch.Tensor, # Bx3x4 tensor representing [R | t]
|
497 |
+
estimate_scale: bool = True,
|
498 |
+
eps: float = 1e-9,
|
499 |
+
):
|
500 |
+
"""
|
501 |
+
Align the source camera extrinsics to the target camera extrinsics.
|
502 |
+
NOTE Assume OPENCV convention
|
503 |
+
|
504 |
+
Args:
|
505 |
+
cameras_src (torch.Tensor): Bx3x4 tensor representing [R | t] for source cameras.
|
506 |
+
cameras_tgt (torch.Tensor): Bx3x4 tensor representing [R | t] for target cameras.
|
507 |
+
estimate_scale (bool, optional): Whether to estimate the scale factor. Default is True.
|
508 |
+
eps (float, optional): Small value to avoid division by zero. Default is 1e-9.
|
509 |
+
|
510 |
+
Returns:
|
511 |
+
align_t_R (torch.Tensor): 1x3x3 rotation matrix for alignment.
|
512 |
+
align_t_T (torch.Tensor): 1x3 translation vector for alignment.
|
513 |
+
align_t_s (float): Scaling factor for alignment.
|
514 |
+
"""
|
515 |
+
|
516 |
+
R_src = cameras_src[:, :, :3] # Extracting the rotation matrices from [R | t]
|
517 |
+
R_tgt = cameras_tgt[:, :, :3] # Extracting the rotation matrices from [R | t]
|
518 |
+
|
519 |
+
RRcov = torch.bmm(R_tgt.transpose(2, 1), R_src).mean(0)
|
520 |
+
U, _, V = torch.svd(RRcov)
|
521 |
+
align_t_R = V @ U.t()
|
522 |
+
|
523 |
+
T_src = cameras_src[:, :, 3] # Extracting the translation vectors from [R | t]
|
524 |
+
T_tgt = cameras_tgt[:, :, 3] # Extracting the translation vectors from [R | t]
|
525 |
+
|
526 |
+
A = torch.bmm(T_src[:, None], R_src)[:, 0]
|
527 |
+
B = torch.bmm(T_tgt[:, None], R_src)[:, 0]
|
528 |
+
|
529 |
+
Amu = A.mean(0, keepdim=True)
|
530 |
+
Bmu = B.mean(0, keepdim=True)
|
531 |
+
|
532 |
+
if estimate_scale and A.shape[0] > 1:
|
533 |
+
# get the scaling component by matching covariances
|
534 |
+
# of centered A and centered B
|
535 |
+
Ac = A - Amu
|
536 |
+
Bc = B - Bmu
|
537 |
+
align_t_s = (Ac * Bc).mean() / (Ac**2).mean().clamp(eps)
|
538 |
+
else:
|
539 |
+
# set the scale to identity
|
540 |
+
align_t_s = 1.0
|
541 |
+
|
542 |
+
# get the translation as the difference between the means of A and B
|
543 |
+
align_t_T = Bmu - align_t_s * Amu
|
544 |
+
|
545 |
+
align_t_R = align_t_R[None]
|
546 |
+
return align_t_R, align_t_T, align_t_s
|
547 |
+
|
548 |
+
|
549 |
+
def apply_transformation(
|
550 |
+
cameras_src: torch.Tensor, # Bx3x4 tensor representing [R | t]
|
551 |
+
align_t_R: torch.Tensor, # 1x3x3 rotation matrix
|
552 |
+
align_t_T: torch.Tensor, # 1x3 translation vector
|
553 |
+
align_t_s: float, # Scaling factor
|
554 |
+
return_extri: bool = True,
|
555 |
+
) -> torch.Tensor:
|
556 |
+
"""
|
557 |
+
Align and transform the source cameras using the provided rotation, translation, and scaling factors.
|
558 |
+
NOTE Assume OPENCV convention
|
559 |
+
|
560 |
+
Args:
|
561 |
+
cameras_src (torch.Tensor): Bx3x4 tensor representing [R | t] for source cameras.
|
562 |
+
align_t_R (torch.Tensor): 1x3x3 rotation matrix for alignment.
|
563 |
+
align_t_T (torch.Tensor): 1x3 translation vector for alignment.
|
564 |
+
align_t_s (float): Scaling factor for alignment.
|
565 |
+
|
566 |
+
Returns:
|
567 |
+
aligned_R (torch.Tensor): Bx3x3 tensor representing the aligned rotation matrices.
|
568 |
+
aligned_T (torch.Tensor): Bx3 tensor representing the aligned translation vectors.
|
569 |
+
"""
|
570 |
+
|
571 |
+
R_src = cameras_src[:, :, :3]
|
572 |
+
T_src = cameras_src[:, :, 3]
|
573 |
+
|
574 |
+
aligned_R = torch.bmm(R_src, align_t_R.expand(R_src.shape[0], 3, 3))
|
575 |
+
|
576 |
+
# Apply the translation alignment to the source translations
|
577 |
+
align_t_T_expanded = align_t_T[..., None].repeat(R_src.shape[0], 1, 1)
|
578 |
+
transformed_T = torch.bmm(R_src, align_t_T_expanded)[..., 0]
|
579 |
+
aligned_T = transformed_T + T_src * align_t_s
|
580 |
+
|
581 |
+
if return_extri:
|
582 |
+
extri = torch.cat([aligned_R, aligned_T.unsqueeze(-1)], dim=-1)
|
583 |
+
return extri
|
584 |
+
|
585 |
+
return aligned_R, aligned_T
|
586 |
+
|
587 |
+
|
588 |
+
def slerp(q1, q2, t):
|
589 |
+
"""Spherical Linear Interpolation between quaternions.
|
590 |
+
Args:
|
591 |
+
q1: (4,) first quaternion
|
592 |
+
q2: (4,) second quaternion
|
593 |
+
t: float between 0 and 1
|
594 |
+
Returns:
|
595 |
+
(4,) interpolated quaternion
|
596 |
+
"""
|
597 |
+
# Compute the cosine of the angle between the two vectors
|
598 |
+
dot = np.sum(q1 * q2)
|
599 |
+
|
600 |
+
# If the dot product is negative, slerp won't take the shorter path
|
601 |
+
# Fix by negating one of the input quaternions
|
602 |
+
if dot < 0.0:
|
603 |
+
q2 = -q2
|
604 |
+
dot = -dot
|
605 |
+
|
606 |
+
# Threshold for using linear interpolation instead of spherical
|
607 |
+
DOT_THRESHOLD = 0.9995
|
608 |
+
if dot > DOT_THRESHOLD:
|
609 |
+
# If the inputs are too close for comfort, linearly interpolate
|
610 |
+
# and normalize the result
|
611 |
+
result = q1 + t * (q2 - q1)
|
612 |
+
return result / np.linalg.norm(result)
|
613 |
+
|
614 |
+
# Compute the angle between the quaternions
|
615 |
+
theta_0 = np.arccos(dot)
|
616 |
+
sin_theta_0 = np.sin(theta_0)
|
617 |
+
|
618 |
+
# Compute interpolation factors
|
619 |
+
theta = theta_0 * t
|
620 |
+
sin_theta = np.sin(theta)
|
621 |
+
|
622 |
+
s0 = np.cos(theta) - dot * sin_theta / sin_theta_0
|
623 |
+
s1 = sin_theta / sin_theta_0
|
624 |
+
|
625 |
+
return (s0 * q1) + (s1 * q2)
|
626 |
+
|
627 |
+
|
628 |
+
def interpolate_poses(pose1, pose2, weight):
|
629 |
+
"""Interpolate between two camera poses with weight.
|
630 |
+
Args:
|
631 |
+
pose1: (4, 4) first camera pose
|
632 |
+
pose2: (4, 4) second camera pose
|
633 |
+
weight: float between 0 and 1, weight for pose1 (1-weight for pose2)
|
634 |
+
Returns:
|
635 |
+
(4, 4) interpolated pose
|
636 |
+
"""
|
637 |
+
from scipy.spatial.transform import Rotation as R
|
638 |
+
|
639 |
+
# Extract rotations and translations
|
640 |
+
R1 = R.from_matrix(pose1[:3, :3])
|
641 |
+
R2 = R.from_matrix(pose2[:3, :3])
|
642 |
+
t1 = pose1[:3, 3]
|
643 |
+
t2 = pose2[:3, 3]
|
644 |
+
|
645 |
+
# Get quaternions
|
646 |
+
q1 = R1.as_quat()
|
647 |
+
q2 = R2.as_quat()
|
648 |
+
|
649 |
+
# Interpolate rotation using our slerp implementation
|
650 |
+
q_interp = slerp(q1, q2, 1 - weight) # 1-weight because weight is for pose1
|
651 |
+
R_interp = R.from_quat(q_interp)
|
652 |
+
|
653 |
+
# Linear interpolation for translation
|
654 |
+
t_interp = weight * t1 + (1 - weight) * t2
|
655 |
+
|
656 |
+
# Construct interpolated pose
|
657 |
+
pose_interp = np.eye(4)
|
658 |
+
pose_interp[:3, :3] = R_interp.as_matrix()
|
659 |
+
pose_interp[:3, 3] = t_interp
|
660 |
+
|
661 |
+
return pose_interp
|
662 |
+
|
663 |
+
|
664 |
+
def smooth_poses(poses, window_size=5, method="gaussian"):
|
665 |
+
"""Smooth camera poses temporally.
|
666 |
+
Args:
|
667 |
+
poses: (N, 4, 4) camera poses
|
668 |
+
window_size: int, must be odd number
|
669 |
+
method: str, 'gaussian' or 'savgol' or 'ma'
|
670 |
+
Returns:
|
671 |
+
(N, 4, 4) smoothed poses
|
672 |
+
"""
|
673 |
+
from scipy.ndimage import gaussian_filter1d
|
674 |
+
from scipy.signal import savgol_filter
|
675 |
+
from scipy.spatial.transform import Rotation as R
|
676 |
+
|
677 |
+
assert window_size % 2 == 1, "window_size must be odd"
|
678 |
+
N = poses.shape[0]
|
679 |
+
smoothed = np.zeros_like(poses)
|
680 |
+
|
681 |
+
# Extract translations and quaternions
|
682 |
+
translations = poses[:, :3, 3]
|
683 |
+
rotations = R.from_matrix(poses[:, :3, :3])
|
684 |
+
quats = rotations.as_quat() # (N, 4)
|
685 |
+
|
686 |
+
# Ensure consistent quaternion signs to prevent interpolation artifacts
|
687 |
+
for i in range(1, N):
|
688 |
+
if np.dot(quats[i], quats[i - 1]) < 0:
|
689 |
+
quats[i] = -quats[i]
|
690 |
+
|
691 |
+
# Smooth translations
|
692 |
+
if method == "gaussian":
|
693 |
+
sigma = window_size / 6.0 # approximately 99.7% of the weight within the window
|
694 |
+
smoothed_trans = gaussian_filter1d(translations, sigma, axis=0, mode="nearest")
|
695 |
+
smoothed_quats = gaussian_filter1d(quats, sigma, axis=0, mode="nearest")
|
696 |
+
elif method == "savgol":
|
697 |
+
# Savitzky-Golay filter: polynomial fitting
|
698 |
+
poly_order = min(window_size - 1, 3)
|
699 |
+
smoothed_trans = savgol_filter(
|
700 |
+
translations, window_size, poly_order, axis=0, mode="nearest"
|
701 |
+
)
|
702 |
+
smoothed_quats = savgol_filter(
|
703 |
+
quats, window_size, poly_order, axis=0, mode="nearest"
|
704 |
+
)
|
705 |
+
elif method == "ma":
|
706 |
+
# Simple moving average
|
707 |
+
kernel = np.ones(window_size) / window_size
|
708 |
+
smoothed_trans = np.array(
|
709 |
+
[np.convolve(translations[:, i], kernel, mode="same") for i in range(3)]
|
710 |
+
).T
|
711 |
+
smoothed_quats = np.array(
|
712 |
+
[np.convolve(quats[:, i], kernel, mode="same") for i in range(4)]
|
713 |
+
).T
|
714 |
+
|
715 |
+
# Normalize quaternions
|
716 |
+
smoothed_quats /= np.linalg.norm(smoothed_quats, axis=1, keepdims=True)
|
717 |
+
|
718 |
+
# Reconstruct poses
|
719 |
+
smoothed_rots = R.from_quat(smoothed_quats).as_matrix()
|
720 |
+
|
721 |
+
for i in range(N):
|
722 |
+
smoothed[i] = np.eye(4)
|
723 |
+
smoothed[i, :3, :3] = smoothed_rots[i]
|
724 |
+
smoothed[i, :3, 3] = smoothed_trans[i]
|
725 |
+
|
726 |
+
return smoothed
|
727 |
+
|
728 |
+
|
729 |
+
def smooth_trajectory(poses, window_size=5):
|
730 |
+
"""Smooth camera trajectory using Kalman filter.
|
731 |
+
Args:
|
732 |
+
poses: (N, 4, 4) camera poses
|
733 |
+
window_size: int, window size for initial smoothing
|
734 |
+
Returns:
|
735 |
+
(N, 4, 4) smoothed poses
|
736 |
+
"""
|
737 |
+
from filterpy.kalman import KalmanFilter
|
738 |
+
from scipy.spatial.transform import Rotation as R
|
739 |
+
|
740 |
+
N = poses.shape[0]
|
741 |
+
|
742 |
+
# Initialize Kalman filter for position and velocity
|
743 |
+
kf = KalmanFilter(dim_x=6, dim_z=3) # 3D position and velocity
|
744 |
+
dt = 1.0 # assume uniform time steps
|
745 |
+
|
746 |
+
# State transition matrix
|
747 |
+
kf.F = np.array(
|
748 |
+
[
|
749 |
+
[1, 0, 0, dt, 0, 0],
|
750 |
+
[0, 1, 0, 0, dt, 0],
|
751 |
+
[0, 0, 1, 0, 0, dt],
|
752 |
+
[0, 0, 0, 1, 0, 0],
|
753 |
+
[0, 0, 0, 0, 1, 0],
|
754 |
+
[0, 0, 0, 0, 0, 1],
|
755 |
+
]
|
756 |
+
)
|
757 |
+
|
758 |
+
# Measurement matrix
|
759 |
+
kf.H = np.array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0]])
|
760 |
+
|
761 |
+
# Measurement noise
|
762 |
+
kf.R *= 0.1
|
763 |
+
|
764 |
+
# Process noise
|
765 |
+
kf.Q *= 0.1
|
766 |
+
|
767 |
+
# Initial state uncertainty
|
768 |
+
kf.P *= 1.0
|
769 |
+
|
770 |
+
# Extract translations and rotations
|
771 |
+
translations = poses[:, :3, 3]
|
772 |
+
rotations = R.from_matrix(poses[:, :3, :3])
|
773 |
+
quats = rotations.as_quat()
|
774 |
+
|
775 |
+
# First pass: simple smoothing for initial estimates
|
776 |
+
smoothed = smooth_poses(poses, window_size, method="gaussian")
|
777 |
+
smooth_trans = smoothed[:, :3, 3]
|
778 |
+
|
779 |
+
# Second pass: Kalman filter for trajectory
|
780 |
+
filtered_trans = np.zeros_like(translations)
|
781 |
+
kf.x = np.zeros(6)
|
782 |
+
kf.x[:3] = smooth_trans[0]
|
783 |
+
|
784 |
+
filtered_trans[0] = smooth_trans[0]
|
785 |
+
|
786 |
+
# Forward pass
|
787 |
+
for i in range(1, N):
|
788 |
+
kf.predict()
|
789 |
+
kf.update(smooth_trans[i])
|
790 |
+
filtered_trans[i] = kf.x[:3]
|
791 |
+
|
792 |
+
# Backward smoothing for rotations using SLERP
|
793 |
+
window_half = window_size // 2
|
794 |
+
smoothed_quats = np.zeros_like(quats)
|
795 |
+
|
796 |
+
for i in range(N):
|
797 |
+
start_idx = max(0, i - window_half)
|
798 |
+
end_idx = min(N, i + window_half + 1)
|
799 |
+
weights = np.exp(
|
800 |
+
-0.5 * ((np.arange(start_idx, end_idx) - i) / (window_half / 2)) ** 2
|
801 |
+
)
|
802 |
+
weights /= weights.sum()
|
803 |
+
|
804 |
+
# Weighted average of nearby quaternions
|
805 |
+
avg_quat = np.zeros(4)
|
806 |
+
for j, w in zip(range(start_idx, end_idx), weights):
|
807 |
+
if np.dot(quats[j], quats[i]) < 0:
|
808 |
+
avg_quat += w * -quats[j]
|
809 |
+
else:
|
810 |
+
avg_quat += w * quats[j]
|
811 |
+
smoothed_quats[i] = avg_quat / np.linalg.norm(avg_quat)
|
812 |
+
|
813 |
+
# Reconstruct final smoothed poses
|
814 |
+
final_smoothed = np.zeros_like(poses)
|
815 |
+
smoothed_rots = R.from_quat(smoothed_quats).as_matrix()
|
816 |
+
|
817 |
+
for i in range(N):
|
818 |
+
final_smoothed[i] = np.eye(4)
|
819 |
+
final_smoothed[i, :3, :3] = smoothed_rots[i]
|
820 |
+
final_smoothed[i, :3, 3] = filtered_trans[i]
|
821 |
+
|
822 |
+
return final_smoothed
|
823 |
+
|
824 |
+
|
825 |
+
def compute_scale(prediction, target, mask):
|
826 |
+
if isinstance(prediction, np.ndarray):
|
827 |
+
prediction = torch.from_numpy(prediction).float()
|
828 |
+
if isinstance(target, np.ndarray):
|
829 |
+
target = torch.from_numpy(target).float()
|
830 |
+
if isinstance(mask, np.ndarray):
|
831 |
+
mask = torch.from_numpy(mask).bool()
|
832 |
+
|
833 |
+
numerator = torch.sum(mask * prediction * target, (1, 2))
|
834 |
+
denominator = torch.sum(mask * prediction * prediction, (1, 2))
|
835 |
+
|
836 |
+
scale = torch.zeros_like(numerator)
|
837 |
+
|
838 |
+
valid = (denominator != 0).nonzero()
|
839 |
+
|
840 |
+
scale[valid] = numerator[valid] / denominator[valid]
|
841 |
+
|
842 |
+
return scale.item()
|
aether/utils/preprocess_utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def imcrop_center(img_list, crop_p_h, crop_p_w):
|
5 |
+
new_img = []
|
6 |
+
for i, _img in enumerate(img_list):
|
7 |
+
if crop_p_h / crop_p_w > _img.shape[0] / _img.shape[1]: # crop left and right
|
8 |
+
start_h = int(0)
|
9 |
+
start_w = int((_img.shape[1] - _img.shape[0] / crop_p_h * crop_p_w) / 2)
|
10 |
+
crop_size = (_img.shape[0], int(_img.shape[0] / crop_p_h * crop_p_w))
|
11 |
+
else:
|
12 |
+
start_h = int((_img.shape[0] - _img.shape[1] / crop_p_w * crop_p_h) / 2)
|
13 |
+
start_w = int(0)
|
14 |
+
crop_size = (int(_img.shape[1] / crop_p_w * crop_p_h), _img.shape[1])
|
15 |
+
|
16 |
+
_img_src = crop(_img, start_h, start_w, crop_size[0], crop_size[1])
|
17 |
+
new_img.append(_img_src)
|
18 |
+
|
19 |
+
return new_img
|
20 |
+
|
21 |
+
|
22 |
+
def crop(img, start_h, start_w, crop_h, crop_w):
|
23 |
+
img_src = np.zeros((crop_h, crop_w, *img.shape[2:]), dtype=img.dtype)
|
24 |
+
hsize, wsize = crop_h, crop_w
|
25 |
+
dh, dw, sh, sw = start_h, start_w, 0, 0
|
26 |
+
if dh < 0:
|
27 |
+
sh = -dh
|
28 |
+
hsize += dh
|
29 |
+
dh = 0
|
30 |
+
if dh + hsize > img.shape[0]:
|
31 |
+
hsize = img.shape[0] - dh
|
32 |
+
if dw < 0:
|
33 |
+
sw = -dw
|
34 |
+
wsize += dw
|
35 |
+
dw = 0
|
36 |
+
if dw + wsize > img.shape[1]:
|
37 |
+
wsize = img.shape[1] - dw
|
38 |
+
img_src[sh : sh + hsize, sw : sw + wsize] = img[dh : dh + hsize, dw : dw + wsize]
|
39 |
+
return img_src
|
aether/utils/visualize_utils.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Adapted from https://github.com/facebookresearch/vggt/blob/main/visual_util.py
|
8 |
+
|
9 |
+
|
10 |
+
import matplotlib
|
11 |
+
import numpy as np
|
12 |
+
import trimesh
|
13 |
+
from scipy.spatial.transform import Rotation
|
14 |
+
|
15 |
+
from aether.utils.postprocess_utils import depth_edge
|
16 |
+
|
17 |
+
|
18 |
+
def predictions_to_glb(
|
19 |
+
predictions,
|
20 |
+
filter_by_frames="all",
|
21 |
+
show_cam=True,
|
22 |
+
max_depth=100.0,
|
23 |
+
rtol=0.03,
|
24 |
+
frame_rel_idx: float = 0.0,
|
25 |
+
) -> trimesh.Scene:
|
26 |
+
"""
|
27 |
+
Converts predictions to a 3D scene represented as a GLB file.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
predictions (dict): Dictionary containing model predictions with keys:
|
31 |
+
- world_points: 3D point coordinates (S, H, W, 3)
|
32 |
+
- images: Input images (S, H, W, 3)
|
33 |
+
- depths: Depths (S, H, W)
|
34 |
+
- camera poses: Camera poses (S, 4, 4)
|
35 |
+
filter_by_frames (str): Frame filter specification (default: "all")
|
36 |
+
show_cam (bool): Include camera visualization (default: True)
|
37 |
+
max_depth (float): Maximum depth value (default: 100.0)
|
38 |
+
rtol (float): Relative tolerance for depth edge detection (default: 0.2)
|
39 |
+
frame_rel_idx (float): Relative index of the frame to visualize (default: 0.0)
|
40 |
+
Returns:
|
41 |
+
trimesh.Scene: Processed 3D scene containing point cloud and cameras
|
42 |
+
|
43 |
+
Raises:
|
44 |
+
ValueError: If input predictions structure is invalid
|
45 |
+
"""
|
46 |
+
if not isinstance(predictions, dict):
|
47 |
+
raise ValueError("predictions must be a dictionary")
|
48 |
+
|
49 |
+
selected_frame_idx = None
|
50 |
+
if filter_by_frames != "all" and filter_by_frames != "All":
|
51 |
+
try:
|
52 |
+
# Extract the index part before the colon
|
53 |
+
selected_frame_idx = int(filter_by_frames.split(":")[0])
|
54 |
+
except (ValueError, IndexError):
|
55 |
+
pass
|
56 |
+
|
57 |
+
pred_world_points = predictions["world_points"]
|
58 |
+
|
59 |
+
# Get images from predictions
|
60 |
+
images = predictions["images"]
|
61 |
+
# Use extrinsic matrices instead of pred_extrinsic_list
|
62 |
+
camera_poses = predictions["camera_poses"]
|
63 |
+
|
64 |
+
if selected_frame_idx is not None:
|
65 |
+
pred_world_points = pred_world_points[selected_frame_idx][None]
|
66 |
+
images = images[selected_frame_idx][None]
|
67 |
+
camera_poses = camera_poses[selected_frame_idx][None]
|
68 |
+
|
69 |
+
vertices_3d = pred_world_points.reshape(-1, 3)
|
70 |
+
# Handle different image formats - check if images need transposing
|
71 |
+
if images.ndim == 4 and images.shape[1] == 3: # NCHW format
|
72 |
+
colors_rgb = np.transpose(images, (0, 2, 3, 1))
|
73 |
+
else: # Assume already in NHWC format
|
74 |
+
colors_rgb = images
|
75 |
+
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
|
76 |
+
|
77 |
+
depths = predictions["depths"]
|
78 |
+
masks = depths < max_depth
|
79 |
+
edge = ~depth_edge(depths, rtol=rtol, mask=masks)
|
80 |
+
masks = (masks & edge).reshape(-1)
|
81 |
+
vertices_3d = vertices_3d[masks]
|
82 |
+
colors_rgb = colors_rgb[masks]
|
83 |
+
|
84 |
+
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
|
85 |
+
vertices_3d = np.array([[1, 0, 0]])
|
86 |
+
colors_rgb = np.array([[255, 255, 255]])
|
87 |
+
scene_scale = 1
|
88 |
+
else:
|
89 |
+
# Calculate the 5th and 95th percentiles along each axis
|
90 |
+
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
91 |
+
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
|
92 |
+
|
93 |
+
# Calculate the diagonal length of the percentile bounding box
|
94 |
+
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
|
95 |
+
|
96 |
+
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
97 |
+
|
98 |
+
# Initialize a 3D scene
|
99 |
+
scene_3d = trimesh.Scene()
|
100 |
+
|
101 |
+
# Add point cloud data to the scene
|
102 |
+
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
|
103 |
+
|
104 |
+
scene_3d.add_geometry(point_cloud_data)
|
105 |
+
|
106 |
+
# Prepare 4x4 matrices for camera extrinsics
|
107 |
+
num_cameras = len(camera_poses)
|
108 |
+
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
|
109 |
+
extrinsics_matrices[:, :3, :4] = camera_poses[:, :3, :4]
|
110 |
+
extrinsics_matrices[:, 3, 3] = 1
|
111 |
+
|
112 |
+
if show_cam:
|
113 |
+
# Add camera models to the scene
|
114 |
+
for i in range(num_cameras):
|
115 |
+
camera_to_world = camera_poses[i]
|
116 |
+
rgba_color = colormap(frame_rel_idx)
|
117 |
+
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
118 |
+
|
119 |
+
integrate_camera_into_scene(
|
120 |
+
scene_3d, camera_to_world, current_color, scene_scale
|
121 |
+
)
|
122 |
+
|
123 |
+
return scene_3d
|
124 |
+
|
125 |
+
|
126 |
+
def integrate_camera_into_scene(
|
127 |
+
scene: trimesh.Scene,
|
128 |
+
transform: np.ndarray,
|
129 |
+
face_colors: tuple,
|
130 |
+
scene_scale: float,
|
131 |
+
):
|
132 |
+
"""
|
133 |
+
Integrates a fake camera mesh into the 3D scene.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
scene (trimesh.Scene): The 3D scene to add the camera model.
|
137 |
+
transform (np.ndarray): Transformation matrix for camera positioning.
|
138 |
+
face_colors (tuple): Color of the camera face.
|
139 |
+
scene_scale (float): Scale of the scene.
|
140 |
+
"""
|
141 |
+
|
142 |
+
cam_width = scene_scale * 0.025
|
143 |
+
cam_height = scene_scale * 0.05
|
144 |
+
|
145 |
+
# Create cone shape for camera
|
146 |
+
rot_45_degree = np.eye(4)
|
147 |
+
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
|
148 |
+
rot_45_degree[2, 3] = -cam_height
|
149 |
+
|
150 |
+
opengl_transform = get_opengl_conversion_matrix()
|
151 |
+
# Combine transformations
|
152 |
+
complete_transform = transform @ opengl_transform @ rot_45_degree
|
153 |
+
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
|
154 |
+
|
155 |
+
# Generate mesh for the camera
|
156 |
+
slight_rotation = np.eye(4)
|
157 |
+
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
|
158 |
+
|
159 |
+
vertices_combined = np.concatenate(
|
160 |
+
[
|
161 |
+
camera_cone_shape.vertices,
|
162 |
+
0.95 * camera_cone_shape.vertices,
|
163 |
+
transform_points(slight_rotation, camera_cone_shape.vertices),
|
164 |
+
]
|
165 |
+
)
|
166 |
+
vertices_transformed = transform_points(complete_transform, vertices_combined)
|
167 |
+
|
168 |
+
mesh_faces = compute_camera_faces(camera_cone_shape)
|
169 |
+
|
170 |
+
# Add the camera mesh to the scene
|
171 |
+
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
|
172 |
+
camera_mesh.visual.face_colors[:, :3] = face_colors
|
173 |
+
scene.add_geometry(camera_mesh)
|
174 |
+
|
175 |
+
|
176 |
+
def get_opengl_conversion_matrix() -> np.ndarray:
|
177 |
+
"""
|
178 |
+
Constructs and returns the OpenGL conversion matrix.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
numpy.ndarray: A 4x4 OpenGL conversion matrix.
|
182 |
+
"""
|
183 |
+
# Create an identity matrix
|
184 |
+
matrix = np.identity(4)
|
185 |
+
|
186 |
+
# Flip the y and z axes
|
187 |
+
matrix[1, 1] = -1
|
188 |
+
matrix[2, 2] = -1
|
189 |
+
|
190 |
+
return matrix
|
191 |
+
|
192 |
+
|
193 |
+
def transform_points(
|
194 |
+
transformation: np.ndarray, points: np.ndarray, dim: int = None
|
195 |
+
) -> np.ndarray:
|
196 |
+
"""
|
197 |
+
Applies a 4x4 transformation to a set of points.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
transformation (np.ndarray): Transformation matrix.
|
201 |
+
points (np.ndarray): Points to be transformed.
|
202 |
+
dim (int, optional): Dimension for reshaping the result.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
np.ndarray: Transformed points.
|
206 |
+
"""
|
207 |
+
points = np.asarray(points)
|
208 |
+
initial_shape = points.shape[:-1]
|
209 |
+
dim = dim or points.shape[-1]
|
210 |
+
|
211 |
+
# Apply transformation
|
212 |
+
transformation = transformation.swapaxes(
|
213 |
+
-1, -2
|
214 |
+
) # Transpose the transformation matrix
|
215 |
+
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
|
216 |
+
|
217 |
+
# Reshape the result
|
218 |
+
result = points[..., :dim].reshape(*initial_shape, dim)
|
219 |
+
return result
|
220 |
+
|
221 |
+
|
222 |
+
def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
|
223 |
+
"""
|
224 |
+
Computes the faces for the camera mesh.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
cone_shape (trimesh.Trimesh): The shape of the camera cone.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
np.ndarray: Array of faces for the camera mesh.
|
231 |
+
"""
|
232 |
+
# Create pseudo cameras
|
233 |
+
faces_list = []
|
234 |
+
num_vertices_cone = len(cone_shape.vertices)
|
235 |
+
|
236 |
+
for face in cone_shape.faces:
|
237 |
+
if 0 in face:
|
238 |
+
continue
|
239 |
+
v1, v2, v3 = face
|
240 |
+
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
|
241 |
+
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
242 |
+
|
243 |
+
faces_list.extend(
|
244 |
+
[
|
245 |
+
(v1, v2, v2_offset),
|
246 |
+
(v1, v1_offset, v3),
|
247 |
+
(v3_offset, v2, v3),
|
248 |
+
(v1, v2, v2_offset_2),
|
249 |
+
(v1, v1_offset_2, v3),
|
250 |
+
(v3_offset_2, v2, v3),
|
251 |
+
]
|
252 |
+
)
|
253 |
+
|
254 |
+
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
255 |
+
return np.array(faces_list)
|
app.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
-
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assets/example_obs/car.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs/cartoon.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs/garden.jpg
ADDED
![]() |
assets/example_obs/room.jpg
ADDED
![]() |
assets/example_obs_goal/01_goal.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs_goal/01_obs.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs_goal/02_goal.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs_goal/02_obs.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs_goal/03_goal.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs_goal/03_obs.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs_goal/04_goal.png
ADDED
![]() |
Git LFS Details
|
assets/example_obs_goal/04_obs.png
ADDED
![]() |
Git LFS Details
|
assets/example_raymaps/raymap_backward.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:396d0f398d8a59bc27a1d76b91222a58c8d751c6aff5a3ccc10a56a2beecd540
|
3 |
+
size 5313728
|
assets/example_raymaps/raymap_forward_right.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46f73167495d98d19e86cd47cd062f6d18f2c2882d73c1239effc1b5f039bd32
|
3 |
+
size 5313728
|
assets/example_raymaps/raymap_left_forward.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6e744a0adf4c5f0386ce8c1d586c85f1b17d392978fa34057c638aff5b84ac3
|
3 |
+
size 5313728
|
assets/example_raymaps/raymap_right.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2eab316eab3e94297b0d4d5cffae367484265873ce3c4c98e24bffb4361d4a8d
|
3 |
+
size 5313728
|
assets/example_videos/bridge.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6dcb3a3dff9d02adfa97534fb3cb96b0618c101c049cf37cb81ce375bf71f252
|
3 |
+
size 6714828
|
assets/example_videos/moviegen.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4204806ad057c53068710d9f3db1ad74f3e947be92d1e8c3f3e06343efd0c1c6
|
3 |
+
size 2164737
|
assets/example_videos/nuscenes.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1230424dfae7159140968649e06f0d8a874f0e52f5ad296ba0fa5fd9c1c2d467
|
3 |
+
size 6256650
|
assets/example_videos/veo2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9206a10f368f4334338c0ca44d6c85aa2d260294d720309a855741d57116ed60
|
3 |
+
size 2909620
|
assets/logo.png
ADDED
![]() |
Git LFS Details
|
assets/teaser.png
ADDED
![]() |
Git LFS Details
|
pyproject.toml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.ruff]
|
2 |
+
line-length = 88
|
3 |
+
|
4 |
+
[tool.ruff.lint]
|
5 |
+
# Never enforce `E501` (line length violations).
|
6 |
+
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
7 |
+
select = ["C", "E", "F", "I", "W"]
|
8 |
+
|
9 |
+
# Ignore import violations in all `__init__.py` files.
|
10 |
+
[tool.ruff.lint.per-file-ignores]
|
11 |
+
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
12 |
+
"aether/*.py" = ["E402"]
|
13 |
+
"tests/*.py" = ["E402", "F841"]
|
14 |
+
|
15 |
+
[tool.ruff.lint.isort]
|
16 |
+
lines-after-imports = 2
|
17 |
+
known-first-party = ["aether"]
|
18 |
+
|
19 |
+
[tool.ruff.format]
|
20 |
+
# Like Black, use double quotes for strings.
|
21 |
+
quote-style = "double"
|
22 |
+
|
23 |
+
# Like Black, indent with spaces, rather than tabs.
|
24 |
+
indent-style = "space"
|
25 |
+
|
26 |
+
# Like Black, respect magic trailing commas.
|
27 |
+
skip-magic-trailing-comma = false
|
28 |
+
|
29 |
+
# Like Black, automatically detect the appropriate line ending.
|
30 |
+
line-ending = "auto"
|
requirements.txt
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate>=1.2.1
|
2 |
+
coloredlogs>=15.0.1
|
3 |
+
colorlog>=6.9.0
|
4 |
+
diffusers>=0.32.2
|
5 |
+
easydict>=1.13
|
6 |
+
einops>=0.8.0
|
7 |
+
hf_transfer>=0.1.8
|
8 |
+
huggingface-hub>=0.27.1
|
9 |
+
imageio>=2.33.1
|
10 |
+
imageio-ffmpeg>=0.5.1
|
11 |
+
iopath>=0.1.10
|
12 |
+
matplotlib>=3.10.0
|
13 |
+
numpy>=1.26.4
|
14 |
+
omegaconf>=2.3.0
|
15 |
+
opencv-python-headless>=4.10.0.84
|
16 |
+
pillow>=11.1.0
|
17 |
+
plotly>=5.24.1
|
18 |
+
plyfile>=1.1
|
19 |
+
pre_commit>=4.0.1
|
20 |
+
python-dotenv>=1.0.1
|
21 |
+
PyYAML>=6.0.2
|
22 |
+
rich>=13.9.4
|
23 |
+
rootutils>=1.0.7
|
24 |
+
safetensors>=0.5.2
|
25 |
+
scikit-image>=0.25.0
|
26 |
+
scipy>=1.15.0
|
27 |
+
sentencepiece>=0.2.0
|
28 |
+
six>=1.17.0
|
29 |
+
tokenizers>=0.21.0
|
30 |
+
torchaudio>=2.5.1
|
31 |
+
torchmetrics>=1.6.1
|
32 |
+
torchvision>=0.20.1
|
33 |
+
tqdm>=4.67.1
|
34 |
+
transformers>=4.48.0
|
35 |
+
triton>=3.1.0
|
36 |
+
typer>=0.15.1
|
37 |
+
typing_extensions>=4.12.2
|
38 |
+
viser>=0.2.23
|
39 |
+
filterpy
|
40 |
+
trimesh
|
41 |
+
gradio
|
scripts/demo.py
ADDED
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
|
6 |
+
import imageio.v3 as iio
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
import rootutils
|
10 |
+
import torch
|
11 |
+
from diffusers import (
|
12 |
+
AutoencoderKLCogVideoX,
|
13 |
+
CogVideoXDPMScheduler,
|
14 |
+
CogVideoXTransformer3DModel,
|
15 |
+
)
|
16 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
17 |
+
|
18 |
+
|
19 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
20 |
+
|
21 |
+
from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402
|
22 |
+
AetherV1PipelineCogVideoX,
|
23 |
+
AetherV1PipelineOutput,
|
24 |
+
)
|
25 |
+
from aether.utils.postprocess_utils import ( # noqa: E402
|
26 |
+
align_camera_extrinsics,
|
27 |
+
apply_transformation,
|
28 |
+
colorize_depth,
|
29 |
+
compute_scale,
|
30 |
+
get_intrinsics,
|
31 |
+
interpolate_poses,
|
32 |
+
postprocess_pointmap,
|
33 |
+
project,
|
34 |
+
raymap_to_poses,
|
35 |
+
)
|
36 |
+
from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
|
37 |
+
|
38 |
+
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
|
41 |
+
|
42 |
+
def seed_all(seed: int = 0) -> None:
|
43 |
+
"""
|
44 |
+
Set random seeds of all components.
|
45 |
+
"""
|
46 |
+
random.seed(seed)
|
47 |
+
np.random.seed(seed)
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
torch.cuda.manual_seed_all(seed)
|
50 |
+
|
51 |
+
|
52 |
+
def parse_args() -> argparse.Namespace:
|
53 |
+
"""Parse command line arguments."""
|
54 |
+
parser = argparse.ArgumentParser(description="AetherV1-CogvideoX Inference Demo")
|
55 |
+
|
56 |
+
parser.add_argument(
|
57 |
+
"--task",
|
58 |
+
type=str,
|
59 |
+
required=True,
|
60 |
+
choices=["reconstruction", "prediction", "planning"],
|
61 |
+
help="Task to perform: 'reconstruction', 'prediction' or 'planning'.",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--video",
|
65 |
+
type=str,
|
66 |
+
default=None,
|
67 |
+
help="Path to a video file. Only used for 'reconstruction' task.",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--image",
|
71 |
+
type=str,
|
72 |
+
default=None,
|
73 |
+
help="Path to an image file. Only used for 'prediction' and 'planning' tasks.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--goal",
|
77 |
+
type=str,
|
78 |
+
default=None,
|
79 |
+
help="Path to a goal image file. Only used for 'planning' task.",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--raymap_action",
|
83 |
+
type=str,
|
84 |
+
default=None,
|
85 |
+
help="Path to a raymap action file. Should be a numpy array of shape (num_frame, 6, latent_height, latent_width).",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--output_dir",
|
89 |
+
type=str,
|
90 |
+
default="outputs",
|
91 |
+
help="Path to save the outputs.",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--seed",
|
95 |
+
type=int,
|
96 |
+
default=42,
|
97 |
+
help="Random seed.",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--fps",
|
101 |
+
type=int,
|
102 |
+
default=12,
|
103 |
+
choices=[8, 10, 12, 15, 24],
|
104 |
+
help="Frames per second. Options: 8, 10, 12, 15, 24.",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--num_inference_steps",
|
108 |
+
type=int,
|
109 |
+
default=None,
|
110 |
+
help="Number of inference steps. If not specified, will use the default number of steps for the task.",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--guidance_scale",
|
114 |
+
type=float,
|
115 |
+
default=None,
|
116 |
+
help="Guidance scale. If not specified, will use the default guidance scale for the task.",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--use_dynamic_cfg",
|
120 |
+
action="store_true",
|
121 |
+
default=True,
|
122 |
+
help="Use dynamic cfg.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--height",
|
126 |
+
type=int,
|
127 |
+
default=480,
|
128 |
+
help="Height of the output video.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--width",
|
132 |
+
type=int,
|
133 |
+
default=720,
|
134 |
+
help="Width of the output video.",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--num_frames",
|
138 |
+
type=int,
|
139 |
+
default=41,
|
140 |
+
help="Number of frames to predict.",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--max_depth",
|
144 |
+
type=float,
|
145 |
+
default=100.0,
|
146 |
+
help="Maximum depth of the scene in meters.",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--rtol",
|
150 |
+
type=float,
|
151 |
+
default=0.03,
|
152 |
+
help="Relative tolerance for depth edge detection.",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--cogvideox_pretrained_model_name_or_path",
|
156 |
+
type=str,
|
157 |
+
default="THUDM/CogVideoX-5b-I2V",
|
158 |
+
help="Name or path of the CogVideoX model to use.",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--aether_pretrained_model_name_or_path",
|
162 |
+
type=str,
|
163 |
+
default="AetherWorldModel/AetherV1-CogVideoX",
|
164 |
+
help="Name or path of the Aether model to use.",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--smooth_camera",
|
168 |
+
action="store_true",
|
169 |
+
default=True,
|
170 |
+
help="Smooth the camera trajectory.",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--smooth_method",
|
174 |
+
type=str,
|
175 |
+
default="kalman",
|
176 |
+
choices=["kalman", "simple"],
|
177 |
+
help="Smooth method.",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--sliding_window_stride",
|
181 |
+
type=int,
|
182 |
+
default=24,
|
183 |
+
help="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task.",
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--post_reconstruction",
|
187 |
+
action="store_true",
|
188 |
+
default=True,
|
189 |
+
help="Run reconstruction after prediction for better quality. Only used for 'prediction' and 'planning' tasks.",
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--pointcloud_save_frame_interval",
|
193 |
+
type=int,
|
194 |
+
default=10,
|
195 |
+
help="Pointcloud save frame interval.",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--align_pointmaps",
|
199 |
+
action="store_true",
|
200 |
+
default=False,
|
201 |
+
help="Align pointmaps.",
|
202 |
+
)
|
203 |
+
return parser.parse_args()
|
204 |
+
|
205 |
+
|
206 |
+
def build_pipeline(args: argparse.Namespace) -> AetherV1PipelineCogVideoX:
|
207 |
+
pipeline = AetherV1PipelineCogVideoX(
|
208 |
+
tokenizer=AutoTokenizer.from_pretrained(
|
209 |
+
args.cogvideox_pretrained_model_name_or_path,
|
210 |
+
subfolder="tokenizer",
|
211 |
+
),
|
212 |
+
text_encoder=T5EncoderModel.from_pretrained(
|
213 |
+
args.cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
|
214 |
+
),
|
215 |
+
vae=AutoencoderKLCogVideoX.from_pretrained(
|
216 |
+
args.cogvideox_pretrained_model_name_or_path, subfolder="vae"
|
217 |
+
),
|
218 |
+
scheduler=CogVideoXDPMScheduler.from_pretrained(
|
219 |
+
args.cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
|
220 |
+
),
|
221 |
+
transformer=CogVideoXTransformer3DModel.from_pretrained(
|
222 |
+
args.aether_pretrained_model_name_or_path, subfolder="transformer"
|
223 |
+
),
|
224 |
+
)
|
225 |
+
pipeline.vae.enable_slicing()
|
226 |
+
pipeline.vae.enable_tiling()
|
227 |
+
pipeline.to(device)
|
228 |
+
return pipeline
|
229 |
+
|
230 |
+
|
231 |
+
def get_window_starts(
|
232 |
+
total_frames: int, sliding_window_size: int, temporal_stride: int
|
233 |
+
) -> List[int]:
|
234 |
+
"""Calculate window start indices."""
|
235 |
+
starts = list(
|
236 |
+
range(
|
237 |
+
0,
|
238 |
+
total_frames - sliding_window_size + 1,
|
239 |
+
temporal_stride,
|
240 |
+
)
|
241 |
+
)
|
242 |
+
if (
|
243 |
+
total_frames > sliding_window_size
|
244 |
+
and (total_frames - sliding_window_size) % temporal_stride != 0
|
245 |
+
):
|
246 |
+
starts.append(total_frames - sliding_window_size)
|
247 |
+
return starts
|
248 |
+
|
249 |
+
|
250 |
+
def blend_and_merge_window_results(
|
251 |
+
window_results: List[AetherV1PipelineOutput],
|
252 |
+
window_indices: List[int],
|
253 |
+
args: argparse.Namespace,
|
254 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
255 |
+
"""Blend and merge window results."""
|
256 |
+
merged_rgb = None
|
257 |
+
merged_disparity = None
|
258 |
+
merged_poses = None
|
259 |
+
merged_focals = None
|
260 |
+
if args.align_pointmaps:
|
261 |
+
merged_pointmaps = None
|
262 |
+
|
263 |
+
w1 = window_results[0].disparity
|
264 |
+
|
265 |
+
for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)):
|
266 |
+
t_end = t_start + window_result.rgb.shape[0]
|
267 |
+
if idx == 0:
|
268 |
+
merged_rgb = window_result.rgb
|
269 |
+
merged_disparity = window_result.disparity
|
270 |
+
pointmap_dict = postprocess_pointmap(
|
271 |
+
window_result.disparity,
|
272 |
+
window_result.raymap,
|
273 |
+
vae_downsample_scale=8,
|
274 |
+
ray_o_scale_inv=0.1,
|
275 |
+
smooth_camera=args.smooth_camera,
|
276 |
+
smooth_method=args.smooth_method if args.smooth_camera else "none",
|
277 |
+
)
|
278 |
+
merged_poses = pointmap_dict["camera_pose"]
|
279 |
+
merged_focals = (
|
280 |
+
pointmap_dict["intrinsics"][:, 0, 0]
|
281 |
+
+ pointmap_dict["intrinsics"][:, 1, 1]
|
282 |
+
) / 2
|
283 |
+
if args.align_pointmaps:
|
284 |
+
merged_pointmaps = pointmap_dict["pointmap"]
|
285 |
+
else:
|
286 |
+
overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start
|
287 |
+
|
288 |
+
window_disparity = window_result.disparity
|
289 |
+
|
290 |
+
# Align disparity
|
291 |
+
disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1
|
292 |
+
scale = compute_scale(
|
293 |
+
window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]),
|
294 |
+
merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]),
|
295 |
+
disp_mask.reshape(1, -1, w1.shape[-1]),
|
296 |
+
)
|
297 |
+
window_disparity = scale * window_disparity
|
298 |
+
|
299 |
+
# Blend disparity
|
300 |
+
result_disparity = np.ones((t_end, *w1.shape[1:]))
|
301 |
+
result_disparity[:t_start] = merged_disparity[:t_start]
|
302 |
+
result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:]
|
303 |
+
weight = np.linspace(1, 0, overlap_t)[:, None, None]
|
304 |
+
result_disparity[t_start : t_start + overlap_t] = merged_disparity[
|
305 |
+
t_start : t_start + overlap_t
|
306 |
+
] * weight + window_disparity[:overlap_t] * (1 - weight)
|
307 |
+
merged_disparity = result_disparity
|
308 |
+
|
309 |
+
# Blend RGB
|
310 |
+
result_rgb = np.ones((t_end, *w1.shape[1:], 3))
|
311 |
+
result_rgb[:t_start] = merged_rgb[:t_start]
|
312 |
+
result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:]
|
313 |
+
weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None]
|
314 |
+
result_rgb[t_start : t_start + overlap_t] = merged_rgb[
|
315 |
+
t_start : t_start + overlap_t
|
316 |
+
] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb)
|
317 |
+
merged_rgb = result_rgb
|
318 |
+
|
319 |
+
# Align poses
|
320 |
+
window_raymap = window_result.raymap
|
321 |
+
window_poses, window_Fov_x, window_Fov_y = raymap_to_poses(
|
322 |
+
window_raymap, ray_o_scale_inv=0.1
|
323 |
+
)
|
324 |
+
rel_r, rel_t, rel_s = align_camera_extrinsics(
|
325 |
+
torch.from_numpy(window_poses[:overlap_t]),
|
326 |
+
torch.from_numpy(merged_poses[-overlap_t:]),
|
327 |
+
)
|
328 |
+
aligned_window_poses = (
|
329 |
+
apply_transformation(
|
330 |
+
torch.from_numpy(window_poses),
|
331 |
+
rel_r,
|
332 |
+
rel_t,
|
333 |
+
rel_s,
|
334 |
+
return_extri=True,
|
335 |
+
)
|
336 |
+
.cpu()
|
337 |
+
.numpy()
|
338 |
+
)
|
339 |
+
|
340 |
+
result_poses = np.ones((t_end, 4, 4))
|
341 |
+
result_poses[:t_start] = merged_poses[:t_start]
|
342 |
+
result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:]
|
343 |
+
|
344 |
+
# Interpolate poses in overlap region
|
345 |
+
weights = np.linspace(1, 0, overlap_t)
|
346 |
+
for t in range(overlap_t):
|
347 |
+
weight = weights[t]
|
348 |
+
pose1 = merged_poses[t_start + t]
|
349 |
+
pose2 = aligned_window_poses[t]
|
350 |
+
result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight)
|
351 |
+
|
352 |
+
merged_poses = result_poses
|
353 |
+
|
354 |
+
# Align intrinsics
|
355 |
+
window_intrinsics, _ = get_intrinsics(
|
356 |
+
batch_size=window_poses.shape[0],
|
357 |
+
h=window_result.disparity.shape[1],
|
358 |
+
w=window_result.disparity.shape[2],
|
359 |
+
fovx=window_Fov_x,
|
360 |
+
fovy=window_Fov_y,
|
361 |
+
)
|
362 |
+
window_focals = (
|
363 |
+
window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1]
|
364 |
+
) / 2
|
365 |
+
scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean()
|
366 |
+
window_focals = scale * window_focals
|
367 |
+
result_focals = np.ones((t_end,))
|
368 |
+
result_focals[:t_start] = merged_focals[:t_start]
|
369 |
+
result_focals[t_start + overlap_t :] = window_focals[overlap_t:]
|
370 |
+
weight = np.linspace(1, 0, overlap_t)
|
371 |
+
result_focals[t_start : t_start + overlap_t] = merged_focals[
|
372 |
+
t_start : t_start + overlap_t
|
373 |
+
] * weight + window_focals[:overlap_t] * (1 - weight)
|
374 |
+
merged_focals = result_focals
|
375 |
+
|
376 |
+
if args.align_pointmaps:
|
377 |
+
# Align pointmaps
|
378 |
+
window_pointmaps = postprocess_pointmap(
|
379 |
+
result_disparity[t_start:],
|
380 |
+
window_raymap,
|
381 |
+
vae_downsample_scale=8,
|
382 |
+
camera_pose=aligned_window_poses,
|
383 |
+
focal=window_focals,
|
384 |
+
ray_o_scale_inv=0.1,
|
385 |
+
smooth_camera=args.smooth_camera,
|
386 |
+
smooth_method=args.smooth_method if args.smooth_camera else "none",
|
387 |
+
)
|
388 |
+
result_pointmaps = np.ones((t_end, *w1.shape[1:], 3))
|
389 |
+
result_pointmaps[:t_start] = merged_pointmaps[:t_start]
|
390 |
+
result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][
|
391 |
+
overlap_t:
|
392 |
+
]
|
393 |
+
weight = np.linspace(1, 0, overlap_t)[:, None, None, None]
|
394 |
+
result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[
|
395 |
+
t_start : t_start + overlap_t
|
396 |
+
] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight)
|
397 |
+
merged_pointmaps = result_pointmaps
|
398 |
+
|
399 |
+
# project to pointmaps
|
400 |
+
intrinsics = [
|
401 |
+
np.array([[f, 0, 0.5 * args.width], [0, f, 0.5 * args.height], [0, 0, 1]])
|
402 |
+
for f in merged_focals
|
403 |
+
]
|
404 |
+
if args.align_pointmaps:
|
405 |
+
pointmaps = merged_pointmaps
|
406 |
+
else:
|
407 |
+
pointmaps = np.stack(
|
408 |
+
[
|
409 |
+
project(
|
410 |
+
1 / np.clip(merged_disparity[i], 1e-8, 1e8),
|
411 |
+
intrinsics[i],
|
412 |
+
merged_poses[i],
|
413 |
+
)
|
414 |
+
for i in range(merged_poses.shape[0])
|
415 |
+
]
|
416 |
+
)
|
417 |
+
|
418 |
+
return merged_rgb, merged_disparity, merged_poses, pointmaps
|
419 |
+
|
420 |
+
|
421 |
+
def save_output(
|
422 |
+
rgb: np.ndarray,
|
423 |
+
disparity: np.ndarray,
|
424 |
+
poses: Optional[np.ndarray] = None,
|
425 |
+
raymap: Optional[np.ndarray] = None,
|
426 |
+
pointmap: Optional[np.ndarray] = None,
|
427 |
+
args: argparse.Namespace = None,
|
428 |
+
) -> None:
|
429 |
+
output_dir = args.output_dir
|
430 |
+
os.makedirs(output_dir, exist_ok=True)
|
431 |
+
|
432 |
+
if pointmap is None:
|
433 |
+
assert raymap is not None, "Raymap is required for saving pointmap."
|
434 |
+
pointmap_dict = postprocess_pointmap(
|
435 |
+
disparity,
|
436 |
+
raymap,
|
437 |
+
vae_downsample_scale=8,
|
438 |
+
ray_o_scale_inv=0.1,
|
439 |
+
smooth_camera=args.smooth_camera,
|
440 |
+
smooth_method=args.smooth_method,
|
441 |
+
)
|
442 |
+
pointmap = pointmap_dict["pointmap"]
|
443 |
+
|
444 |
+
if poses is None:
|
445 |
+
assert raymap is not None, "Raymap is required for saving poses."
|
446 |
+
poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
|
447 |
+
|
448 |
+
if args.task == "reconstruction":
|
449 |
+
filename = f"reconstruction_{args.video.split('/')[-1].split('.')[0]}"
|
450 |
+
elif args.task == "prediction":
|
451 |
+
filename = f"prediction_{args.image.split('/')[-1].split('.')[0]}"
|
452 |
+
elif args.task == "planning":
|
453 |
+
filename = f"planning_{args.image.split('/')[-1].split('.')[0]}_{args.goal.split('/')[-1].split('.')[0]}"
|
454 |
+
|
455 |
+
filename = os.path.join(output_dir, filename)
|
456 |
+
|
457 |
+
iio.imwrite(
|
458 |
+
f"{filename}_rgb.mp4",
|
459 |
+
(np.clip(rgb, 0, 1) * 255).astype(np.uint8),
|
460 |
+
fps=12,
|
461 |
+
)
|
462 |
+
iio.imwrite(
|
463 |
+
f"{filename}_disparity.mp4",
|
464 |
+
(colorize_depth(disparity) * 255).astype(np.uint8),
|
465 |
+
fps=12,
|
466 |
+
)
|
467 |
+
|
468 |
+
print("Building GLB scene")
|
469 |
+
for frame_idx in range(pointmap.shape[0])[:: args.pointcloud_save_frame_interval]:
|
470 |
+
predictions = {
|
471 |
+
"world_points": pointmap[frame_idx : frame_idx + 1],
|
472 |
+
"images": rgb[frame_idx : frame_idx + 1],
|
473 |
+
"depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
|
474 |
+
"camera_poses": poses[frame_idx : frame_idx + 1],
|
475 |
+
}
|
476 |
+
scene_3d = predictions_to_glb(
|
477 |
+
predictions,
|
478 |
+
filter_by_frames="all",
|
479 |
+
show_cam=True,
|
480 |
+
max_depth=args.max_depth,
|
481 |
+
rtol=args.rtol,
|
482 |
+
frame_rel_idx=float(frame_idx) / pointmap.shape[0],
|
483 |
+
)
|
484 |
+
scene_3d.export(f"{filename}_pointcloud_frame_{frame_idx}.glb")
|
485 |
+
print("GLB Scene built")
|
486 |
+
|
487 |
+
|
488 |
+
def main() -> None:
|
489 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
490 |
+
args = parse_args()
|
491 |
+
seed_all(args.seed)
|
492 |
+
|
493 |
+
if args.num_inference_steps is None:
|
494 |
+
args.num_inference_steps = 4 if args.task == "reconstruction" else 50
|
495 |
+
|
496 |
+
if args.guidance_scale is None:
|
497 |
+
args.guidance_scale = 1.0 if args.task == "reconstruction" else 3.0
|
498 |
+
|
499 |
+
pipeline = build_pipeline(args)
|
500 |
+
|
501 |
+
if args.task == "reconstruction":
|
502 |
+
assert args.video is not None, "Video is required for reconstruction task."
|
503 |
+
assert args.image is None, "Image is not required for reconstruction task."
|
504 |
+
assert args.goal is None, "Goal is not required for reconstruction task."
|
505 |
+
|
506 |
+
video = iio.imread(args.video).astype(np.float32) / 255.0
|
507 |
+
image, goal = None, None
|
508 |
+
elif args.task == "prediction":
|
509 |
+
assert args.image is not None, "Image is required for prediction task."
|
510 |
+
assert args.goal is None, "Goal is not required for prediction task."
|
511 |
+
|
512 |
+
image = PIL.Image.open(args.image)
|
513 |
+
video, goal = None, None
|
514 |
+
elif args.task == "planning":
|
515 |
+
assert args.image is not None, "Image is required for planning task."
|
516 |
+
assert args.goal is not None, "Goal is required for planning task."
|
517 |
+
|
518 |
+
image = PIL.Image.open(args.image)
|
519 |
+
goal = PIL.Image.open(args.goal)
|
520 |
+
|
521 |
+
video = None
|
522 |
+
|
523 |
+
if args.raymap_action is not None:
|
524 |
+
raymap = np.load(args.raymap_action)
|
525 |
+
else:
|
526 |
+
raymap = None
|
527 |
+
|
528 |
+
if args.task != "reconstruction":
|
529 |
+
output = pipeline(
|
530 |
+
task=args.task,
|
531 |
+
image=image,
|
532 |
+
video=video,
|
533 |
+
goal=goal,
|
534 |
+
raymap=raymap,
|
535 |
+
height=args.height,
|
536 |
+
width=args.width,
|
537 |
+
num_frames=args.num_frames,
|
538 |
+
fps=args.fps,
|
539 |
+
num_inference_steps=args.num_inference_steps,
|
540 |
+
guidance_scale=args.guidance_scale,
|
541 |
+
use_dynamic_cfg=args.use_dynamic_cfg,
|
542 |
+
generator=torch.Generator(device=device).manual_seed(args.seed),
|
543 |
+
return_dict=True,
|
544 |
+
)
|
545 |
+
if not args.post_reconstruction:
|
546 |
+
save_output(
|
547 |
+
rgb=output.rgb,
|
548 |
+
disparity=output.disparity,
|
549 |
+
raymap=output.raymap,
|
550 |
+
args=args,
|
551 |
+
)
|
552 |
+
else:
|
553 |
+
recon_output = pipeline(
|
554 |
+
task="reconstruction",
|
555 |
+
video=output.rgb,
|
556 |
+
height=args.height,
|
557 |
+
width=args.width,
|
558 |
+
num_frames=args.num_frames,
|
559 |
+
fps=args.fps,
|
560 |
+
num_inference_steps=4,
|
561 |
+
guidance_scale=1.0, # we don't need guidance scale for reconstruction task
|
562 |
+
use_dynamic_cfg=False,
|
563 |
+
generator=torch.Generator(device=device).manual_seed(args.seed),
|
564 |
+
)
|
565 |
+
save_output(
|
566 |
+
rgb=output.rgb,
|
567 |
+
disparity=recon_output.disparity,
|
568 |
+
raymap=recon_output.raymap,
|
569 |
+
args=args,
|
570 |
+
)
|
571 |
+
else:
|
572 |
+
# for reconstruction task, we have to employ sliding window on long videos
|
573 |
+
window_results = []
|
574 |
+
window_indices = get_window_starts(
|
575 |
+
len(video), args.num_frames, args.sliding_window_stride
|
576 |
+
)
|
577 |
+
for start_idx in window_indices:
|
578 |
+
output = pipeline(
|
579 |
+
task=args.task,
|
580 |
+
image=None,
|
581 |
+
goal=None,
|
582 |
+
video=video[start_idx : start_idx + args.num_frames],
|
583 |
+
raymap=raymap[start_idx : start_idx + args.num_frames]
|
584 |
+
if raymap is not None
|
585 |
+
else None,
|
586 |
+
height=args.height,
|
587 |
+
width=args.width,
|
588 |
+
num_frames=args.num_frames,
|
589 |
+
fps=args.fps,
|
590 |
+
num_inference_steps=args.num_inference_steps,
|
591 |
+
guidance_scale=1.0, # we don't need guidance scale for reconstruction task
|
592 |
+
use_dynamic_cfg=False,
|
593 |
+
generator=torch.Generator(device=device).manual_seed(args.seed),
|
594 |
+
)
|
595 |
+
window_results.append(output)
|
596 |
+
|
597 |
+
# merge window results
|
598 |
+
(
|
599 |
+
merged_rgb,
|
600 |
+
merged_disparity,
|
601 |
+
merged_poses,
|
602 |
+
pointmaps,
|
603 |
+
) = blend_and_merge_window_results(window_results, window_indices, args)
|
604 |
+
save_output(
|
605 |
+
rgb=merged_rgb,
|
606 |
+
disparity=merged_disparity,
|
607 |
+
poses=merged_poses,
|
608 |
+
pointmap=pointmaps,
|
609 |
+
args=args,
|
610 |
+
)
|
611 |
+
|
612 |
+
|
613 |
+
if __name__ == "__main__":
|
614 |
+
main()
|
scripts/demo_gradio.py
ADDED
@@ -0,0 +1,1470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
from datetime import datetime
|
6 |
+
from typing import Dict, List, Optional, Tuple
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import imageio.v3 as iio
|
10 |
+
import numpy as np
|
11 |
+
import PIL
|
12 |
+
import rootutils
|
13 |
+
import torch
|
14 |
+
from diffusers import (
|
15 |
+
AutoencoderKLCogVideoX,
|
16 |
+
CogVideoXDPMScheduler,
|
17 |
+
CogVideoXTransformer3DModel,
|
18 |
+
)
|
19 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
20 |
+
|
21 |
+
|
22 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
23 |
+
|
24 |
+
from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402
|
25 |
+
AetherV1PipelineCogVideoX,
|
26 |
+
AetherV1PipelineOutput,
|
27 |
+
)
|
28 |
+
from aether.utils.postprocess_utils import ( # noqa: E402
|
29 |
+
align_camera_extrinsics,
|
30 |
+
apply_transformation,
|
31 |
+
colorize_depth,
|
32 |
+
compute_scale,
|
33 |
+
get_intrinsics,
|
34 |
+
interpolate_poses,
|
35 |
+
postprocess_pointmap,
|
36 |
+
project,
|
37 |
+
raymap_to_poses,
|
38 |
+
)
|
39 |
+
from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
|
40 |
+
|
41 |
+
|
42 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
|
44 |
+
|
45 |
+
def seed_all(seed: int = 0) -> None:
|
46 |
+
"""
|
47 |
+
Set random seeds of all components.
|
48 |
+
"""
|
49 |
+
random.seed(seed)
|
50 |
+
np.random.seed(seed)
|
51 |
+
torch.manual_seed(seed)
|
52 |
+
torch.cuda.manual_seed_all(seed)
|
53 |
+
|
54 |
+
|
55 |
+
# Global pipeline
|
56 |
+
cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
|
57 |
+
aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
|
58 |
+
pipeline = AetherV1PipelineCogVideoX(
|
59 |
+
tokenizer=AutoTokenizer.from_pretrained(
|
60 |
+
cogvideox_pretrained_model_name_or_path,
|
61 |
+
subfolder="tokenizer",
|
62 |
+
),
|
63 |
+
text_encoder=T5EncoderModel.from_pretrained(
|
64 |
+
cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
|
65 |
+
),
|
66 |
+
vae=AutoencoderKLCogVideoX.from_pretrained(
|
67 |
+
cogvideox_pretrained_model_name_or_path, subfolder="vae"
|
68 |
+
),
|
69 |
+
scheduler=CogVideoXDPMScheduler.from_pretrained(
|
70 |
+
cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
|
71 |
+
),
|
72 |
+
transformer=CogVideoXTransformer3DModel.from_pretrained(
|
73 |
+
aether_pretrained_model_name_or_path, subfolder="transformer"
|
74 |
+
),
|
75 |
+
)
|
76 |
+
pipeline.vae.enable_slicing()
|
77 |
+
pipeline.vae.enable_tiling()
|
78 |
+
pipeline.to(device)
|
79 |
+
|
80 |
+
|
81 |
+
def build_pipeline() -> AetherV1PipelineCogVideoX:
|
82 |
+
"""Initialize the model pipeline."""
|
83 |
+
return pipeline
|
84 |
+
|
85 |
+
|
86 |
+
def get_window_starts(
|
87 |
+
total_frames: int, sliding_window_size: int, temporal_stride: int
|
88 |
+
) -> List[int]:
|
89 |
+
"""Calculate window start indices."""
|
90 |
+
starts = list(
|
91 |
+
range(
|
92 |
+
0,
|
93 |
+
total_frames - sliding_window_size + 1,
|
94 |
+
temporal_stride,
|
95 |
+
)
|
96 |
+
)
|
97 |
+
if (
|
98 |
+
total_frames > sliding_window_size
|
99 |
+
and (total_frames - sliding_window_size) % temporal_stride != 0
|
100 |
+
):
|
101 |
+
starts.append(total_frames - sliding_window_size)
|
102 |
+
return starts
|
103 |
+
|
104 |
+
|
105 |
+
def blend_and_merge_window_results(
|
106 |
+
window_results: List[AetherV1PipelineOutput], window_indices: List[int], args: Dict
|
107 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
108 |
+
"""Blend and merge window results."""
|
109 |
+
merged_rgb = None
|
110 |
+
merged_disparity = None
|
111 |
+
merged_poses = None
|
112 |
+
merged_focals = None
|
113 |
+
align_pointmaps = args.get("align_pointmaps", True)
|
114 |
+
smooth_camera = args.get("smooth_camera", True)
|
115 |
+
smooth_method = args.get("smooth_method", "kalman") if smooth_camera else "none"
|
116 |
+
|
117 |
+
if align_pointmaps:
|
118 |
+
merged_pointmaps = None
|
119 |
+
|
120 |
+
w1 = window_results[0].disparity
|
121 |
+
|
122 |
+
for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)):
|
123 |
+
t_end = t_start + window_result.rgb.shape[0]
|
124 |
+
if idx == 0:
|
125 |
+
merged_rgb = window_result.rgb
|
126 |
+
merged_disparity = window_result.disparity
|
127 |
+
pointmap_dict = postprocess_pointmap(
|
128 |
+
window_result.disparity,
|
129 |
+
window_result.raymap,
|
130 |
+
vae_downsample_scale=8,
|
131 |
+
ray_o_scale_inv=0.1,
|
132 |
+
smooth_camera=smooth_camera,
|
133 |
+
smooth_method=smooth_method if smooth_camera else "none",
|
134 |
+
)
|
135 |
+
merged_poses = pointmap_dict["camera_pose"]
|
136 |
+
merged_focals = (
|
137 |
+
pointmap_dict["intrinsics"][:, 0, 0]
|
138 |
+
+ pointmap_dict["intrinsics"][:, 1, 1]
|
139 |
+
) / 2
|
140 |
+
if align_pointmaps:
|
141 |
+
merged_pointmaps = pointmap_dict["pointmap"]
|
142 |
+
else:
|
143 |
+
overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start
|
144 |
+
|
145 |
+
window_disparity = window_result.disparity
|
146 |
+
|
147 |
+
# Align disparity
|
148 |
+
disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1
|
149 |
+
scale = compute_scale(
|
150 |
+
window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]),
|
151 |
+
merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]),
|
152 |
+
disp_mask.reshape(1, -1, w1.shape[-1]),
|
153 |
+
)
|
154 |
+
window_disparity = scale * window_disparity
|
155 |
+
|
156 |
+
# Blend disparity
|
157 |
+
result_disparity = np.ones((t_end, *w1.shape[1:]))
|
158 |
+
result_disparity[:t_start] = merged_disparity[:t_start]
|
159 |
+
result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:]
|
160 |
+
weight = np.linspace(1, 0, overlap_t)[:, None, None]
|
161 |
+
result_disparity[t_start : t_start + overlap_t] = merged_disparity[
|
162 |
+
t_start : t_start + overlap_t
|
163 |
+
] * weight + window_disparity[:overlap_t] * (1 - weight)
|
164 |
+
merged_disparity = result_disparity
|
165 |
+
|
166 |
+
# Blend RGB
|
167 |
+
result_rgb = np.ones((t_end, *w1.shape[1:], 3))
|
168 |
+
result_rgb[:t_start] = merged_rgb[:t_start]
|
169 |
+
result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:]
|
170 |
+
weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None]
|
171 |
+
result_rgb[t_start : t_start + overlap_t] = merged_rgb[
|
172 |
+
t_start : t_start + overlap_t
|
173 |
+
] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb)
|
174 |
+
merged_rgb = result_rgb
|
175 |
+
|
176 |
+
# Align poses
|
177 |
+
window_raymap = window_result.raymap
|
178 |
+
window_poses, window_Fov_x, window_Fov_y = raymap_to_poses(
|
179 |
+
window_raymap, ray_o_scale_inv=0.1
|
180 |
+
)
|
181 |
+
rel_r, rel_t, rel_s = align_camera_extrinsics(
|
182 |
+
torch.from_numpy(window_poses[:overlap_t]),
|
183 |
+
torch.from_numpy(merged_poses[-overlap_t:]),
|
184 |
+
)
|
185 |
+
aligned_window_poses = (
|
186 |
+
apply_transformation(
|
187 |
+
torch.from_numpy(window_poses),
|
188 |
+
rel_r,
|
189 |
+
rel_t,
|
190 |
+
rel_s,
|
191 |
+
return_extri=True,
|
192 |
+
)
|
193 |
+
.cpu()
|
194 |
+
.numpy()
|
195 |
+
)
|
196 |
+
|
197 |
+
result_poses = np.ones((t_end, 4, 4))
|
198 |
+
result_poses[:t_start] = merged_poses[:t_start]
|
199 |
+
result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:]
|
200 |
+
|
201 |
+
# Interpolate poses in overlap region
|
202 |
+
weights = np.linspace(1, 0, overlap_t)
|
203 |
+
for t in range(overlap_t):
|
204 |
+
weight = weights[t]
|
205 |
+
pose1 = merged_poses[t_start + t]
|
206 |
+
pose2 = aligned_window_poses[t]
|
207 |
+
result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight)
|
208 |
+
|
209 |
+
merged_poses = result_poses
|
210 |
+
|
211 |
+
# Align intrinsics
|
212 |
+
window_intrinsics, _ = get_intrinsics(
|
213 |
+
batch_size=window_poses.shape[0],
|
214 |
+
h=window_result.disparity.shape[1],
|
215 |
+
w=window_result.disparity.shape[2],
|
216 |
+
fovx=window_Fov_x,
|
217 |
+
fovy=window_Fov_y,
|
218 |
+
)
|
219 |
+
window_focals = (
|
220 |
+
window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1]
|
221 |
+
) / 2
|
222 |
+
scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean()
|
223 |
+
window_focals = scale * window_focals
|
224 |
+
result_focals = np.ones((t_end,))
|
225 |
+
result_focals[:t_start] = merged_focals[:t_start]
|
226 |
+
result_focals[t_start + overlap_t :] = window_focals[overlap_t:]
|
227 |
+
weight = np.linspace(1, 0, overlap_t)
|
228 |
+
result_focals[t_start : t_start + overlap_t] = merged_focals[
|
229 |
+
t_start : t_start + overlap_t
|
230 |
+
] * weight + window_focals[:overlap_t] * (1 - weight)
|
231 |
+
merged_focals = result_focals
|
232 |
+
|
233 |
+
if align_pointmaps:
|
234 |
+
# Align pointmaps
|
235 |
+
window_pointmaps = postprocess_pointmap(
|
236 |
+
result_disparity[t_start:],
|
237 |
+
window_raymap,
|
238 |
+
vae_downsample_scale=8,
|
239 |
+
camera_pose=aligned_window_poses,
|
240 |
+
focal=window_focals,
|
241 |
+
ray_o_scale_inv=0.1,
|
242 |
+
smooth_camera=smooth_camera,
|
243 |
+
smooth_method=smooth_method if smooth_camera else "none",
|
244 |
+
)
|
245 |
+
result_pointmaps = np.ones((t_end, *w1.shape[1:], 3))
|
246 |
+
result_pointmaps[:t_start] = merged_pointmaps[:t_start]
|
247 |
+
result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][
|
248 |
+
overlap_t:
|
249 |
+
]
|
250 |
+
weight = np.linspace(1, 0, overlap_t)[:, None, None, None]
|
251 |
+
result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[
|
252 |
+
t_start : t_start + overlap_t
|
253 |
+
] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight)
|
254 |
+
merged_pointmaps = result_pointmaps
|
255 |
+
|
256 |
+
# project to pointmaps
|
257 |
+
height = args.get("height", 480)
|
258 |
+
width = args.get("width", 720)
|
259 |
+
|
260 |
+
intrinsics = [
|
261 |
+
np.array([[f, 0, 0.5 * width], [0, f, 0.5 * height], [0, 0, 1]])
|
262 |
+
for f in merged_focals
|
263 |
+
]
|
264 |
+
if align_pointmaps:
|
265 |
+
pointmaps = merged_pointmaps
|
266 |
+
else:
|
267 |
+
pointmaps = np.stack(
|
268 |
+
[
|
269 |
+
project(
|
270 |
+
1 / np.clip(merged_disparity[i], 1e-8, 1e8),
|
271 |
+
intrinsics[i],
|
272 |
+
merged_poses[i],
|
273 |
+
)
|
274 |
+
for i in range(merged_poses.shape[0])
|
275 |
+
]
|
276 |
+
)
|
277 |
+
|
278 |
+
return merged_rgb, merged_disparity, merged_poses, pointmaps
|
279 |
+
|
280 |
+
|
281 |
+
def process_video_to_frames(video_path: str, fps_sample: int = 12) -> List[str]:
|
282 |
+
"""Process video into frames and save them locally."""
|
283 |
+
# Create a unique output directory
|
284 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
285 |
+
output_dir = f"temp_frames_{timestamp}"
|
286 |
+
os.makedirs(output_dir, exist_ok=True)
|
287 |
+
|
288 |
+
# Read video
|
289 |
+
video = iio.imread(video_path)
|
290 |
+
|
291 |
+
# Calculate frame interval based on original video fps
|
292 |
+
if isinstance(video, np.ndarray):
|
293 |
+
# For captured videos
|
294 |
+
total_frames = len(video)
|
295 |
+
frame_interval = max(
|
296 |
+
1, round(total_frames / (fps_sample * (total_frames / 30)))
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
# Default if can't determine
|
300 |
+
frame_interval = 2
|
301 |
+
|
302 |
+
frame_paths = []
|
303 |
+
for i, frame in enumerate(video[::frame_interval]):
|
304 |
+
frame_path = os.path.join(output_dir, f"frame_{i:04d}.jpg")
|
305 |
+
if isinstance(frame, np.ndarray):
|
306 |
+
iio.imwrite(frame_path, frame)
|
307 |
+
frame_paths.append(frame_path)
|
308 |
+
|
309 |
+
return frame_paths, output_dir
|
310 |
+
|
311 |
+
|
312 |
+
def save_output_files(
|
313 |
+
rgb: np.ndarray,
|
314 |
+
disparity: np.ndarray,
|
315 |
+
poses: Optional[np.ndarray] = None,
|
316 |
+
raymap: Optional[np.ndarray] = None,
|
317 |
+
pointmap: Optional[np.ndarray] = None,
|
318 |
+
task: str = "reconstruction",
|
319 |
+
output_dir: str = "outputs",
|
320 |
+
**kwargs,
|
321 |
+
) -> Dict[str, str]:
|
322 |
+
"""
|
323 |
+
Save outputs and return paths to saved files.
|
324 |
+
"""
|
325 |
+
os.makedirs(output_dir, exist_ok=True)
|
326 |
+
|
327 |
+
if pointmap is None and raymap is not None:
|
328 |
+
# Generate pointmap from raymap and disparity
|
329 |
+
smooth_camera = kwargs.get("smooth_camera", True)
|
330 |
+
smooth_method = (
|
331 |
+
kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
|
332 |
+
)
|
333 |
+
|
334 |
+
pointmap_dict = postprocess_pointmap(
|
335 |
+
disparity,
|
336 |
+
raymap,
|
337 |
+
vae_downsample_scale=8,
|
338 |
+
ray_o_scale_inv=0.1,
|
339 |
+
smooth_camera=smooth_camera,
|
340 |
+
smooth_method=smooth_method,
|
341 |
+
)
|
342 |
+
pointmap = pointmap_dict["pointmap"]
|
343 |
+
|
344 |
+
if poses is None and raymap is not None:
|
345 |
+
poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
|
346 |
+
|
347 |
+
# Create a unique filename
|
348 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
349 |
+
base_filename = f"{task}_{timestamp}"
|
350 |
+
|
351 |
+
# Paths for saved files
|
352 |
+
paths = {}
|
353 |
+
|
354 |
+
# Save RGB video
|
355 |
+
rgb_path = os.path.join(output_dir, f"{base_filename}_rgb.mp4")
|
356 |
+
iio.imwrite(
|
357 |
+
rgb_path,
|
358 |
+
(np.clip(rgb, 0, 1) * 255).astype(np.uint8),
|
359 |
+
fps=kwargs.get("fps", 12),
|
360 |
+
)
|
361 |
+
paths["rgb"] = rgb_path
|
362 |
+
|
363 |
+
# Save depth/disparity video
|
364 |
+
depth_path = os.path.join(output_dir, f"{base_filename}_disparity.mp4")
|
365 |
+
iio.imwrite(
|
366 |
+
depth_path,
|
367 |
+
(colorize_depth(disparity) * 255).astype(np.uint8),
|
368 |
+
fps=kwargs.get("fps", 12),
|
369 |
+
)
|
370 |
+
paths["disparity"] = depth_path
|
371 |
+
|
372 |
+
# Save point cloud GLB files
|
373 |
+
if pointmap is not None and poses is not None:
|
374 |
+
pointcloud_save_frame_interval = kwargs.get(
|
375 |
+
"pointcloud_save_frame_interval", 10
|
376 |
+
)
|
377 |
+
max_depth = kwargs.get("max_depth", 100.0)
|
378 |
+
rtol = kwargs.get("rtol", 0.03)
|
379 |
+
|
380 |
+
glb_paths = []
|
381 |
+
# Determine which frames to save based on the interval
|
382 |
+
frames_to_save = list(
|
383 |
+
range(0, pointmap.shape[0], pointcloud_save_frame_interval)
|
384 |
+
)
|
385 |
+
|
386 |
+
# Always include the first and last frame
|
387 |
+
if 0 not in frames_to_save:
|
388 |
+
frames_to_save.insert(0, 0)
|
389 |
+
if pointmap.shape[0] - 1 not in frames_to_save:
|
390 |
+
frames_to_save.append(pointmap.shape[0] - 1)
|
391 |
+
|
392 |
+
# Sort the frames to ensure they're in order
|
393 |
+
frames_to_save = sorted(set(frames_to_save))
|
394 |
+
|
395 |
+
for frame_idx in frames_to_save:
|
396 |
+
if frame_idx >= pointmap.shape[0]:
|
397 |
+
continue
|
398 |
+
|
399 |
+
predictions = {
|
400 |
+
"world_points": pointmap[frame_idx : frame_idx + 1],
|
401 |
+
"images": rgb[frame_idx : frame_idx + 1],
|
402 |
+
"depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
|
403 |
+
"camera_poses": poses[frame_idx : frame_idx + 1],
|
404 |
+
}
|
405 |
+
|
406 |
+
glb_path = os.path.join(
|
407 |
+
output_dir, f"{base_filename}_pointcloud_frame_{frame_idx}.glb"
|
408 |
+
)
|
409 |
+
|
410 |
+
scene_3d = predictions_to_glb(
|
411 |
+
predictions,
|
412 |
+
filter_by_frames="all",
|
413 |
+
show_cam=True,
|
414 |
+
max_depth=max_depth,
|
415 |
+
rtol=rtol,
|
416 |
+
frame_rel_idx=float(frame_idx) / pointmap.shape[0],
|
417 |
+
)
|
418 |
+
scene_3d.export(glb_path)
|
419 |
+
glb_paths.append(glb_path)
|
420 |
+
|
421 |
+
paths["pointcloud_glbs"] = glb_paths
|
422 |
+
|
423 |
+
return paths
|
424 |
+
|
425 |
+
|
426 |
+
def process_reconstruction(
|
427 |
+
video_file,
|
428 |
+
height,
|
429 |
+
width,
|
430 |
+
num_frames,
|
431 |
+
num_inference_steps,
|
432 |
+
guidance_scale,
|
433 |
+
sliding_window_stride,
|
434 |
+
fps,
|
435 |
+
smooth_camera,
|
436 |
+
align_pointmaps,
|
437 |
+
max_depth,
|
438 |
+
rtol,
|
439 |
+
pointcloud_save_frame_interval,
|
440 |
+
seed,
|
441 |
+
progress=gr.Progress(),
|
442 |
+
):
|
443 |
+
"""
|
444 |
+
Process reconstruction task.
|
445 |
+
"""
|
446 |
+
try:
|
447 |
+
gc.collect()
|
448 |
+
torch.cuda.empty_cache()
|
449 |
+
|
450 |
+
# Set random seed
|
451 |
+
seed_all(seed)
|
452 |
+
|
453 |
+
# Build the pipeline
|
454 |
+
pipeline = build_pipeline()
|
455 |
+
|
456 |
+
progress(0.1, "Loading video")
|
457 |
+
# Check if video_file is a string or a file object
|
458 |
+
if isinstance(video_file, str):
|
459 |
+
video_path = video_file
|
460 |
+
else:
|
461 |
+
video_path = video_file.name
|
462 |
+
|
463 |
+
video = iio.imread(video_path).astype(np.float32) / 255.0
|
464 |
+
|
465 |
+
# Setup arguments
|
466 |
+
args = {
|
467 |
+
"height": height,
|
468 |
+
"width": width,
|
469 |
+
"num_frames": num_frames,
|
470 |
+
"sliding_window_stride": sliding_window_stride,
|
471 |
+
"smooth_camera": smooth_camera,
|
472 |
+
"smooth_method": "kalman" if smooth_camera else "none",
|
473 |
+
"align_pointmaps": align_pointmaps,
|
474 |
+
"max_depth": max_depth,
|
475 |
+
"rtol": rtol,
|
476 |
+
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
|
477 |
+
}
|
478 |
+
|
479 |
+
# Process in sliding windows
|
480 |
+
window_results = []
|
481 |
+
window_indices = get_window_starts(
|
482 |
+
len(video), num_frames, sliding_window_stride
|
483 |
+
)
|
484 |
+
|
485 |
+
progress(0.2, f"Processing video in {len(window_indices)} windows")
|
486 |
+
|
487 |
+
for i, start_idx in enumerate(window_indices):
|
488 |
+
progress_val = 0.2 + (0.6 * (i / len(window_indices)))
|
489 |
+
progress(progress_val, f"Processing window {i+1}/{len(window_indices)}")
|
490 |
+
|
491 |
+
output = pipeline(
|
492 |
+
task="reconstruction",
|
493 |
+
image=None,
|
494 |
+
goal=None,
|
495 |
+
video=video[start_idx : start_idx + num_frames],
|
496 |
+
raymap=None,
|
497 |
+
height=height,
|
498 |
+
width=width,
|
499 |
+
num_frames=num_frames,
|
500 |
+
fps=fps,
|
501 |
+
num_inference_steps=num_inference_steps,
|
502 |
+
guidance_scale=guidance_scale,
|
503 |
+
use_dynamic_cfg=False,
|
504 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
505 |
+
)
|
506 |
+
window_results.append(output)
|
507 |
+
|
508 |
+
progress(0.8, "Merging results from all windows")
|
509 |
+
# Merge window results
|
510 |
+
(
|
511 |
+
merged_rgb,
|
512 |
+
merged_disparity,
|
513 |
+
merged_poses,
|
514 |
+
pointmaps,
|
515 |
+
) = blend_and_merge_window_results(window_results, window_indices, args)
|
516 |
+
|
517 |
+
progress(0.9, "Saving output files")
|
518 |
+
# Save output files
|
519 |
+
output_dir = "outputs"
|
520 |
+
os.makedirs(output_dir, exist_ok=True)
|
521 |
+
output_paths = save_output_files(
|
522 |
+
rgb=merged_rgb,
|
523 |
+
disparity=merged_disparity,
|
524 |
+
poses=merged_poses,
|
525 |
+
pointmap=pointmaps,
|
526 |
+
task="reconstruction",
|
527 |
+
output_dir=output_dir,
|
528 |
+
fps=12,
|
529 |
+
**args,
|
530 |
+
)
|
531 |
+
|
532 |
+
progress(1.0, "Done!")
|
533 |
+
|
534 |
+
# Return paths for displaying
|
535 |
+
return (
|
536 |
+
output_paths["rgb"],
|
537 |
+
output_paths["disparity"],
|
538 |
+
output_paths.get("pointcloud_glbs", []),
|
539 |
+
)
|
540 |
+
|
541 |
+
except Exception:
|
542 |
+
import traceback
|
543 |
+
|
544 |
+
traceback.print_exc()
|
545 |
+
return None, None, []
|
546 |
+
|
547 |
+
|
548 |
+
def process_prediction(
|
549 |
+
image_file,
|
550 |
+
height,
|
551 |
+
width,
|
552 |
+
num_frames,
|
553 |
+
num_inference_steps,
|
554 |
+
guidance_scale,
|
555 |
+
use_dynamic_cfg,
|
556 |
+
raymap_option,
|
557 |
+
post_reconstruction,
|
558 |
+
fps,
|
559 |
+
smooth_camera,
|
560 |
+
align_pointmaps,
|
561 |
+
max_depth,
|
562 |
+
rtol,
|
563 |
+
pointcloud_save_frame_interval,
|
564 |
+
seed,
|
565 |
+
progress=gr.Progress(),
|
566 |
+
):
|
567 |
+
"""
|
568 |
+
Process prediction task.
|
569 |
+
"""
|
570 |
+
try:
|
571 |
+
gc.collect()
|
572 |
+
torch.cuda.empty_cache()
|
573 |
+
|
574 |
+
# Set random seed
|
575 |
+
seed_all(seed)
|
576 |
+
|
577 |
+
# Build the pipeline
|
578 |
+
pipeline = build_pipeline()
|
579 |
+
|
580 |
+
progress(0.1, "Loading image")
|
581 |
+
# Check if image_file is a string or a file object
|
582 |
+
if isinstance(image_file, str):
|
583 |
+
image_path = image_file
|
584 |
+
else:
|
585 |
+
image_path = image_file.name
|
586 |
+
|
587 |
+
image = PIL.Image.open(image_path)
|
588 |
+
|
589 |
+
progress(0.2, "Running prediction")
|
590 |
+
# Run prediction
|
591 |
+
output = pipeline(
|
592 |
+
task="prediction",
|
593 |
+
image=image,
|
594 |
+
video=None,
|
595 |
+
goal=None,
|
596 |
+
raymap=np.load(f"assets/example_raymaps/raymap_{raymap_option}.npy"),
|
597 |
+
height=height,
|
598 |
+
width=width,
|
599 |
+
num_frames=num_frames,
|
600 |
+
fps=fps,
|
601 |
+
num_inference_steps=num_inference_steps,
|
602 |
+
guidance_scale=guidance_scale,
|
603 |
+
use_dynamic_cfg=use_dynamic_cfg,
|
604 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
605 |
+
return_dict=True,
|
606 |
+
)
|
607 |
+
|
608 |
+
# Show RGB output immediately
|
609 |
+
rgb_output = output.rgb
|
610 |
+
|
611 |
+
# Setup arguments for saving
|
612 |
+
args = {
|
613 |
+
"height": height,
|
614 |
+
"width": width,
|
615 |
+
"smooth_camera": smooth_camera,
|
616 |
+
"smooth_method": "kalman" if smooth_camera else "none",
|
617 |
+
"align_pointmaps": align_pointmaps,
|
618 |
+
"max_depth": max_depth,
|
619 |
+
"rtol": rtol,
|
620 |
+
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
|
621 |
+
}
|
622 |
+
|
623 |
+
if post_reconstruction:
|
624 |
+
progress(0.5, "Running post-reconstruction for better quality")
|
625 |
+
recon_output = pipeline(
|
626 |
+
task="reconstruction",
|
627 |
+
video=output.rgb,
|
628 |
+
height=height,
|
629 |
+
width=width,
|
630 |
+
num_frames=num_frames,
|
631 |
+
fps=fps,
|
632 |
+
num_inference_steps=4,
|
633 |
+
guidance_scale=1.0,
|
634 |
+
use_dynamic_cfg=False,
|
635 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
636 |
+
)
|
637 |
+
|
638 |
+
disparity = recon_output.disparity
|
639 |
+
raymap = recon_output.raymap
|
640 |
+
else:
|
641 |
+
disparity = output.disparity
|
642 |
+
raymap = output.raymap
|
643 |
+
|
644 |
+
progress(0.8, "Saving output files")
|
645 |
+
# Save output files
|
646 |
+
output_dir = "outputs"
|
647 |
+
os.makedirs(output_dir, exist_ok=True)
|
648 |
+
output_paths = save_output_files(
|
649 |
+
rgb=rgb_output,
|
650 |
+
disparity=disparity,
|
651 |
+
raymap=raymap,
|
652 |
+
task="prediction",
|
653 |
+
output_dir=output_dir,
|
654 |
+
fps=12,
|
655 |
+
**args,
|
656 |
+
)
|
657 |
+
|
658 |
+
progress(1.0, "Done!")
|
659 |
+
|
660 |
+
# Return paths for displaying
|
661 |
+
return (
|
662 |
+
output_paths["rgb"],
|
663 |
+
output_paths["disparity"],
|
664 |
+
output_paths.get("pointcloud_glbs", []),
|
665 |
+
)
|
666 |
+
|
667 |
+
except Exception:
|
668 |
+
import traceback
|
669 |
+
|
670 |
+
traceback.print_exc()
|
671 |
+
return None, None, []
|
672 |
+
|
673 |
+
|
674 |
+
def process_planning(
|
675 |
+
image_file,
|
676 |
+
goal_file,
|
677 |
+
height,
|
678 |
+
width,
|
679 |
+
num_frames,
|
680 |
+
num_inference_steps,
|
681 |
+
guidance_scale,
|
682 |
+
use_dynamic_cfg,
|
683 |
+
post_reconstruction,
|
684 |
+
fps,
|
685 |
+
smooth_camera,
|
686 |
+
align_pointmaps,
|
687 |
+
max_depth,
|
688 |
+
rtol,
|
689 |
+
pointcloud_save_frame_interval,
|
690 |
+
seed,
|
691 |
+
progress=gr.Progress(),
|
692 |
+
):
|
693 |
+
"""
|
694 |
+
Process planning task.
|
695 |
+
"""
|
696 |
+
try:
|
697 |
+
gc.collect()
|
698 |
+
torch.cuda.empty_cache()
|
699 |
+
|
700 |
+
# Set random seed
|
701 |
+
seed_all(seed)
|
702 |
+
|
703 |
+
# Build the pipeline
|
704 |
+
pipeline = build_pipeline()
|
705 |
+
|
706 |
+
progress(0.1, "Loading images")
|
707 |
+
# Check if image_file and goal_file are strings or file objects
|
708 |
+
if isinstance(image_file, str):
|
709 |
+
image_path = image_file
|
710 |
+
else:
|
711 |
+
image_path = image_file.name
|
712 |
+
|
713 |
+
if isinstance(goal_file, str):
|
714 |
+
goal_path = goal_file
|
715 |
+
else:
|
716 |
+
goal_path = goal_file.name
|
717 |
+
|
718 |
+
image = PIL.Image.open(image_path)
|
719 |
+
goal = PIL.Image.open(goal_path)
|
720 |
+
|
721 |
+
progress(0.2, "Running planning")
|
722 |
+
# Run planning
|
723 |
+
output = pipeline(
|
724 |
+
task="planning",
|
725 |
+
image=image,
|
726 |
+
video=None,
|
727 |
+
goal=goal,
|
728 |
+
raymap=None,
|
729 |
+
height=height,
|
730 |
+
width=width,
|
731 |
+
num_frames=num_frames,
|
732 |
+
fps=fps,
|
733 |
+
num_inference_steps=num_inference_steps,
|
734 |
+
guidance_scale=guidance_scale,
|
735 |
+
use_dynamic_cfg=use_dynamic_cfg,
|
736 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
737 |
+
return_dict=True,
|
738 |
+
)
|
739 |
+
|
740 |
+
# Show RGB output immediately
|
741 |
+
rgb_output = output.rgb
|
742 |
+
|
743 |
+
# Setup arguments for saving
|
744 |
+
args = {
|
745 |
+
"height": height,
|
746 |
+
"width": width,
|
747 |
+
"smooth_camera": smooth_camera,
|
748 |
+
"smooth_method": "kalman" if smooth_camera else "none",
|
749 |
+
"align_pointmaps": align_pointmaps,
|
750 |
+
"max_depth": max_depth,
|
751 |
+
"rtol": rtol,
|
752 |
+
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
|
753 |
+
}
|
754 |
+
|
755 |
+
if post_reconstruction:
|
756 |
+
progress(0.5, "Running post-reconstruction for better quality")
|
757 |
+
recon_output = pipeline(
|
758 |
+
task="reconstruction",
|
759 |
+
video=output.rgb,
|
760 |
+
height=height,
|
761 |
+
width=width,
|
762 |
+
num_frames=num_frames,
|
763 |
+
fps=12,
|
764 |
+
num_inference_steps=4,
|
765 |
+
guidance_scale=1.0,
|
766 |
+
use_dynamic_cfg=False,
|
767 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
768 |
+
)
|
769 |
+
|
770 |
+
disparity = recon_output.disparity
|
771 |
+
raymap = recon_output.raymap
|
772 |
+
else:
|
773 |
+
disparity = output.disparity
|
774 |
+
raymap = output.raymap
|
775 |
+
|
776 |
+
progress(0.8, "Saving output files")
|
777 |
+
# Save output files
|
778 |
+
output_dir = "outputs"
|
779 |
+
os.makedirs(output_dir, exist_ok=True)
|
780 |
+
output_paths = save_output_files(
|
781 |
+
rgb=rgb_output,
|
782 |
+
disparity=disparity,
|
783 |
+
raymap=raymap,
|
784 |
+
task="planning",
|
785 |
+
output_dir=output_dir,
|
786 |
+
fps=fps,
|
787 |
+
**args,
|
788 |
+
)
|
789 |
+
|
790 |
+
progress(1.0, "Done!")
|
791 |
+
|
792 |
+
# Return paths for displaying
|
793 |
+
return (
|
794 |
+
output_paths["rgb"],
|
795 |
+
output_paths["disparity"],
|
796 |
+
output_paths.get("pointcloud_glbs", []),
|
797 |
+
)
|
798 |
+
|
799 |
+
except Exception:
|
800 |
+
import traceback
|
801 |
+
|
802 |
+
traceback.print_exc()
|
803 |
+
return None, None, []
|
804 |
+
|
805 |
+
|
806 |
+
def update_task_ui(task):
|
807 |
+
"""Update UI elements based on selected task."""
|
808 |
+
if task == "reconstruction":
|
809 |
+
return (
|
810 |
+
gr.update(visible=True), # video_input
|
811 |
+
gr.update(visible=False), # image_input
|
812 |
+
gr.update(visible=False), # goal_input
|
813 |
+
gr.update(visible=False), # image_preview
|
814 |
+
gr.update(visible=False), # goal_preview
|
815 |
+
gr.update(value=4), # num_inference_steps
|
816 |
+
gr.update(visible=True), # sliding_window_stride
|
817 |
+
gr.update(visible=False), # use_dynamic_cfg
|
818 |
+
gr.update(visible=False), # raymap_option
|
819 |
+
gr.update(visible=False), # post_reconstruction
|
820 |
+
gr.update(value=1.0), # guidance_scale
|
821 |
+
)
|
822 |
+
elif task == "prediction":
|
823 |
+
return (
|
824 |
+
gr.update(visible=False), # video_input
|
825 |
+
gr.update(visible=True), # image_input
|
826 |
+
gr.update(visible=False), # goal_input
|
827 |
+
gr.update(visible=True), # image_preview
|
828 |
+
gr.update(visible=False), # goal_preview
|
829 |
+
gr.update(value=50), # num_inference_steps
|
830 |
+
gr.update(visible=False), # sliding_window_stride
|
831 |
+
gr.update(visible=True), # use_dynamic_cfg
|
832 |
+
gr.update(visible=True), # raymap_option
|
833 |
+
gr.update(visible=True), # post_reconstruction
|
834 |
+
gr.update(value=3.0), # guidance_scale
|
835 |
+
)
|
836 |
+
elif task == "planning":
|
837 |
+
return (
|
838 |
+
gr.update(visible=False), # video_input
|
839 |
+
gr.update(visible=True), # image_input
|
840 |
+
gr.update(visible=True), # goal_input
|
841 |
+
gr.update(visible=True), # image_preview
|
842 |
+
gr.update(visible=True), # goal_preview
|
843 |
+
gr.update(value=50), # num_inference_steps
|
844 |
+
gr.update(visible=False), # sliding_window_stride
|
845 |
+
gr.update(visible=True), # use_dynamic_cfg
|
846 |
+
gr.update(visible=False), # raymap_option
|
847 |
+
gr.update(visible=True), # post_reconstruction
|
848 |
+
gr.update(value=3.0), # guidance_scale
|
849 |
+
)
|
850 |
+
|
851 |
+
|
852 |
+
def update_image_preview(image_file):
|
853 |
+
"""Update the image preview."""
|
854 |
+
if image_file:
|
855 |
+
return image_file.name
|
856 |
+
return None
|
857 |
+
|
858 |
+
|
859 |
+
def update_goal_preview(goal_file):
|
860 |
+
"""Update the goal preview."""
|
861 |
+
if goal_file:
|
862 |
+
return goal_file.name
|
863 |
+
return None
|
864 |
+
|
865 |
+
|
866 |
+
def get_download_link(selected_frame, all_paths):
|
867 |
+
"""Update the download button with the selected file path."""
|
868 |
+
if not selected_frame or not all_paths:
|
869 |
+
return gr.update(visible=False, value=None)
|
870 |
+
|
871 |
+
frame_num = int(re.search(r"Frame (\d+)", selected_frame).group(1))
|
872 |
+
|
873 |
+
for path in all_paths:
|
874 |
+
if f"frame_{frame_num}" in path:
|
875 |
+
# Make sure the file exists before setting it
|
876 |
+
if os.path.exists(path):
|
877 |
+
return gr.update(visible=True, value=path, interactive=True)
|
878 |
+
|
879 |
+
return gr.update(visible=False, value=None)
|
880 |
+
|
881 |
+
|
882 |
+
# Theme setup
|
883 |
+
theme = gr.themes.Default(
|
884 |
+
primary_hue="blue",
|
885 |
+
secondary_hue="cyan",
|
886 |
+
)
|
887 |
+
|
888 |
+
with gr.Blocks(
|
889 |
+
theme=theme,
|
890 |
+
css="""
|
891 |
+
.output-column {
|
892 |
+
min-height: 400px;
|
893 |
+
}
|
894 |
+
.warning {
|
895 |
+
color: #ff9800;
|
896 |
+
font-weight: bold;
|
897 |
+
}
|
898 |
+
.highlight {
|
899 |
+
background-color: rgba(0, 123, 255, 0.1);
|
900 |
+
padding: 10px;
|
901 |
+
border-radius: 8px;
|
902 |
+
border-left: 5px solid #007bff;
|
903 |
+
margin: 10px 0;
|
904 |
+
}
|
905 |
+
.task-header {
|
906 |
+
margin-top: 10px;
|
907 |
+
margin-bottom: 15px;
|
908 |
+
font-size: 1.2em;
|
909 |
+
font-weight: bold;
|
910 |
+
color: #007bff;
|
911 |
+
}
|
912 |
+
.flex-display {
|
913 |
+
display: flex;
|
914 |
+
flex-wrap: wrap;
|
915 |
+
gap: 10px;
|
916 |
+
}
|
917 |
+
.output-subtitle {
|
918 |
+
font-size: 1.1em;
|
919 |
+
margin-top: 5px;
|
920 |
+
margin-bottom: 5px;
|
921 |
+
color: #505050;
|
922 |
+
}
|
923 |
+
.input-section, .params-section, .advanced-section {
|
924 |
+
border: 1px solid #ddd;
|
925 |
+
padding: 15px;
|
926 |
+
border-radius: 8px;
|
927 |
+
margin-bottom: 15px;
|
928 |
+
}
|
929 |
+
.logo-container {
|
930 |
+
display: flex;
|
931 |
+
justify-content: center;
|
932 |
+
margin-bottom: 20px;
|
933 |
+
}
|
934 |
+
.logo-image {
|
935 |
+
max-width: 300px;
|
936 |
+
height: auto;
|
937 |
+
}
|
938 |
+
""",
|
939 |
+
) as demo:
|
940 |
+
with gr.Row(elem_classes=["logo-container"]):
|
941 |
+
gr.Image("assets/logo.png", show_label=False, elem_classes=["logo-image"])
|
942 |
+
|
943 |
+
gr.Markdown(
|
944 |
+
"""
|
945 |
+
# Aether: Geometric-Aware Unified World Modeling
|
946 |
+
|
947 |
+
Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with
|
948 |
+
generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities:
|
949 |
+
|
950 |
+
1. **4D dynamic reconstruction** - Reconstruct dynamic point clouds from videos by estimating depths and camera poses.
|
951 |
+
2. **Action-Conditioned Video Prediction** - Predict future frames based on initial observation images, with optional conditions of camera trajectory actions.
|
952 |
+
3. **Goal-Conditioned Visual Planning** - Generate planning paths from pairs of observation and goal images.
|
953 |
+
|
954 |
+
Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.
|
955 |
+
"""
|
956 |
+
)
|
957 |
+
|
958 |
+
with gr.Row():
|
959 |
+
with gr.Column(scale=1):
|
960 |
+
task = gr.Radio(
|
961 |
+
["reconstruction", "prediction", "planning"],
|
962 |
+
label="Select Task",
|
963 |
+
value="reconstruction",
|
964 |
+
info="Choose the task you want to perform",
|
965 |
+
)
|
966 |
+
|
967 |
+
with gr.Group(elem_classes=["input-section"]):
|
968 |
+
# Input section - changes based on task
|
969 |
+
gr.Markdown("## 📥 Input", elem_classes=["task-header"])
|
970 |
+
|
971 |
+
# Task-specific inputs
|
972 |
+
video_input = gr.Video(
|
973 |
+
label="Upload Input Video",
|
974 |
+
sources=["upload"],
|
975 |
+
visible=True,
|
976 |
+
interactive=True,
|
977 |
+
elem_id="video_input",
|
978 |
+
)
|
979 |
+
|
980 |
+
image_input = gr.File(
|
981 |
+
label="Upload Start Image",
|
982 |
+
file_count="single",
|
983 |
+
file_types=["image"],
|
984 |
+
visible=False,
|
985 |
+
interactive=True,
|
986 |
+
elem_id="image_input",
|
987 |
+
)
|
988 |
+
|
989 |
+
goal_input = gr.File(
|
990 |
+
label="Upload Goal Image",
|
991 |
+
file_count="single",
|
992 |
+
file_types=["image"],
|
993 |
+
visible=False,
|
994 |
+
interactive=True,
|
995 |
+
elem_id="goal_input",
|
996 |
+
)
|
997 |
+
|
998 |
+
with gr.Row(visible=False) as preview_row:
|
999 |
+
image_preview = gr.Image(
|
1000 |
+
label="Start Image Preview",
|
1001 |
+
elem_id="image_preview",
|
1002 |
+
visible=False,
|
1003 |
+
)
|
1004 |
+
goal_preview = gr.Image(
|
1005 |
+
label="Goal Image Preview",
|
1006 |
+
elem_id="goal_preview",
|
1007 |
+
visible=False,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
with gr.Group(elem_classes=["params-section"]):
|
1011 |
+
gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"])
|
1012 |
+
|
1013 |
+
with gr.Row():
|
1014 |
+
with gr.Column(scale=1):
|
1015 |
+
height = gr.Dropdown(
|
1016 |
+
choices=[480],
|
1017 |
+
value=480,
|
1018 |
+
label="Height",
|
1019 |
+
info="Height of the output video",
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
with gr.Column(scale=1):
|
1023 |
+
width = gr.Dropdown(
|
1024 |
+
choices=[720],
|
1025 |
+
value=720,
|
1026 |
+
label="Width",
|
1027 |
+
info="Width of the output video",
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
with gr.Row():
|
1031 |
+
with gr.Column(scale=1):
|
1032 |
+
num_frames = gr.Dropdown(
|
1033 |
+
choices=[17, 25, 33, 41],
|
1034 |
+
value=41,
|
1035 |
+
label="Number of Frames",
|
1036 |
+
info="Number of frames to predict",
|
1037 |
+
)
|
1038 |
+
|
1039 |
+
with gr.Column(scale=1):
|
1040 |
+
fps = gr.Dropdown(
|
1041 |
+
choices=[8, 10, 12, 15, 24],
|
1042 |
+
value=12,
|
1043 |
+
label="FPS",
|
1044 |
+
info="Frames per second",
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
with gr.Row():
|
1048 |
+
with gr.Column(scale=1):
|
1049 |
+
num_inference_steps = gr.Slider(
|
1050 |
+
minimum=1,
|
1051 |
+
maximum=60,
|
1052 |
+
value=4,
|
1053 |
+
step=1,
|
1054 |
+
label="Inference Steps",
|
1055 |
+
info="Number of inference step",
|
1056 |
+
)
|
1057 |
+
|
1058 |
+
sliding_window_stride = gr.Slider(
|
1059 |
+
minimum=1,
|
1060 |
+
maximum=40,
|
1061 |
+
value=24,
|
1062 |
+
step=1,
|
1063 |
+
label="Sliding Window Stride",
|
1064 |
+
info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task",
|
1065 |
+
visible=True,
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
use_dynamic_cfg = gr.Checkbox(
|
1069 |
+
label="Use Dynamic CFG",
|
1070 |
+
value=True,
|
1071 |
+
info="Use dynamic CFG",
|
1072 |
+
visible=False,
|
1073 |
+
)
|
1074 |
+
|
1075 |
+
raymap_option = gr.Radio(
|
1076 |
+
choices=["backward", "forward_right", "left_forward", "right"],
|
1077 |
+
label="Camera Movement Direction",
|
1078 |
+
value="forward_right",
|
1079 |
+
info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.",
|
1080 |
+
visible=False,
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
post_reconstruction = gr.Checkbox(
|
1084 |
+
label="Post-Reconstruction",
|
1085 |
+
value=True,
|
1086 |
+
info="Run reconstruction after prediction for better quality",
|
1087 |
+
visible=False,
|
1088 |
+
)
|
1089 |
+
|
1090 |
+
with gr.Accordion(
|
1091 |
+
"Advanced Options", open=False, visible=True
|
1092 |
+
) as advanced_options:
|
1093 |
+
with gr.Group(elem_classes=["advanced-section"]):
|
1094 |
+
with gr.Row():
|
1095 |
+
with gr.Column(scale=1):
|
1096 |
+
guidance_scale = gr.Slider(
|
1097 |
+
minimum=1.0,
|
1098 |
+
maximum=10.0,
|
1099 |
+
value=1.0,
|
1100 |
+
step=0.1,
|
1101 |
+
label="Guidance Scale",
|
1102 |
+
info="Guidance scale (only for prediction / planning)",
|
1103 |
+
)
|
1104 |
+
|
1105 |
+
with gr.Row():
|
1106 |
+
with gr.Column(scale=1):
|
1107 |
+
seed = gr.Number(
|
1108 |
+
value=42,
|
1109 |
+
label="Random Seed",
|
1110 |
+
info="Set a seed for reproducible results",
|
1111 |
+
precision=0,
|
1112 |
+
minimum=0,
|
1113 |
+
maximum=2147483647,
|
1114 |
+
)
|
1115 |
+
|
1116 |
+
with gr.Row():
|
1117 |
+
with gr.Column(scale=1):
|
1118 |
+
smooth_camera = gr.Checkbox(
|
1119 |
+
label="Smooth Camera",
|
1120 |
+
value=True,
|
1121 |
+
info="Apply smoothing to camera trajectory",
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
with gr.Column(scale=1):
|
1125 |
+
align_pointmaps = gr.Checkbox(
|
1126 |
+
label="Align Point Maps",
|
1127 |
+
value=False,
|
1128 |
+
info="Align point maps across frames",
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
with gr.Row():
|
1132 |
+
with gr.Column(scale=1):
|
1133 |
+
max_depth = gr.Slider(
|
1134 |
+
minimum=10,
|
1135 |
+
maximum=200,
|
1136 |
+
value=60,
|
1137 |
+
step=10,
|
1138 |
+
label="Max Depth",
|
1139 |
+
info="Maximum depth for point cloud (higher = more distant points)",
|
1140 |
+
)
|
1141 |
+
|
1142 |
+
with gr.Column(scale=1):
|
1143 |
+
rtol = gr.Slider(
|
1144 |
+
minimum=0.01,
|
1145 |
+
maximum=2.0,
|
1146 |
+
value=0.03,
|
1147 |
+
step=0.01,
|
1148 |
+
label="Relative Tolerance",
|
1149 |
+
info="Used for depth edge detection. Lower = remove more edges",
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
pointcloud_save_frame_interval = gr.Slider(
|
1153 |
+
minimum=1,
|
1154 |
+
maximum=20,
|
1155 |
+
value=10,
|
1156 |
+
step=1,
|
1157 |
+
label="Point Cloud Frame Interval",
|
1158 |
+
info="Save point cloud every N frames (higher = fewer files but less complete representation)",
|
1159 |
+
)
|
1160 |
+
|
1161 |
+
run_button = gr.Button("Run Aether", variant="primary")
|
1162 |
+
|
1163 |
+
with gr.Column(scale=1, elem_classes=["output-column"]):
|
1164 |
+
with gr.Group():
|
1165 |
+
gr.Markdown("## 📤 Output", elem_classes=["task-header"])
|
1166 |
+
|
1167 |
+
gr.Markdown("### RGB Video", elem_classes=["output-subtitle"])
|
1168 |
+
rgb_output = gr.Video(
|
1169 |
+
label="RGB Output", interactive=False, elem_id="rgb_output"
|
1170 |
+
)
|
1171 |
+
|
1172 |
+
gr.Markdown("### Depth Video", elem_classes=["output-subtitle"])
|
1173 |
+
depth_output = gr.Video(
|
1174 |
+
label="Depth Output", interactive=False, elem_id="depth_output"
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
gr.Markdown("### Point Clouds", elem_classes=["output-subtitle"])
|
1178 |
+
with gr.Row(elem_classes=["flex-display"]):
|
1179 |
+
pointcloud_frames = gr.Dropdown(
|
1180 |
+
label="Select Frame",
|
1181 |
+
choices=[],
|
1182 |
+
value=None,
|
1183 |
+
interactive=True,
|
1184 |
+
elem_id="pointcloud_frames",
|
1185 |
+
)
|
1186 |
+
pointcloud_download = gr.DownloadButton(
|
1187 |
+
label="Download Point Cloud",
|
1188 |
+
visible=False,
|
1189 |
+
elem_id="pointcloud_download",
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
model_output = gr.Model3D(
|
1193 |
+
label="Point Cloud Viewer", interactive=True, elem_id="model_output"
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
with gr.Tab("About Results"):
|
1197 |
+
gr.Markdown(
|
1198 |
+
"""
|
1199 |
+
### Understanding the Outputs
|
1200 |
+
|
1201 |
+
- **RGB Video**: Shows the predicted or reconstructed RGB frames
|
1202 |
+
- **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue)
|
1203 |
+
- **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids
|
1204 |
+
|
1205 |
+
<p class="warning">Note: 3D point clouds take a long time to visualize, and we show the keyframes only.
|
1206 |
+
You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`.</p>
|
1207 |
+
"""
|
1208 |
+
)
|
1209 |
+
|
1210 |
+
# Event handlers
|
1211 |
+
task.change(
|
1212 |
+
fn=update_task_ui,
|
1213 |
+
inputs=[task],
|
1214 |
+
outputs=[
|
1215 |
+
video_input,
|
1216 |
+
image_input,
|
1217 |
+
goal_input,
|
1218 |
+
image_preview,
|
1219 |
+
goal_preview,
|
1220 |
+
num_inference_steps,
|
1221 |
+
sliding_window_stride,
|
1222 |
+
use_dynamic_cfg,
|
1223 |
+
raymap_option,
|
1224 |
+
post_reconstruction,
|
1225 |
+
guidance_scale,
|
1226 |
+
],
|
1227 |
+
)
|
1228 |
+
|
1229 |
+
image_input.change(
|
1230 |
+
fn=update_image_preview, inputs=[image_input], outputs=[image_preview]
|
1231 |
+
).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
|
1232 |
+
|
1233 |
+
goal_input.change(
|
1234 |
+
fn=update_goal_preview, inputs=[goal_input], outputs=[goal_preview]
|
1235 |
+
).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
|
1236 |
+
|
1237 |
+
def update_pointcloud_frames(pointcloud_paths):
|
1238 |
+
"""Update the pointcloud frames dropdown with available frames."""
|
1239 |
+
if not pointcloud_paths:
|
1240 |
+
return gr.update(choices=[], value=None), None, gr.update(visible=False)
|
1241 |
+
|
1242 |
+
# Extract frame numbers from filenames
|
1243 |
+
frame_info = []
|
1244 |
+
for path in pointcloud_paths:
|
1245 |
+
filename = os.path.basename(path)
|
1246 |
+
match = re.search(r"frame_(\d+)", filename)
|
1247 |
+
if match:
|
1248 |
+
frame_num = int(match.group(1))
|
1249 |
+
frame_info.append((f"Frame {frame_num}", path))
|
1250 |
+
|
1251 |
+
# Sort by frame number
|
1252 |
+
frame_info.sort(key=lambda x: int(re.search(r"Frame (\d+)", x[0]).group(1)))
|
1253 |
+
|
1254 |
+
choices = [label for label, _ in frame_info]
|
1255 |
+
paths = [path for _, path in frame_info]
|
1256 |
+
|
1257 |
+
if not choices:
|
1258 |
+
return gr.update(choices=[], value=None), None, gr.update(visible=False)
|
1259 |
+
|
1260 |
+
# Make download button visible when we have point cloud files
|
1261 |
+
return (
|
1262 |
+
gr.update(choices=choices, value=choices[0]),
|
1263 |
+
paths[0],
|
1264 |
+
gr.update(visible=True),
|
1265 |
+
)
|
1266 |
+
|
1267 |
+
def select_pointcloud_frame(frame_label, all_paths):
|
1268 |
+
"""Select a specific pointcloud frame."""
|
1269 |
+
if not frame_label or not all_paths:
|
1270 |
+
return None
|
1271 |
+
|
1272 |
+
frame_num = int(re.search(r"Frame (\d+)", frame_label).group(1))
|
1273 |
+
|
1274 |
+
for path in all_paths:
|
1275 |
+
if f"frame_{frame_num}" in path:
|
1276 |
+
return path
|
1277 |
+
|
1278 |
+
return None
|
1279 |
+
|
1280 |
+
# Then in the run button click handler:
|
1281 |
+
def process_task(task_type, *args):
|
1282 |
+
"""Process selected task with appropriate function."""
|
1283 |
+
if task_type == "reconstruction":
|
1284 |
+
rgb_path, depth_path, pointcloud_paths = process_reconstruction(*args)
|
1285 |
+
# Update the pointcloud frames dropdown
|
1286 |
+
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
|
1287 |
+
pointcloud_paths
|
1288 |
+
)
|
1289 |
+
return (
|
1290 |
+
rgb_path,
|
1291 |
+
depth_path,
|
1292 |
+
initial_path,
|
1293 |
+
frame_dropdown,
|
1294 |
+
pointcloud_paths,
|
1295 |
+
download_visible,
|
1296 |
+
)
|
1297 |
+
elif task_type == "prediction":
|
1298 |
+
rgb_path, depth_path, pointcloud_paths = process_prediction(*args)
|
1299 |
+
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
|
1300 |
+
pointcloud_paths
|
1301 |
+
)
|
1302 |
+
return (
|
1303 |
+
rgb_path,
|
1304 |
+
depth_path,
|
1305 |
+
initial_path,
|
1306 |
+
frame_dropdown,
|
1307 |
+
pointcloud_paths,
|
1308 |
+
download_visible,
|
1309 |
+
)
|
1310 |
+
elif task_type == "planning":
|
1311 |
+
rgb_path, depth_path, pointcloud_paths = process_planning(*args)
|
1312 |
+
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
|
1313 |
+
pointcloud_paths
|
1314 |
+
)
|
1315 |
+
return (
|
1316 |
+
rgb_path,
|
1317 |
+
depth_path,
|
1318 |
+
initial_path,
|
1319 |
+
frame_dropdown,
|
1320 |
+
pointcloud_paths,
|
1321 |
+
download_visible,
|
1322 |
+
)
|
1323 |
+
return (
|
1324 |
+
None,
|
1325 |
+
None,
|
1326 |
+
None,
|
1327 |
+
gr.update(choices=[], value=None),
|
1328 |
+
[],
|
1329 |
+
gr.update(visible=False),
|
1330 |
+
)
|
1331 |
+
|
1332 |
+
# Store all pointcloud paths for later use
|
1333 |
+
all_pointcloud_paths = gr.State([])
|
1334 |
+
|
1335 |
+
run_button.click(
|
1336 |
+
fn=lambda task_type,
|
1337 |
+
video_file,
|
1338 |
+
image_file,
|
1339 |
+
goal_file,
|
1340 |
+
height,
|
1341 |
+
width,
|
1342 |
+
num_frames,
|
1343 |
+
num_inference_steps,
|
1344 |
+
guidance_scale,
|
1345 |
+
sliding_window_stride,
|
1346 |
+
use_dynamic_cfg,
|
1347 |
+
raymap_option,
|
1348 |
+
post_reconstruction,
|
1349 |
+
fps,
|
1350 |
+
smooth_camera,
|
1351 |
+
align_pointmaps,
|
1352 |
+
max_depth,
|
1353 |
+
rtol,
|
1354 |
+
pointcloud_save_frame_interval,
|
1355 |
+
seed: process_task(
|
1356 |
+
task_type,
|
1357 |
+
*(
|
1358 |
+
[
|
1359 |
+
video_file,
|
1360 |
+
height,
|
1361 |
+
width,
|
1362 |
+
num_frames,
|
1363 |
+
num_inference_steps,
|
1364 |
+
guidance_scale,
|
1365 |
+
sliding_window_stride,
|
1366 |
+
fps,
|
1367 |
+
smooth_camera,
|
1368 |
+
align_pointmaps,
|
1369 |
+
max_depth,
|
1370 |
+
rtol,
|
1371 |
+
pointcloud_save_frame_interval,
|
1372 |
+
seed,
|
1373 |
+
]
|
1374 |
+
if task_type == "reconstruction"
|
1375 |
+
else [
|
1376 |
+
image_file,
|
1377 |
+
height,
|
1378 |
+
width,
|
1379 |
+
num_frames,
|
1380 |
+
num_inference_steps,
|
1381 |
+
guidance_scale,
|
1382 |
+
use_dynamic_cfg,
|
1383 |
+
raymap_option,
|
1384 |
+
post_reconstruction,
|
1385 |
+
fps,
|
1386 |
+
smooth_camera,
|
1387 |
+
align_pointmaps,
|
1388 |
+
max_depth,
|
1389 |
+
rtol,
|
1390 |
+
pointcloud_save_frame_interval,
|
1391 |
+
seed,
|
1392 |
+
]
|
1393 |
+
if task_type == "prediction"
|
1394 |
+
else [
|
1395 |
+
image_file,
|
1396 |
+
goal_file,
|
1397 |
+
height,
|
1398 |
+
width,
|
1399 |
+
num_frames,
|
1400 |
+
num_inference_steps,
|
1401 |
+
guidance_scale,
|
1402 |
+
use_dynamic_cfg,
|
1403 |
+
post_reconstruction,
|
1404 |
+
fps,
|
1405 |
+
smooth_camera,
|
1406 |
+
align_pointmaps,
|
1407 |
+
max_depth,
|
1408 |
+
rtol,
|
1409 |
+
pointcloud_save_frame_interval,
|
1410 |
+
seed,
|
1411 |
+
]
|
1412 |
+
),
|
1413 |
+
),
|
1414 |
+
inputs=[
|
1415 |
+
task,
|
1416 |
+
video_input,
|
1417 |
+
image_input,
|
1418 |
+
goal_input,
|
1419 |
+
height,
|
1420 |
+
width,
|
1421 |
+
num_frames,
|
1422 |
+
num_inference_steps,
|
1423 |
+
guidance_scale,
|
1424 |
+
sliding_window_stride,
|
1425 |
+
use_dynamic_cfg,
|
1426 |
+
raymap_option,
|
1427 |
+
post_reconstruction,
|
1428 |
+
fps,
|
1429 |
+
smooth_camera,
|
1430 |
+
align_pointmaps,
|
1431 |
+
max_depth,
|
1432 |
+
rtol,
|
1433 |
+
pointcloud_save_frame_interval,
|
1434 |
+
seed,
|
1435 |
+
],
|
1436 |
+
outputs=[
|
1437 |
+
rgb_output,
|
1438 |
+
depth_output,
|
1439 |
+
model_output,
|
1440 |
+
pointcloud_frames,
|
1441 |
+
all_pointcloud_paths,
|
1442 |
+
pointcloud_download,
|
1443 |
+
],
|
1444 |
+
)
|
1445 |
+
|
1446 |
+
pointcloud_frames.change(
|
1447 |
+
fn=select_pointcloud_frame,
|
1448 |
+
inputs=[pointcloud_frames, all_pointcloud_paths],
|
1449 |
+
outputs=[model_output],
|
1450 |
+
).then(
|
1451 |
+
fn=get_download_link,
|
1452 |
+
inputs=[pointcloud_frames, all_pointcloud_paths],
|
1453 |
+
outputs=[pointcloud_download],
|
1454 |
+
)
|
1455 |
+
|
1456 |
+
# Example Accordion
|
1457 |
+
with gr.Accordion("Examples"):
|
1458 |
+
gr.Markdown(
|
1459 |
+
"""
|
1460 |
+
### Examples will be added soon
|
1461 |
+
Check back for example inputs for each task type.
|
1462 |
+
"""
|
1463 |
+
)
|
1464 |
+
|
1465 |
+
# Load the model at startup
|
1466 |
+
demo.load(lambda: build_pipeline(), inputs=None, outputs=None)
|
1467 |
+
|
1468 |
+
if __name__ == "__main__":
|
1469 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
1470 |
+
demo.queue(max_size=20).launch(show_error=True, share=True)
|
setup.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import pathlib
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import pkg_resources
|
6 |
+
from setuptools import find_packages, setup
|
7 |
+
|
8 |
+
|
9 |
+
PKG_NAME = "aether"
|
10 |
+
VERSION = "0.1"
|
11 |
+
EXTRAS = {}
|
12 |
+
|
13 |
+
|
14 |
+
def _read_file(fname):
|
15 |
+
with pathlib.Path(fname).open() as fp:
|
16 |
+
return fp.read()
|
17 |
+
|
18 |
+
|
19 |
+
def _read_install_requires():
|
20 |
+
with pathlib.Path("requirements.txt").open() as fp:
|
21 |
+
return [
|
22 |
+
str(requirement) for requirement in pkg_resources.parse_requirements(fp)
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
def _fill_extras(extras):
|
27 |
+
if extras:
|
28 |
+
extras["all"] = list({item for group in extras.values() for item in group})
|
29 |
+
return extras
|
30 |
+
|
31 |
+
|
32 |
+
version_range_max = max(sys.version_info[1], 10) + 1
|
33 |
+
setup(
|
34 |
+
name=PKG_NAME,
|
35 |
+
version=VERSION,
|
36 |
+
author="Aether Team",
|
37 |
+
author_email="[email protected]",
|
38 |
+
url="https://github.com/OpenRobotLab/Aether",
|
39 |
+
description="",
|
40 |
+
long_description=_read_file("README.md"),
|
41 |
+
long_description_content_type="text/markdown",
|
42 |
+
keywords=[
|
43 |
+
"Deep Learning",
|
44 |
+
"Machine Learning",
|
45 |
+
"World Model",
|
46 |
+
"3D Vision",
|
47 |
+
"Reconstruction",
|
48 |
+
"Sythetic Data",
|
49 |
+
"Embodied AI",
|
50 |
+
],
|
51 |
+
license="MIT License",
|
52 |
+
packages=find_packages(include=f"{PKG_NAME}.*"),
|
53 |
+
include_package_data=True,
|
54 |
+
zip_safe=False,
|
55 |
+
install_requires=_read_install_requires(),
|
56 |
+
extras_require=_fill_extras(EXTRAS),
|
57 |
+
python_requires=">=3.8",
|
58 |
+
classifiers=[
|
59 |
+
"Development Status :: 5 - Production/Stable",
|
60 |
+
"Intended Audience :: Developers",
|
61 |
+
"Intended Audience :: Education",
|
62 |
+
"Intended Audience :: Science/Research",
|
63 |
+
"License :: OSI Approved :: MIT License",
|
64 |
+
"Operating System :: OS Independent",
|
65 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
66 |
+
"Programming Language :: Python :: 3",
|
67 |
+
]
|
68 |
+
+ [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
|
69 |
+
)
|