Upload models
Browse files- README.md +133 -3
- config.json +53 -0
- configuration_internimage.py +121 -0
- dcnv3.py +356 -0
- dcnv3_func.py +222 -0
- model.safetensors +3 -0
- modeling_internimage.py +899 -0
- preprocessor_config.json +19 -0
README.md
CHANGED
@@ -1,3 +1,133 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
pipeline_tag: image-classification
|
4 |
+
library_name: transformers
|
5 |
+
tags:
|
6 |
+
- internimage
|
7 |
+
- custom_code
|
8 |
+
datasets:
|
9 |
+
- ILSVRC/imagenet-1k
|
10 |
+
---
|
11 |
+
|
12 |
+
# InternImage Model Card
|
13 |
+
|
14 |
+
## Introduction
|
15 |
+
|
16 |
+
InternImage is an advanced vision foundation model developed by researchers from Shanghai AI Laboratory, Tsinghua University, and other institutions. Unlike models based on Transformers, InternImage employs DCNv3 as its core operator. This approach equips the model with dynamic and effective receptive fields required for downstream tasks like object detection and segmentation, while enabling adaptive spatial aggregation.
|
17 |
+
|
18 |
+
<div style="text-align: center;"> <img src="https://github.com/OpenGVLab/InternImage/raw/master/docs/figs/arch.png" style="width:60%;" /> </div>
|
19 |
+
|
20 |
+
## Performance
|
21 |
+
|
22 |
+
- InternImage achieved an impressive Top-1 accuracy of 90.1% on the ImageNet benchmark dataset using only publicly available data for image classification. Apart from two undisclosed models trained with additional datasets by Google and Microsoft, InternImage is the only open-source model that achieves a Top-1 accuracy of over 90.0%, and it is also the largest model in scale worldwide.
|
23 |
+
- InternImage outperformed all other models worldwide on the COCO object detection benchmark dataset with a remarkable mAP of 65.5, making it the only model that surpasses 65 mAP in the world.
|
24 |
+
- InternImage also demonstrated world's best performance on 16 other important visual benchmark datasets, covering a wide range of tasks such as classification, detection, and segmentation, making it the top-performing model across multiple domains.
|
25 |
+
|
26 |
+
## Released Models
|
27 |
+
|
28 |
+
### Open‑Source Visual Pretrained Models
|
29 |
+
|
30 |
+
| huggingface name | model name | pretrain | resolution | #param |
|
31 |
+
| :-------------------------------------------------------------------------------------------: | :------------: | :------------------: | :--------: | :----: |
|
32 |
+
| [internimage_l_22k_384](https://huggingface.co/OpenGVLab/internimage_l_22k_384) | InternImage-L | IN-22K | 384x384 | 223M |
|
33 |
+
| [internimage_xl_22k_384](https://huggingface.co/OpenGVLab/internimage_xl_22k_384) | InternImage-XL | IN-22K | 384x384 | 335M |
|
34 |
+
| [internimage_h_jointto22k_384](https://huggingface.co/OpenGVLab/internimage_h_jointto22k_384) | InternImage-H | Joint 427M -> IN-22K | 384x384 | 1.08B |
|
35 |
+
| [internimage_g_jointto22k_384](https://huggingface.co/OpenGVLab/internimage_g_jointto22k_384) | InternImage-G | Joint 427M -> IN-22K | 384x384 | 3B |
|
36 |
+
|
37 |
+
### ImageNet-1K Image Classification
|
38 |
+
|
39 |
+
| huggingface name | model name | pretrain | resolution | acc@1 | #param | FLOPs |
|
40 |
+
| :---------------------------------------------------------------------------------------: | :------------: | :------------------: | :--------: | :---: | :----: | :---: |
|
41 |
+
| [internimage_t_1k_224](https://huggingface.co/OpenGVLab/internimage_t_1k_224) | InternImage-T | IN-1K | 224x224 | 83.5 | 30M | 5G |
|
42 |
+
| [internimage_s_1k_224](https://huggingface.co/OpenGVLab/internimage_s_1k_224) | InternImage-S | IN-1K | 224x224 | 84.2 | 50M | 8G |
|
43 |
+
| [internimage_b_1k_224](https://huggingface.co/OpenGVLab/internimage_b_1k_224) | InternImage-B | IN-1K | 224x224 | 84.9 | 97M | 16G |
|
44 |
+
| [internimage_l_22kto1k_384](https://huggingface.co/OpenGVLab/internimage_l_22kto1k_384) | InternImage-L | IN-22K | 384x384 | 87.7 | 223M | 108G |
|
45 |
+
| [internimage_xl_22kto1k_384](https://huggingface.co/OpenGVLab/internimage_xl_22kto1k_384) | InternImage-XL | IN-22K | 384x384 | 88.0 | 335M | 163G |
|
46 |
+
| [internimage_h_22kto1k_640](https://huggingface.co/OpenGVLab/internimage_h_22kto1k_640) | InternImage-H | Joint 427M -> IN-22K | 640x640 | 89.6 | 1.08B | 1478G |
|
47 |
+
| [internimage_g_22kto1k_512](https://huggingface.co/OpenGVLab/internimage_g_22kto1k_512) | InternImage-G | Joint 427M -> IN-22K | 512x512 | 90.1 | 3B | 2700G |
|
48 |
+
|
49 |
+
## DCNv3 CUDA Kernel Installation
|
50 |
+
|
51 |
+
If you do not install the CUDA version of DCNv3, InternImage will automatically fall back to a PyTorch implementation. However, the CUDA implementation can significantly reduce GPU memory usage and improve inference efficiency.
|
52 |
+
|
53 |
+
**Installation Tutorial:**
|
54 |
+
|
55 |
+
1. Open your terminal and run:
|
56 |
+
|
57 |
+
```bash
|
58 |
+
git clone https://github.com/OpenGVLab/InternImage.git
|
59 |
+
cd InternImage/classification/ops_dcnv3
|
60 |
+
```
|
61 |
+
|
62 |
+
2. Make sure you have an available GPU for compilation, then run:
|
63 |
+
|
64 |
+
```bash
|
65 |
+
sh make.sh
|
66 |
+
```
|
67 |
+
|
68 |
+
This will compile the CUDA version of DCNv3. Once installed, InternImage will automatically leverage the optimized CUDA implementation for better performance.
|
69 |
+
|
70 |
+
## Usage with Transformers
|
71 |
+
|
72 |
+
Below are two usage examples for InternImage with the Transformers framework:
|
73 |
+
|
74 |
+
### Example 1: Using InternImage as an Image Backbone
|
75 |
+
|
76 |
+
```python
|
77 |
+
import torch
|
78 |
+
from PIL import Image
|
79 |
+
from transformers import AutoModel, CLIPImageProcessor
|
80 |
+
|
81 |
+
# Replace 'model_name' with the appropriate model identifier
|
82 |
+
model_name = "OpenGVLab/internimage_t_1k_224" # example model
|
83 |
+
|
84 |
+
# Prepare the image
|
85 |
+
image_path = 'img.png'
|
86 |
+
image_processor = CLIPImageProcessor.from_pretrained(model_name)
|
87 |
+
image = Image.open(image_path)
|
88 |
+
image = image_processor(images=image, return_tensors='pt').pixel_values
|
89 |
+
print('image shape:', image.shape)
|
90 |
+
|
91 |
+
# Load the model as a backbone
|
92 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
93 |
+
# 'hidden_states' contains the outputs from the 4 stages of the InternImage backbone
|
94 |
+
hidden_states = model(image).hidden_states
|
95 |
+
```
|
96 |
+
|
97 |
+
### Example 2: Using InternImage for Image Classification
|
98 |
+
|
99 |
+
```python
|
100 |
+
import torch
|
101 |
+
from PIL import Image
|
102 |
+
from transformers import AutoModelForImageClassification, CLIPImageProcessor
|
103 |
+
|
104 |
+
# Replace 'model_name' with the appropriate model identifier
|
105 |
+
model_name = "OpenGVLab/internimage_t_1k_224" # example model
|
106 |
+
|
107 |
+
# Prepare the image
|
108 |
+
image_path = 'img.png'
|
109 |
+
image_processor = CLIPImageProcessor.from_pretrained(model_name)
|
110 |
+
image = Image.open(image_path)
|
111 |
+
image = image_processor(images=image, return_tensors='pt').pixel_values
|
112 |
+
print('image shape:', image.shape)
|
113 |
+
|
114 |
+
# Load the model as an image classifier
|
115 |
+
model = AutoModelForImageClassification.from_pretrained(model_name, trust_remote_code=True)
|
116 |
+
logits = model(image).logits
|
117 |
+
label = torch.argmax(logits, dim=1)
|
118 |
+
print("Predicted label:", label.item())
|
119 |
+
```
|
120 |
+
|
121 |
+
## Citation
|
122 |
+
|
123 |
+
If this work is helpful for your research, please consider citing the following BibTeX entry.
|
124 |
+
|
125 |
+
```Bibtex
|
126 |
+
@inproceedings{wang2023internimage,
|
127 |
+
title={Internimage: Exploring large-scale vision foundation models with deformable convolutions},
|
128 |
+
author={Wang, Wenhai and Dai, Jifeng and Chen, Zhe and Huang, Zhenhang and Li, Zhiqi and Zhu, Xizhou and Hu, Xiaowei and Lu, Tong and Lu, Lewei and Li, Hongsheng and others},
|
129 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
130 |
+
pages={14408--14419},
|
131 |
+
year={2023}
|
132 |
+
}
|
133 |
+
```
|
config.json
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "OpenGVLab/internimage_h_jointto22k_384",
|
3 |
+
"act_layer": "GELU",
|
4 |
+
"architectures": [
|
5 |
+
"InternImageModel"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_internimage.InternImageConfig",
|
9 |
+
"AutoModel": "modeling_internimage.InternImageModel",
|
10 |
+
"AutoModelForImageClassification": "modeling_internimage.InternImageModelForImageClassification"
|
11 |
+
},
|
12 |
+
"center_feature_scale": true,
|
13 |
+
"channels": 320,
|
14 |
+
"cls_scale": 1.5,
|
15 |
+
"core_op": "DCNv3",
|
16 |
+
"depths": [
|
17 |
+
6,
|
18 |
+
6,
|
19 |
+
32,
|
20 |
+
6
|
21 |
+
],
|
22 |
+
"drop_path_rate": 0.0,
|
23 |
+
"drop_path_type": "linear",
|
24 |
+
"drop_rate": 0.0,
|
25 |
+
"dw_kernel_size": 5,
|
26 |
+
"groups": [
|
27 |
+
10,
|
28 |
+
20,
|
29 |
+
40,
|
30 |
+
80
|
31 |
+
],
|
32 |
+
"layer_scale": null,
|
33 |
+
"level2_post_norm": true,
|
34 |
+
"level2_post_norm_block_ids": [
|
35 |
+
5,
|
36 |
+
11,
|
37 |
+
17,
|
38 |
+
23,
|
39 |
+
29
|
40 |
+
],
|
41 |
+
"mlp_ratio": 4.0,
|
42 |
+
"model_type": "internimage",
|
43 |
+
"norm_layer": "LN",
|
44 |
+
"num_classes": 21841,
|
45 |
+
"offset_scale": 1.0,
|
46 |
+
"post_norm": false,
|
47 |
+
"remove_center": false,
|
48 |
+
"res_post_norm": true,
|
49 |
+
"torch_dtype": "float32",
|
50 |
+
"transformers_version": "4.37.2",
|
51 |
+
"use_clip_projector": true,
|
52 |
+
"with_cp": false
|
53 |
+
}
|
configuration_internimage.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternImage
|
3 |
+
# Copyright (c) 2025 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
from transformers import PretrainedConfig
|
8 |
+
|
9 |
+
|
10 |
+
class InternImageConfig(PretrainedConfig):
|
11 |
+
r"""
|
12 |
+
This is the configuration class to store the configuration of a [`~InternImageModel`].
|
13 |
+
It is used to instantiate an internimage model according to the specified arguments, defining the model
|
14 |
+
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
15 |
+
the internimage [OpenGVLab/internimage](https://huggingface.co/OpenGVLab/internimage) architecture.
|
16 |
+
|
17 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used
|
18 |
+
to control the model outputs. Read the documentation from [`PretrainedConfig`]
|
19 |
+
for more information.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
core_op (`str`, *optional*, defaults to `"DCNv3"`):
|
23 |
+
Core operation used in the InternImageModel.
|
24 |
+
depths (`tuple`, *optional*, defaults to `(4, 4, 18, 4)`):
|
25 |
+
Tuple specifying the depth of layers in the InternImageModel.
|
26 |
+
groups (`tuple`, *optional*, defaults to `(4, 8, 16, 32)`):
|
27 |
+
Tuple specifying the group of layers in the InternImageModel.
|
28 |
+
channels (`int`, *optional*, defaults to `64`):
|
29 |
+
Number of channels in the InternImageModel.
|
30 |
+
dw_kernel_size (`int`, *optional*, defaults to `None`):
|
31 |
+
Kernel size for depthwise convolutions.
|
32 |
+
layer_scale (`float`, *optional*, defaults to `None`):
|
33 |
+
Scale of the layers in the model.
|
34 |
+
offset_scale (`float`, *optional*, defaults to `1.0`):
|
35 |
+
Offset scale in the model.
|
36 |
+
mlp_ratio (`float`, *optional*, defaults to `4.0`):
|
37 |
+
Ratio of mlp layers in the InternImageModel.
|
38 |
+
post_norm (`bool`, *optional*, defaults to `False`):
|
39 |
+
Whether to use post normalization in the model.
|
40 |
+
level2_post_norm (`bool`, *optional*, defaults to `False`):
|
41 |
+
Whether to use level 2 post normalization.
|
42 |
+
level2_post_norm_block_ids (`list`, *optional*, defaults to `None`):
|
43 |
+
Specific block IDs for level 2 post normalization.
|
44 |
+
center_feature_scale (`bool`, *optional*, defaults to `False`):
|
45 |
+
Whether to apply center feature scaling.
|
46 |
+
use_clip_projector (`bool`, *optional*, defaults to `False`):
|
47 |
+
Whether to use CLIP projector.
|
48 |
+
remove_center (`bool`, *optional*, defaults to `False`):
|
49 |
+
Whether to remove center pixels in some operations.
|
50 |
+
num_classes (`int`, *optional*, defaults to `1000`):
|
51 |
+
Number of classes for the model output.
|
52 |
+
drop_rate (`float`, *optional*, defaults to `0.0`):
|
53 |
+
Dropout rate in the model.
|
54 |
+
drop_path_rate (`float`, *optional*, defaults to `0.0`):
|
55 |
+
Dropout path rate in the model.
|
56 |
+
drop_path_type (`str`, *optional*, defaults to `"linear"`):
|
57 |
+
Type of dropout path used in the model.
|
58 |
+
act_layer (`str`, *optional*, defaults to `"GELU"`):
|
59 |
+
Activation function used in the model.
|
60 |
+
norm_layer (`str`, *optional*, defaults to `"LN"`):
|
61 |
+
Normalization layer used in the model.
|
62 |
+
cls_scale (`float`, *optional*, defaults to `1.5`):
|
63 |
+
Scale of the classification layer in the model.
|
64 |
+
with_cp (`bool`, *optional*, defaults to `False`):
|
65 |
+
Whether to use checkpointing in the model.
|
66 |
+
"""
|
67 |
+
model_type = 'internimage'
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
core_op='DCNv3',
|
72 |
+
depths=(4, 4, 18, 4),
|
73 |
+
groups=(4, 8, 16, 32),
|
74 |
+
channels=64,
|
75 |
+
dw_kernel_size=None,
|
76 |
+
layer_scale=None,
|
77 |
+
offset_scale=1.0,
|
78 |
+
mlp_ratio=4.0,
|
79 |
+
post_norm=False,
|
80 |
+
res_post_norm=False,
|
81 |
+
level2_post_norm=False,
|
82 |
+
level2_post_norm_block_ids=None,
|
83 |
+
center_feature_scale=False,
|
84 |
+
use_clip_projector=False,
|
85 |
+
remove_center=False,
|
86 |
+
num_classes=1000,
|
87 |
+
drop_rate=0.0,
|
88 |
+
drop_path_rate=0.0,
|
89 |
+
drop_path_type='linear',
|
90 |
+
act_layer='GELU',
|
91 |
+
norm_layer='LN',
|
92 |
+
cls_scale=1.5,
|
93 |
+
with_cp=False,
|
94 |
+
**kwargs,
|
95 |
+
):
|
96 |
+
super().__init__(**kwargs)
|
97 |
+
|
98 |
+
# Model configuration parameters
|
99 |
+
self.core_op = core_op
|
100 |
+
self.depths = depths
|
101 |
+
self.groups = groups
|
102 |
+
self.channels = channels
|
103 |
+
self.dw_kernel_size = dw_kernel_size
|
104 |
+
self.layer_scale = layer_scale
|
105 |
+
self.offset_scale = offset_scale
|
106 |
+
self.mlp_ratio = mlp_ratio
|
107 |
+
self.post_norm = post_norm
|
108 |
+
self.res_post_norm = res_post_norm
|
109 |
+
self.level2_post_norm = level2_post_norm
|
110 |
+
self.level2_post_norm_block_ids = level2_post_norm_block_ids
|
111 |
+
self.center_feature_scale = center_feature_scale
|
112 |
+
self.use_clip_projector = use_clip_projector
|
113 |
+
self.remove_center = remove_center
|
114 |
+
self.num_classes = num_classes
|
115 |
+
self.drop_rate = drop_rate
|
116 |
+
self.drop_path_rate = drop_path_rate
|
117 |
+
self.drop_path_type = drop_path_type
|
118 |
+
self.act_layer = act_layer
|
119 |
+
self.norm_layer = norm_layer
|
120 |
+
self.cls_scale = cls_scale
|
121 |
+
self.with_cp = with_cp
|
dcnv3.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternImage
|
3 |
+
# Copyright (c) 2025 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
from __future__ import absolute_import, division, print_function
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn.init import constant_, xavier_uniform_
|
15 |
+
|
16 |
+
from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch, has_cuda_kernel
|
17 |
+
|
18 |
+
|
19 |
+
class to_channels_first(nn.Module):
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return x.permute(0, 3, 1, 2)
|
26 |
+
|
27 |
+
|
28 |
+
class to_channels_last(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return x.permute(0, 2, 3, 1)
|
35 |
+
|
36 |
+
|
37 |
+
def build_norm_layer(dim,
|
38 |
+
norm_layer,
|
39 |
+
in_format='channels_last',
|
40 |
+
out_format='channels_last',
|
41 |
+
eps=1e-6):
|
42 |
+
layers = []
|
43 |
+
if norm_layer == 'BN':
|
44 |
+
if in_format == 'channels_last':
|
45 |
+
layers.append(to_channels_first())
|
46 |
+
layers.append(nn.BatchNorm2d(dim))
|
47 |
+
if out_format == 'channels_last':
|
48 |
+
layers.append(to_channels_last())
|
49 |
+
elif norm_layer == 'LN':
|
50 |
+
if in_format == 'channels_first':
|
51 |
+
layers.append(to_channels_last())
|
52 |
+
layers.append(nn.LayerNorm(dim, eps=eps))
|
53 |
+
if out_format == 'channels_first':
|
54 |
+
layers.append(to_channels_first())
|
55 |
+
else:
|
56 |
+
raise NotImplementedError(
|
57 |
+
f'build_norm_layer does not support {norm_layer}')
|
58 |
+
return nn.Sequential(*layers)
|
59 |
+
|
60 |
+
|
61 |
+
def build_act_layer(act_layer):
|
62 |
+
if act_layer == 'ReLU':
|
63 |
+
return nn.ReLU(inplace=True)
|
64 |
+
elif act_layer == 'SiLU':
|
65 |
+
return nn.SiLU(inplace=True)
|
66 |
+
elif act_layer == 'GELU':
|
67 |
+
return nn.GELU()
|
68 |
+
|
69 |
+
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
|
70 |
+
|
71 |
+
|
72 |
+
def _is_power_of_2(n):
|
73 |
+
if (not isinstance(n, int)) or (n < 0):
|
74 |
+
raise ValueError(
|
75 |
+
'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n)))
|
76 |
+
|
77 |
+
return (n & (n - 1) == 0) and n != 0
|
78 |
+
|
79 |
+
|
80 |
+
class CenterFeatureScaleModule(nn.Module):
|
81 |
+
def forward(self,
|
82 |
+
query,
|
83 |
+
center_feature_scale_proj_weight,
|
84 |
+
center_feature_scale_proj_bias):
|
85 |
+
center_feature_scale = F.linear(query,
|
86 |
+
weight=center_feature_scale_proj_weight,
|
87 |
+
bias=center_feature_scale_proj_bias).sigmoid()
|
88 |
+
return center_feature_scale
|
89 |
+
|
90 |
+
|
91 |
+
class DCNv3_pytorch(nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
channels=64,
|
95 |
+
kernel_size=3,
|
96 |
+
dw_kernel_size=None,
|
97 |
+
stride=1,
|
98 |
+
pad=1,
|
99 |
+
dilation=1,
|
100 |
+
group=4,
|
101 |
+
offset_scale=1.0,
|
102 |
+
act_layer='GELU',
|
103 |
+
norm_layer='LN',
|
104 |
+
center_feature_scale=False,
|
105 |
+
remove_center=False,
|
106 |
+
):
|
107 |
+
"""
|
108 |
+
DCNv3 Module
|
109 |
+
:param channels
|
110 |
+
:param kernel_size
|
111 |
+
:param stride
|
112 |
+
:param pad
|
113 |
+
:param dilation
|
114 |
+
:param group
|
115 |
+
:param offset_scale
|
116 |
+
:param act_layer
|
117 |
+
:param norm_layer
|
118 |
+
"""
|
119 |
+
super().__init__()
|
120 |
+
if channels % group != 0:
|
121 |
+
raise ValueError(
|
122 |
+
f'channels must be divisible by group, but got {channels} and {group}')
|
123 |
+
_d_per_group = channels // group
|
124 |
+
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
|
125 |
+
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
|
126 |
+
if not _is_power_of_2(_d_per_group):
|
127 |
+
warnings.warn(
|
128 |
+
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
|
129 |
+
'which is more efficient in our CUDA implementation.')
|
130 |
+
|
131 |
+
self.offset_scale = offset_scale
|
132 |
+
self.channels = channels
|
133 |
+
self.kernel_size = kernel_size
|
134 |
+
self.dw_kernel_size = dw_kernel_size
|
135 |
+
self.stride = stride
|
136 |
+
self.dilation = dilation
|
137 |
+
self.pad = pad
|
138 |
+
self.group = group
|
139 |
+
self.group_channels = channels // group
|
140 |
+
self.offset_scale = offset_scale
|
141 |
+
self.center_feature_scale = center_feature_scale
|
142 |
+
self.remove_center = int(remove_center)
|
143 |
+
|
144 |
+
self.dw_conv = nn.Sequential(
|
145 |
+
nn.Conv2d(
|
146 |
+
channels,
|
147 |
+
channels,
|
148 |
+
kernel_size=dw_kernel_size,
|
149 |
+
stride=1,
|
150 |
+
padding=(dw_kernel_size - 1) // 2,
|
151 |
+
groups=channels),
|
152 |
+
build_norm_layer(
|
153 |
+
channels,
|
154 |
+
norm_layer,
|
155 |
+
'channels_first',
|
156 |
+
'channels_last'),
|
157 |
+
build_act_layer(act_layer))
|
158 |
+
self.offset = nn.Linear(
|
159 |
+
channels,
|
160 |
+
group * (kernel_size * kernel_size - remove_center) * 2)
|
161 |
+
self.mask = nn.Linear(
|
162 |
+
channels,
|
163 |
+
group * (kernel_size * kernel_size - remove_center))
|
164 |
+
self.input_proj = nn.Linear(channels, channels)
|
165 |
+
self.output_proj = nn.Linear(channels, channels)
|
166 |
+
self._reset_parameters()
|
167 |
+
|
168 |
+
if center_feature_scale:
|
169 |
+
self.center_feature_scale_proj_weight = nn.Parameter(
|
170 |
+
torch.zeros((group, channels), dtype=torch.float))
|
171 |
+
self.center_feature_scale_proj_bias = nn.Parameter(
|
172 |
+
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
|
173 |
+
self.center_feature_scale_module = CenterFeatureScaleModule()
|
174 |
+
|
175 |
+
def _reset_parameters(self):
|
176 |
+
constant_(self.offset.weight.data, 0.)
|
177 |
+
constant_(self.offset.bias.data, 0.)
|
178 |
+
constant_(self.mask.weight.data, 0.)
|
179 |
+
constant_(self.mask.bias.data, 0.)
|
180 |
+
xavier_uniform_(self.input_proj.weight.data)
|
181 |
+
constant_(self.input_proj.bias.data, 0.)
|
182 |
+
xavier_uniform_(self.output_proj.weight.data)
|
183 |
+
constant_(self.output_proj.bias.data, 0.)
|
184 |
+
|
185 |
+
def forward(self, input):
|
186 |
+
"""
|
187 |
+
:param query (N, H, W, C)
|
188 |
+
:return output (N, H, W, C)
|
189 |
+
"""
|
190 |
+
N, H, W, _ = input.shape
|
191 |
+
|
192 |
+
x = self.input_proj(input)
|
193 |
+
x_proj = x
|
194 |
+
|
195 |
+
x1 = input.permute(0, 3, 1, 2)
|
196 |
+
x1 = self.dw_conv(x1)
|
197 |
+
offset = self.offset(x1)
|
198 |
+
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
|
199 |
+
mask = F.softmax(mask, -1).reshape(N, H, W, -1)
|
200 |
+
|
201 |
+
x = dcnv3_core_pytorch(
|
202 |
+
x, offset, mask,
|
203 |
+
self.kernel_size, self.kernel_size,
|
204 |
+
self.stride, self.stride,
|
205 |
+
self.pad, self.pad,
|
206 |
+
self.dilation, self.dilation,
|
207 |
+
self.group, self.group_channels,
|
208 |
+
self.offset_scale, self.remove_center)
|
209 |
+
if self.center_feature_scale:
|
210 |
+
center_feature_scale = self.center_feature_scale_module(
|
211 |
+
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
|
212 |
+
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
|
213 |
+
center_feature_scale = center_feature_scale[..., None].repeat(
|
214 |
+
1, 1, 1, 1, self.channels // self.group).flatten(-2)
|
215 |
+
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
|
216 |
+
x = self.output_proj(x)
|
217 |
+
|
218 |
+
return x
|
219 |
+
|
220 |
+
|
221 |
+
class DCNv3(nn.Module):
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
channels=64,
|
225 |
+
kernel_size=3,
|
226 |
+
dw_kernel_size=None,
|
227 |
+
stride=1,
|
228 |
+
pad=1,
|
229 |
+
dilation=1,
|
230 |
+
group=4,
|
231 |
+
offset_scale=1.0,
|
232 |
+
act_layer='GELU',
|
233 |
+
norm_layer='LN',
|
234 |
+
center_feature_scale=False,
|
235 |
+
remove_center=False,
|
236 |
+
):
|
237 |
+
"""
|
238 |
+
DCNv3 Module
|
239 |
+
:param channels
|
240 |
+
:param kernel_size
|
241 |
+
:param stride
|
242 |
+
:param pad
|
243 |
+
:param dilation
|
244 |
+
:param group
|
245 |
+
:param offset_scale
|
246 |
+
:param act_layer
|
247 |
+
:param norm_layer
|
248 |
+
"""
|
249 |
+
super().__init__()
|
250 |
+
if channels % group != 0:
|
251 |
+
raise ValueError(
|
252 |
+
f'channels must be divisible by group, but got {channels} and {group}')
|
253 |
+
_d_per_group = channels // group
|
254 |
+
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size
|
255 |
+
# you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
|
256 |
+
if not _is_power_of_2(_d_per_group):
|
257 |
+
warnings.warn(
|
258 |
+
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
|
259 |
+
'which is more efficient in our CUDA implementation.')
|
260 |
+
|
261 |
+
self.offset_scale = offset_scale
|
262 |
+
self.channels = channels
|
263 |
+
self.kernel_size = kernel_size
|
264 |
+
self.dw_kernel_size = dw_kernel_size
|
265 |
+
self.stride = stride
|
266 |
+
self.dilation = dilation
|
267 |
+
self.pad = pad
|
268 |
+
self.group = group
|
269 |
+
self.group_channels = channels // group
|
270 |
+
self.offset_scale = offset_scale
|
271 |
+
self.center_feature_scale = center_feature_scale
|
272 |
+
self.remove_center = int(remove_center)
|
273 |
+
|
274 |
+
if self.remove_center and self.kernel_size % 2 == 0:
|
275 |
+
raise ValueError('remove_center is only compatible with odd kernel size.')
|
276 |
+
|
277 |
+
self.dw_conv = nn.Sequential(
|
278 |
+
nn.Conv2d(
|
279 |
+
channels,
|
280 |
+
channels,
|
281 |
+
kernel_size=dw_kernel_size,
|
282 |
+
stride=1,
|
283 |
+
padding=(dw_kernel_size - 1) // 2,
|
284 |
+
groups=channels),
|
285 |
+
build_norm_layer(
|
286 |
+
channels,
|
287 |
+
norm_layer,
|
288 |
+
'channels_first',
|
289 |
+
'channels_last'),
|
290 |
+
build_act_layer(act_layer))
|
291 |
+
self.offset = nn.Linear(
|
292 |
+
channels,
|
293 |
+
group * (kernel_size * kernel_size - remove_center) * 2)
|
294 |
+
self.mask = nn.Linear(
|
295 |
+
channels,
|
296 |
+
group * (kernel_size * kernel_size - remove_center))
|
297 |
+
self.input_proj = nn.Linear(channels, channels)
|
298 |
+
self.output_proj = nn.Linear(channels, channels)
|
299 |
+
self._reset_parameters()
|
300 |
+
|
301 |
+
if center_feature_scale:
|
302 |
+
self.center_feature_scale_proj_weight = nn.Parameter(
|
303 |
+
torch.zeros((group, channels), dtype=torch.float))
|
304 |
+
self.center_feature_scale_proj_bias = nn.Parameter(
|
305 |
+
torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ))
|
306 |
+
self.center_feature_scale_module = CenterFeatureScaleModule()
|
307 |
+
|
308 |
+
def _reset_parameters(self):
|
309 |
+
constant_(self.offset.weight.data, 0.)
|
310 |
+
constant_(self.offset.bias.data, 0.)
|
311 |
+
constant_(self.mask.weight.data, 0.)
|
312 |
+
constant_(self.mask.bias.data, 0.)
|
313 |
+
xavier_uniform_(self.input_proj.weight.data)
|
314 |
+
constant_(self.input_proj.bias.data, 0.)
|
315 |
+
xavier_uniform_(self.output_proj.weight.data)
|
316 |
+
constant_(self.output_proj.bias.data, 0.)
|
317 |
+
|
318 |
+
def forward(self, input):
|
319 |
+
"""
|
320 |
+
:param query (N, H, W, C)
|
321 |
+
:return output (N, H, W, C)
|
322 |
+
"""
|
323 |
+
N, H, W, _ = input.shape
|
324 |
+
|
325 |
+
x = self.input_proj(input)
|
326 |
+
x_proj = x
|
327 |
+
dtype = x.dtype
|
328 |
+
|
329 |
+
x1 = input.permute(0, 3, 1, 2)
|
330 |
+
x1 = self.dw_conv(x1)
|
331 |
+
offset = self.offset(x1)
|
332 |
+
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
|
333 |
+
mask = F.softmax(mask, -1)
|
334 |
+
mask = mask.reshape(N, H, W, -1).type(dtype)
|
335 |
+
|
336 |
+
x = DCNv3Function.apply(
|
337 |
+
x, offset, mask,
|
338 |
+
self.kernel_size, self.kernel_size,
|
339 |
+
self.stride, self.stride,
|
340 |
+
self.pad, self.pad,
|
341 |
+
self.dilation, self.dilation,
|
342 |
+
self.group, self.group_channels,
|
343 |
+
self.offset_scale,
|
344 |
+
256,
|
345 |
+
self.remove_center)
|
346 |
+
|
347 |
+
if self.center_feature_scale:
|
348 |
+
center_feature_scale = self.center_feature_scale_module(
|
349 |
+
x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
|
350 |
+
# N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels
|
351 |
+
center_feature_scale = center_feature_scale[..., None].repeat(
|
352 |
+
1, 1, 1, 1, self.channels // self.group).flatten(-2)
|
353 |
+
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
|
354 |
+
x = self.output_proj(x)
|
355 |
+
|
356 |
+
return x
|
dcnv3_func.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternImage
|
3 |
+
# Copyright (c) 2025 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
from __future__ import absolute_import, division, print_function
|
8 |
+
|
9 |
+
try:
|
10 |
+
import DCNv3
|
11 |
+
dcn_version = float(pkg_resources.get_distribution('DCNv3').version)
|
12 |
+
has_cuda_kernel = True
|
13 |
+
except:
|
14 |
+
has_cuda_kernel = False
|
15 |
+
import pkg_resources
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.autograd import Function
|
19 |
+
from torch.autograd.function import once_differentiable
|
20 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
21 |
+
|
22 |
+
|
23 |
+
class DCNv3Function(Function):
|
24 |
+
@staticmethod
|
25 |
+
@custom_fwd
|
26 |
+
def forward(
|
27 |
+
ctx, input, offset, mask,
|
28 |
+
kernel_h, kernel_w, stride_h, stride_w,
|
29 |
+
pad_h, pad_w, dilation_h, dilation_w,
|
30 |
+
group, group_channels, offset_scale, im2col_step, remove_center):
|
31 |
+
ctx.kernel_h = kernel_h
|
32 |
+
ctx.kernel_w = kernel_w
|
33 |
+
ctx.stride_h = stride_h
|
34 |
+
ctx.stride_w = stride_w
|
35 |
+
ctx.pad_h = pad_h
|
36 |
+
ctx.pad_w = pad_w
|
37 |
+
ctx.dilation_h = dilation_h
|
38 |
+
ctx.dilation_w = dilation_w
|
39 |
+
ctx.group = group
|
40 |
+
ctx.group_channels = group_channels
|
41 |
+
ctx.offset_scale = offset_scale
|
42 |
+
ctx.im2col_step = im2col_step
|
43 |
+
ctx.remove_center = remove_center
|
44 |
+
|
45 |
+
args = [
|
46 |
+
input, offset, mask, kernel_h,
|
47 |
+
kernel_w, stride_h, stride_w, pad_h,
|
48 |
+
pad_w, dilation_h, dilation_w, group,
|
49 |
+
group_channels, offset_scale, ctx.im2col_step
|
50 |
+
]
|
51 |
+
if remove_center or dcn_version > 1.0:
|
52 |
+
args.append(remove_center)
|
53 |
+
|
54 |
+
output = DCNv3.dcnv3_forward(*args)
|
55 |
+
ctx.save_for_backward(input, offset, mask)
|
56 |
+
|
57 |
+
return output
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
@once_differentiable
|
61 |
+
@custom_bwd
|
62 |
+
def backward(ctx, grad_output):
|
63 |
+
input, offset, mask = ctx.saved_tensors
|
64 |
+
|
65 |
+
args = [
|
66 |
+
input, offset, mask, ctx.kernel_h,
|
67 |
+
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
|
68 |
+
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
|
69 |
+
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step
|
70 |
+
]
|
71 |
+
if ctx.remove_center or dcn_version > 1.0:
|
72 |
+
args.append(ctx.remove_center)
|
73 |
+
|
74 |
+
grad_input, grad_offset, grad_mask = \
|
75 |
+
DCNv3.dcnv3_backward(*args)
|
76 |
+
|
77 |
+
return grad_input, grad_offset, grad_mask, \
|
78 |
+
None, None, None, None, None, None, None, None, None, None, None, None, None
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h,
|
82 |
+
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
83 |
+
group_channels, offset_scale, im2col_step, remove_center):
|
84 |
+
"""Symbolic function for mmdeploy::DCNv3.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
DCNv3 op for onnx.
|
88 |
+
"""
|
89 |
+
return g.op(
|
90 |
+
'mmdeploy::TRTDCNv3',
|
91 |
+
input,
|
92 |
+
offset,
|
93 |
+
mask,
|
94 |
+
kernel_h_i=int(kernel_h),
|
95 |
+
kernel_w_i=int(kernel_w),
|
96 |
+
stride_h_i=int(stride_h),
|
97 |
+
stride_w_i=int(stride_w),
|
98 |
+
pad_h_i=int(pad_h),
|
99 |
+
pad_w_i=int(pad_w),
|
100 |
+
dilation_h_i=int(dilation_h),
|
101 |
+
dilation_w_i=int(dilation_w),
|
102 |
+
group_i=int(group),
|
103 |
+
group_channels_i=int(group_channels),
|
104 |
+
offset_scale_f=float(offset_scale),
|
105 |
+
im2col_step_i=int(im2col_step),
|
106 |
+
remove_center_i=int(remove_center),
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
|
111 |
+
_, H_, W_, _ = spatial_shapes
|
112 |
+
H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
|
113 |
+
W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
|
114 |
+
|
115 |
+
ref_y, ref_x = torch.meshgrid(
|
116 |
+
torch.linspace(
|
117 |
+
# pad_h + 0.5,
|
118 |
+
# H_ - pad_h - 0.5,
|
119 |
+
(dilation_h * (kernel_h - 1)) // 2 + 0.5,
|
120 |
+
(dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,
|
121 |
+
H_out,
|
122 |
+
dtype=torch.float32,
|
123 |
+
device=device),
|
124 |
+
torch.linspace(
|
125 |
+
# pad_w + 0.5,
|
126 |
+
# W_ - pad_w - 0.5,
|
127 |
+
(dilation_w * (kernel_w - 1)) // 2 + 0.5,
|
128 |
+
(dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,
|
129 |
+
W_out,
|
130 |
+
dtype=torch.float32,
|
131 |
+
device=device))
|
132 |
+
ref_y = ref_y.reshape(-1)[None] / H_
|
133 |
+
ref_x = ref_x.reshape(-1)[None] / W_
|
134 |
+
|
135 |
+
ref = torch.stack((ref_x, ref_y), -1).reshape(
|
136 |
+
1, H_out, W_out, 1, 2)
|
137 |
+
|
138 |
+
return ref
|
139 |
+
|
140 |
+
|
141 |
+
def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device):
|
142 |
+
_, H_, W_, _ = spatial_shapes
|
143 |
+
points_list = []
|
144 |
+
x, y = torch.meshgrid(
|
145 |
+
torch.linspace(
|
146 |
+
-((dilation_w * (kernel_w - 1)) // 2),
|
147 |
+
-((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w,
|
148 |
+
kernel_w,
|
149 |
+
dtype=torch.float32,
|
150 |
+
device=device),
|
151 |
+
torch.linspace(
|
152 |
+
-((dilation_h * (kernel_h - 1)) // 2),
|
153 |
+
-((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h,
|
154 |
+
kernel_h,
|
155 |
+
dtype=torch.float32,
|
156 |
+
device=device))
|
157 |
+
|
158 |
+
points_list.extend([x / W_, y / H_])
|
159 |
+
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
|
160 |
+
repeat(1, group, 1).permute(1, 0, 2)
|
161 |
+
grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)
|
162 |
+
|
163 |
+
return grid
|
164 |
+
|
165 |
+
|
166 |
+
def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h):
|
167 |
+
idx = list(range(sampling_locations.shape[-2]))
|
168 |
+
C = (kernel_w * kernel_h - 1)//2
|
169 |
+
idx = [i for i in idx if i != C and (i-C) % (C*2+1) != 0]
|
170 |
+
sampling_locations = sampling_locations[:,:,:,idx, :]
|
171 |
+
return sampling_locations
|
172 |
+
|
173 |
+
|
174 |
+
def dcnv3_core_pytorch(
|
175 |
+
input, offset, mask, kernel_h,
|
176 |
+
kernel_w, stride_h, stride_w, pad_h,
|
177 |
+
pad_w, dilation_h, dilation_w, group,
|
178 |
+
group_channels, offset_scale, remove_center):
|
179 |
+
# for debug and test only,
|
180 |
+
# need to use cuda version instead
|
181 |
+
|
182 |
+
if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 or kernel_w != kernel_h):
|
183 |
+
raise ValueError('remove_center is only compatible with square odd kernel size.')
|
184 |
+
|
185 |
+
input = F.pad(
|
186 |
+
input,
|
187 |
+
[0, 0, pad_h, pad_h, pad_w, pad_w])
|
188 |
+
N_, H_in, W_in, _ = input.shape
|
189 |
+
_, H_out, W_out, _ = offset.shape
|
190 |
+
|
191 |
+
ref = _get_reference_points(
|
192 |
+
input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w)
|
193 |
+
grid = _generate_dilation_grids(
|
194 |
+
input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device)
|
195 |
+
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\
|
196 |
+
repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).to(input.device)
|
197 |
+
|
198 |
+
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1)
|
199 |
+
if remove_center:
|
200 |
+
sampling_locations = remove_center_sampling_locations(sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h)
|
201 |
+
sampling_locations = sampling_locations.flatten(3, 4)
|
202 |
+
sampling_locations = sampling_locations + offset * offset_scale / spatial_norm
|
203 |
+
|
204 |
+
P_ = kernel_h * kernel_w - remove_center
|
205 |
+
sampling_grids = 2 * sampling_locations - 1
|
206 |
+
# N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in
|
207 |
+
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\
|
208 |
+
reshape(N_*group, group_channels, H_in, W_in)
|
209 |
+
# N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
|
210 |
+
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
|
211 |
+
flatten(0, 1)
|
212 |
+
# N_*group, group_channels, H_out*W_out, P_
|
213 |
+
sampling_input_ = F.grid_sample(
|
214 |
+
input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
|
215 |
+
|
216 |
+
# (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_)
|
217 |
+
mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\
|
218 |
+
reshape(N_*group, 1, H_out*W_out, P_)
|
219 |
+
output = (sampling_input_ * mask).sum(-1).view(N_,
|
220 |
+
group*group_channels, H_out*W_out)
|
221 |
+
|
222 |
+
return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e38da0af5a2f47f30ba7518ce9598ba865f9441de7e7c13d685b99e7dc09e53
|
3 |
+
size 4422811076
|
modeling_internimage.py
ADDED
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternImage
|
3 |
+
# Copyright (c) 2025 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.utils.checkpoint as checkpoint
|
14 |
+
from timm.models.layers import DropPath, trunc_normal_
|
15 |
+
from torch import nn
|
16 |
+
from transformers import PreTrainedModel
|
17 |
+
from transformers.modeling_outputs import ModelOutput
|
18 |
+
|
19 |
+
from .configuration_internimage import InternImageConfig
|
20 |
+
from .dcnv3 import DCNv3, DCNv3_pytorch, has_cuda_kernel
|
21 |
+
from .dcnv3_func import dcnv3_core_pytorch
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class BackboneOutput(ModelOutput):
|
26 |
+
"""
|
27 |
+
Base class for outputs of backbones.
|
28 |
+
"""
|
29 |
+
|
30 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
31 |
+
pooler_output: Optional[torch.FloatTensor] = None
|
32 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
33 |
+
logits: Optional[torch.FloatTensor] = None
|
34 |
+
loss: Optional[torch.FloatTensor] = None
|
35 |
+
|
36 |
+
|
37 |
+
class to_channels_first(nn.Module):
|
38 |
+
|
39 |
+
def __init__(self):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return x.permute(0, 3, 1, 2)
|
44 |
+
|
45 |
+
|
46 |
+
class to_channels_last(nn.Module):
|
47 |
+
|
48 |
+
def __init__(self):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
return x.permute(0, 2, 3, 1)
|
53 |
+
|
54 |
+
|
55 |
+
def build_norm_layer(dim,
|
56 |
+
norm_layer,
|
57 |
+
in_format='channels_last',
|
58 |
+
out_format='channels_last',
|
59 |
+
eps=1e-6):
|
60 |
+
layers = []
|
61 |
+
if norm_layer == 'BN':
|
62 |
+
if in_format == 'channels_last':
|
63 |
+
layers.append(to_channels_first())
|
64 |
+
layers.append(nn.BatchNorm2d(dim))
|
65 |
+
if out_format == 'channels_last':
|
66 |
+
layers.append(to_channels_last())
|
67 |
+
elif norm_layer == 'LN':
|
68 |
+
if in_format == 'channels_first':
|
69 |
+
layers.append(to_channels_last())
|
70 |
+
layers.append(nn.LayerNorm(dim, eps=eps))
|
71 |
+
if out_format == 'channels_first':
|
72 |
+
layers.append(to_channels_first())
|
73 |
+
else:
|
74 |
+
raise NotImplementedError(
|
75 |
+
f'build_norm_layer does not support {norm_layer}')
|
76 |
+
return nn.Sequential(*layers)
|
77 |
+
|
78 |
+
|
79 |
+
def build_act_layer(act_layer):
|
80 |
+
if act_layer == 'ReLU':
|
81 |
+
return nn.ReLU(inplace=True)
|
82 |
+
elif act_layer == 'SiLU':
|
83 |
+
return nn.SiLU(inplace=True)
|
84 |
+
elif act_layer == 'GELU':
|
85 |
+
return nn.GELU()
|
86 |
+
|
87 |
+
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
|
88 |
+
|
89 |
+
|
90 |
+
class CrossAttention(nn.Module):
|
91 |
+
r""" Cross Attention Module
|
92 |
+
Args:
|
93 |
+
dim (int): Number of input channels.
|
94 |
+
num_heads (int): Number of attention heads. Default: 8
|
95 |
+
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
96 |
+
Default: False.
|
97 |
+
qk_scale (float | None, optional): Override default qk scale of
|
98 |
+
head_dim ** -0.5 if set. Default: None.
|
99 |
+
attn_drop (float, optional): Dropout ratio of attention weight.
|
100 |
+
Default: 0.0
|
101 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
102 |
+
attn_head_dim (int, optional): Dimension of attention head.
|
103 |
+
out_dim (int, optional): Dimension of output.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self,
|
107 |
+
dim,
|
108 |
+
num_heads=8,
|
109 |
+
qkv_bias=False,
|
110 |
+
qk_scale=None,
|
111 |
+
attn_drop=0.,
|
112 |
+
proj_drop=0.,
|
113 |
+
attn_head_dim=None,
|
114 |
+
out_dim=None):
|
115 |
+
super().__init__()
|
116 |
+
if out_dim is None:
|
117 |
+
out_dim = dim
|
118 |
+
self.num_heads = num_heads
|
119 |
+
head_dim = dim // num_heads
|
120 |
+
if attn_head_dim is not None:
|
121 |
+
head_dim = attn_head_dim
|
122 |
+
all_head_dim = head_dim * self.num_heads
|
123 |
+
self.scale = qk_scale or head_dim ** -0.5
|
124 |
+
assert all_head_dim == dim
|
125 |
+
|
126 |
+
self.q = nn.Linear(dim, all_head_dim, bias=False)
|
127 |
+
self.k = nn.Linear(dim, all_head_dim, bias=False)
|
128 |
+
self.v = nn.Linear(dim, all_head_dim, bias=False)
|
129 |
+
|
130 |
+
if qkv_bias:
|
131 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
132 |
+
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
|
133 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
134 |
+
else:
|
135 |
+
self.q_bias = None
|
136 |
+
self.k_bias = None
|
137 |
+
self.v_bias = None
|
138 |
+
|
139 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
140 |
+
self.proj = nn.Linear(all_head_dim, out_dim)
|
141 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
142 |
+
|
143 |
+
def forward(self, x, k=None, v=None):
|
144 |
+
B, N, C = x.shape
|
145 |
+
N_k = k.shape[1]
|
146 |
+
N_v = v.shape[1]
|
147 |
+
|
148 |
+
q_bias, k_bias, v_bias = None, None, None
|
149 |
+
if self.q_bias is not None:
|
150 |
+
q_bias = self.q_bias
|
151 |
+
k_bias = self.k_bias
|
152 |
+
v_bias = self.v_bias
|
153 |
+
|
154 |
+
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
155 |
+
q = q.reshape(B, N, 1, self.num_heads,
|
156 |
+
-1).permute(2, 0, 3, 1,
|
157 |
+
4).squeeze(0) # (B, N_head, N_q, dim)
|
158 |
+
|
159 |
+
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
160 |
+
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
|
161 |
+
4).squeeze(0)
|
162 |
+
|
163 |
+
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
164 |
+
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
|
165 |
+
4).squeeze(0)
|
166 |
+
|
167 |
+
q = q * self.scale
|
168 |
+
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
|
169 |
+
|
170 |
+
attn = attn.softmax(dim=-1)
|
171 |
+
attn = self.attn_drop(attn)
|
172 |
+
|
173 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
174 |
+
x = self.proj(x)
|
175 |
+
x = self.proj_drop(x)
|
176 |
+
|
177 |
+
return x
|
178 |
+
|
179 |
+
|
180 |
+
class AttentiveBlock(nn.Module):
|
181 |
+
r"""Attentive Block
|
182 |
+
Args:
|
183 |
+
dim (int): Number of input channels.
|
184 |
+
num_heads (int): Number of attention heads. Default: 8
|
185 |
+
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
186 |
+
Default: False.
|
187 |
+
qk_scale (float | None, optional): Override default qk scale of
|
188 |
+
head_dim ** -0.5 if set. Default: None.
|
189 |
+
drop (float, optional): Dropout rate. Default: 0.0.
|
190 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
|
191 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate.
|
192 |
+
Default: 0.0.
|
193 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
|
194 |
+
attn_head_dim (int, optional): Dimension of attention head. Default: None.
|
195 |
+
out_dim (int, optional): Dimension of output. Default: None.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self,
|
199 |
+
dim,
|
200 |
+
num_heads,
|
201 |
+
qkv_bias=False,
|
202 |
+
qk_scale=None,
|
203 |
+
drop=0.,
|
204 |
+
attn_drop=0.,
|
205 |
+
drop_path=0.,
|
206 |
+
norm_layer='LN',
|
207 |
+
attn_head_dim=None,
|
208 |
+
out_dim=None):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
|
212 |
+
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
|
213 |
+
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
|
214 |
+
self.cross_dcn = CrossAttention(dim,
|
215 |
+
num_heads=num_heads,
|
216 |
+
qkv_bias=qkv_bias,
|
217 |
+
qk_scale=qk_scale,
|
218 |
+
attn_drop=attn_drop,
|
219 |
+
proj_drop=drop,
|
220 |
+
attn_head_dim=attn_head_dim,
|
221 |
+
out_dim=out_dim)
|
222 |
+
|
223 |
+
self.drop_path = DropPath(
|
224 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
225 |
+
|
226 |
+
def forward(self,
|
227 |
+
x_q,
|
228 |
+
x_kv,
|
229 |
+
pos_q,
|
230 |
+
pos_k,
|
231 |
+
bool_masked_pos,
|
232 |
+
rel_pos_bias=None):
|
233 |
+
x_q = self.norm1_q(x_q + pos_q)
|
234 |
+
x_k = self.norm1_k(x_kv + pos_k)
|
235 |
+
x_v = self.norm1_v(x_kv)
|
236 |
+
|
237 |
+
x = self.cross_dcn(x_q, k=x_k, v=x_v)
|
238 |
+
|
239 |
+
return x
|
240 |
+
|
241 |
+
|
242 |
+
class AttentionPoolingBlock(AttentiveBlock):
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
x_q = x.mean(1, keepdim=True)
|
246 |
+
x_kv = x
|
247 |
+
pos_q, pos_k = 0, 0
|
248 |
+
x = super().forward(x_q, x_kv, pos_q, pos_k,
|
249 |
+
bool_masked_pos=None,
|
250 |
+
rel_pos_bias=None)
|
251 |
+
x = x.squeeze(1)
|
252 |
+
return x
|
253 |
+
|
254 |
+
|
255 |
+
class StemLayer(nn.Module):
|
256 |
+
r"""Stem layer of InternImage
|
257 |
+
Args:
|
258 |
+
in_chans (int): number of input channels
|
259 |
+
out_chans (int): number of output channels
|
260 |
+
act_layer (str): activation layer
|
261 |
+
norm_layer (str): normalization layer
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self,
|
265 |
+
in_chans=3,
|
266 |
+
out_chans=96,
|
267 |
+
act_layer='GELU',
|
268 |
+
norm_layer='BN'):
|
269 |
+
super().__init__()
|
270 |
+
self.conv1 = nn.Conv2d(in_chans,
|
271 |
+
out_chans // 2,
|
272 |
+
kernel_size=3,
|
273 |
+
stride=2,
|
274 |
+
padding=1)
|
275 |
+
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
|
276 |
+
'channels_first', 'channels_first')
|
277 |
+
self.act = build_act_layer(act_layer)
|
278 |
+
self.conv2 = nn.Conv2d(out_chans // 2,
|
279 |
+
out_chans,
|
280 |
+
kernel_size=3,
|
281 |
+
stride=2,
|
282 |
+
padding=1)
|
283 |
+
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
|
284 |
+
'channels_last')
|
285 |
+
|
286 |
+
def forward(self, x):
|
287 |
+
x = self.conv1(x)
|
288 |
+
x = self.norm1(x)
|
289 |
+
x = self.act(x)
|
290 |
+
x = self.conv2(x)
|
291 |
+
x = self.norm2(x)
|
292 |
+
return x
|
293 |
+
|
294 |
+
|
295 |
+
class DownsampleLayer(nn.Module):
|
296 |
+
r"""Downsample layer of InternImage
|
297 |
+
Args:
|
298 |
+
channels (int): number of input channels
|
299 |
+
norm_layer (str): normalization layer
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, channels, norm_layer='LN'):
|
303 |
+
super().__init__()
|
304 |
+
self.conv = nn.Conv2d(channels,
|
305 |
+
2 * channels,
|
306 |
+
kernel_size=3,
|
307 |
+
stride=2,
|
308 |
+
padding=1,
|
309 |
+
bias=False)
|
310 |
+
self.norm = build_norm_layer(2 * channels, norm_layer,
|
311 |
+
'channels_first', 'channels_last')
|
312 |
+
|
313 |
+
def forward(self, x):
|
314 |
+
x = self.conv(x.permute(0, 3, 1, 2))
|
315 |
+
x = self.norm(x)
|
316 |
+
return x
|
317 |
+
|
318 |
+
|
319 |
+
class MLPLayer(nn.Module):
|
320 |
+
r"""MLP layer of InternImage
|
321 |
+
Args:
|
322 |
+
in_features (int): number of input features
|
323 |
+
hidden_features (int): number of hidden features
|
324 |
+
out_features (int): number of output features
|
325 |
+
act_layer (str): activation layer
|
326 |
+
drop (float): dropout rate
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self,
|
330 |
+
in_features,
|
331 |
+
hidden_features=None,
|
332 |
+
out_features=None,
|
333 |
+
act_layer='GELU',
|
334 |
+
drop=0.):
|
335 |
+
super().__init__()
|
336 |
+
out_features = out_features or in_features
|
337 |
+
hidden_features = hidden_features or in_features
|
338 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
339 |
+
self.act = build_act_layer(act_layer)
|
340 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
341 |
+
self.drop = nn.Dropout(drop)
|
342 |
+
|
343 |
+
def forward(self, x):
|
344 |
+
x = self.fc1(x)
|
345 |
+
x = self.act(x)
|
346 |
+
x = self.drop(x)
|
347 |
+
x = self.fc2(x)
|
348 |
+
x = self.drop(x)
|
349 |
+
return x
|
350 |
+
|
351 |
+
|
352 |
+
class InternImageLayer(nn.Module):
|
353 |
+
r"""Basic layer of InternImage
|
354 |
+
Args:
|
355 |
+
core_op (nn.Module): core operation of InternImage
|
356 |
+
channels (int): number of input channels
|
357 |
+
groups (list): Groups of each block.
|
358 |
+
mlp_ratio (float): ratio of mlp hidden features to input channels
|
359 |
+
drop (float): dropout rate
|
360 |
+
drop_path (float): drop path rate
|
361 |
+
act_layer (str): activation layer
|
362 |
+
norm_layer (str): normalization layer
|
363 |
+
post_norm (bool): whether to use post normalization
|
364 |
+
layer_scale (float): layer scale
|
365 |
+
offset_scale (float): offset scale
|
366 |
+
with_cp (bool): whether to use checkpoint
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self,
|
370 |
+
core_op,
|
371 |
+
channels,
|
372 |
+
groups,
|
373 |
+
mlp_ratio=4.,
|
374 |
+
drop=0.,
|
375 |
+
drop_path=0.,
|
376 |
+
act_layer='GELU',
|
377 |
+
norm_layer='LN',
|
378 |
+
post_norm=False,
|
379 |
+
layer_scale=None,
|
380 |
+
offset_scale=1.0,
|
381 |
+
with_cp=False,
|
382 |
+
dw_kernel_size=None, # for InternImage-H/G
|
383 |
+
res_post_norm=False, # for InternImage-H/G
|
384 |
+
center_feature_scale=False, # for InternImage-H/G
|
385 |
+
remove_center=False, # for InternImage-H/G
|
386 |
+
):
|
387 |
+
super().__init__()
|
388 |
+
self.channels = channels
|
389 |
+
self.groups = groups
|
390 |
+
self.mlp_ratio = mlp_ratio
|
391 |
+
self.with_cp = with_cp
|
392 |
+
|
393 |
+
self.norm1 = build_norm_layer(channels, 'LN')
|
394 |
+
self.post_norm = post_norm
|
395 |
+
self.dcn = core_op(
|
396 |
+
channels=channels,
|
397 |
+
kernel_size=3,
|
398 |
+
stride=1,
|
399 |
+
pad=1,
|
400 |
+
dilation=1,
|
401 |
+
group=groups,
|
402 |
+
offset_scale=offset_scale,
|
403 |
+
act_layer=act_layer,
|
404 |
+
norm_layer=norm_layer,
|
405 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
406 |
+
center_feature_scale=center_feature_scale, # for InternImage-H/G
|
407 |
+
remove_center=remove_center, # for InternImage-H/G
|
408 |
+
)
|
409 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
410 |
+
else nn.Identity()
|
411 |
+
self.norm2 = build_norm_layer(channels, 'LN')
|
412 |
+
self.mlp = MLPLayer(in_features=channels,
|
413 |
+
hidden_features=int(channels * mlp_ratio),
|
414 |
+
act_layer=act_layer,
|
415 |
+
drop=drop)
|
416 |
+
self.layer_scale = layer_scale is not None
|
417 |
+
if self.layer_scale:
|
418 |
+
self.layer_scale1 = nn.Parameter(layer_scale * torch.ones(channels),
|
419 |
+
requires_grad=True)
|
420 |
+
self.layer_scale2 = nn.Parameter(layer_scale * torch.ones(channels),
|
421 |
+
requires_grad=True)
|
422 |
+
self.res_post_norm = res_post_norm
|
423 |
+
if res_post_norm:
|
424 |
+
self.res_post_norm1 = build_norm_layer(channels, 'LN')
|
425 |
+
self.res_post_norm2 = build_norm_layer(channels, 'LN')
|
426 |
+
|
427 |
+
def forward(self, x):
|
428 |
+
|
429 |
+
def _inner_forward(x):
|
430 |
+
if not self.layer_scale:
|
431 |
+
if self.post_norm:
|
432 |
+
x = x + self.drop_path(self.norm1(self.dcn(x)))
|
433 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
434 |
+
elif self.res_post_norm: # for InternImage-H/G
|
435 |
+
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
|
436 |
+
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
|
437 |
+
else:
|
438 |
+
x = x + self.drop_path(self.dcn(self.norm1(x)))
|
439 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
440 |
+
return x
|
441 |
+
if self.post_norm:
|
442 |
+
x = x + self.drop_path(self.layer_scale1 * self.norm1(self.dcn(x)))
|
443 |
+
x = x + self.drop_path(self.layer_scale2 * self.norm2(self.mlp(x)))
|
444 |
+
else:
|
445 |
+
x = x + self.drop_path(self.layer_scale1 * self.dcn(self.norm1(x)))
|
446 |
+
x = x + self.drop_path(self.layer_scale2 * self.mlp(self.norm2(x)))
|
447 |
+
return x
|
448 |
+
|
449 |
+
if self.with_cp and x.requires_grad:
|
450 |
+
x = checkpoint.checkpoint(_inner_forward, x)
|
451 |
+
else:
|
452 |
+
x = _inner_forward(x)
|
453 |
+
return x
|
454 |
+
|
455 |
+
|
456 |
+
class InternImageBlock(nn.Module):
|
457 |
+
r"""Block of InternImage
|
458 |
+
Args:
|
459 |
+
core_op (nn.Module): core operation of InternImage
|
460 |
+
channels (int): number of input channels
|
461 |
+
depths (list): Depth of each block.
|
462 |
+
groups (list): Groups of each block.
|
463 |
+
mlp_ratio (float): ratio of mlp hidden features to input channels
|
464 |
+
drop (float): dropout rate
|
465 |
+
drop_path (float): drop path rate
|
466 |
+
act_layer (str): activation layer
|
467 |
+
norm_layer (str): normalization layer
|
468 |
+
post_norm (bool): whether to use post normalization
|
469 |
+
layer_scale (float): layer scale
|
470 |
+
offset_scale (float): offset scale
|
471 |
+
with_cp (bool): whether to use checkpoint
|
472 |
+
"""
|
473 |
+
|
474 |
+
def __init__(self,
|
475 |
+
core_op,
|
476 |
+
channels,
|
477 |
+
depth,
|
478 |
+
groups,
|
479 |
+
downsample=True,
|
480 |
+
mlp_ratio=4.,
|
481 |
+
drop=0.,
|
482 |
+
drop_path=0.,
|
483 |
+
act_layer='GELU',
|
484 |
+
norm_layer='LN',
|
485 |
+
post_norm=False,
|
486 |
+
offset_scale=1.0,
|
487 |
+
layer_scale=None,
|
488 |
+
with_cp=False,
|
489 |
+
dw_kernel_size=None, # for InternImage-H/G
|
490 |
+
post_norm_block_ids=None, # for InternImage-H/G
|
491 |
+
res_post_norm=False, # for InternImage-H/G
|
492 |
+
center_feature_scale=False, # for InternImage-H/G
|
493 |
+
remove_center=False, # for InternImage-H/G
|
494 |
+
):
|
495 |
+
super().__init__()
|
496 |
+
self.channels = channels
|
497 |
+
self.depth = depth
|
498 |
+
self.post_norm = post_norm
|
499 |
+
self.center_feature_scale = center_feature_scale
|
500 |
+
|
501 |
+
self.blocks = nn.ModuleList([
|
502 |
+
InternImageLayer(
|
503 |
+
core_op=core_op,
|
504 |
+
channels=channels,
|
505 |
+
groups=groups,
|
506 |
+
mlp_ratio=mlp_ratio,
|
507 |
+
drop=drop,
|
508 |
+
drop_path=drop_path[i] if isinstance(
|
509 |
+
drop_path, list) else drop_path,
|
510 |
+
act_layer=act_layer,
|
511 |
+
norm_layer=norm_layer,
|
512 |
+
post_norm=post_norm,
|
513 |
+
layer_scale=layer_scale,
|
514 |
+
offset_scale=offset_scale,
|
515 |
+
with_cp=with_cp,
|
516 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
517 |
+
res_post_norm=res_post_norm, # for InternImage-H/G
|
518 |
+
center_feature_scale=center_feature_scale, # for InternImage-H/G
|
519 |
+
remove_center=remove_center, # for InternImage-H/G
|
520 |
+
) for i in range(depth)
|
521 |
+
])
|
522 |
+
if not self.post_norm or center_feature_scale:
|
523 |
+
self.norm = build_norm_layer(channels, 'LN')
|
524 |
+
self.post_norm_block_ids = post_norm_block_ids
|
525 |
+
if post_norm_block_ids is not None: # for InternImage-H/G
|
526 |
+
self.post_norms = nn.ModuleList(
|
527 |
+
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
|
528 |
+
)
|
529 |
+
self.downsample = DownsampleLayer(
|
530 |
+
channels=channels, norm_layer=norm_layer) if downsample else None
|
531 |
+
|
532 |
+
def forward(self, x, return_wo_downsample=False):
|
533 |
+
for i, blk in enumerate(self.blocks):
|
534 |
+
x = blk(x)
|
535 |
+
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
|
536 |
+
index = self.post_norm_block_ids.index(i)
|
537 |
+
x = self.post_norms[index](x) # for InternImage-H/G
|
538 |
+
if not self.post_norm or self.center_feature_scale:
|
539 |
+
x = self.norm(x)
|
540 |
+
if return_wo_downsample:
|
541 |
+
x_ = x
|
542 |
+
if self.downsample is not None:
|
543 |
+
x = self.downsample(x)
|
544 |
+
|
545 |
+
if return_wo_downsample:
|
546 |
+
return x, x_
|
547 |
+
return x
|
548 |
+
|
549 |
+
|
550 |
+
class InternImage(nn.Module):
|
551 |
+
r"""InternImage
|
552 |
+
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
|
553 |
+
https://arxiv.org/pdf/2103.14030
|
554 |
+
Args:
|
555 |
+
core_op (str): Core operator. Default: 'DCNv3'
|
556 |
+
channels (int): Number of the first stage. Default: 64
|
557 |
+
depths (list): Depth of each block. Default: [3, 4, 18, 5]
|
558 |
+
groups (list): Groups of each block. Default: [3, 6, 12, 24]
|
559 |
+
num_classes (int): Number of classes. Default: 1000
|
560 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
561 |
+
drop_rate (float): Probability of an element to be zeroed. Default: 0.
|
562 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
563 |
+
act_layer (str): Activation layer. Default: 'GELU'
|
564 |
+
norm_layer (str): Normalization layer. Default: 'LN'
|
565 |
+
layer_scale (float): The initial value of layer scale. Default: None
|
566 |
+
cls_scale (float): Whether to use class scale. Default: 1.5
|
567 |
+
with_cp (bool): Use gradient checkpointing or not. Default: False
|
568 |
+
dw_kernel_size (int): Size of the dwconv. Default: None
|
569 |
+
use_clip_projector (bool): Whether to use clip projector. Default: False
|
570 |
+
level2_post_norm (bool): Whether to use level2 post norm. Default: False
|
571 |
+
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
|
572 |
+
res_post_norm (bool): Whether to use res post norm. Default: False
|
573 |
+
center_feature_scale (bool): Whether to use center feature scale. Default: False
|
574 |
+
"""
|
575 |
+
|
576 |
+
def __init__(self,
|
577 |
+
core_op='DCNv3',
|
578 |
+
channels=64,
|
579 |
+
depths=[3, 4, 18, 5],
|
580 |
+
groups=[3, 6, 12, 24],
|
581 |
+
num_classes=1000,
|
582 |
+
mlp_ratio=4.,
|
583 |
+
drop_rate=0.,
|
584 |
+
drop_path_rate=0.2,
|
585 |
+
drop_path_type='linear',
|
586 |
+
act_layer='GELU',
|
587 |
+
norm_layer='LN',
|
588 |
+
layer_scale=None,
|
589 |
+
offset_scale=1.0,
|
590 |
+
post_norm=False,
|
591 |
+
cls_scale=1.5,
|
592 |
+
with_cp=False,
|
593 |
+
dw_kernel_size=None, # for InternImage-H/G
|
594 |
+
use_clip_projector=False, # for InternImage-H/G
|
595 |
+
level2_post_norm=False, # for InternImage-H/G
|
596 |
+
level2_post_norm_block_ids=None, # for InternImage-H/G
|
597 |
+
res_post_norm=False, # for InternImage-H/G
|
598 |
+
center_feature_scale=False, # for InternImage-H/G
|
599 |
+
remove_center=False, # for InternImage-H/G
|
600 |
+
**kwargs):
|
601 |
+
super().__init__()
|
602 |
+
if core_op == 'DCNv3' and has_cuda_kernel:
|
603 |
+
self.core_op = DCNv3
|
604 |
+
print('DCNv3 is installed, using CUDA implementation.')
|
605 |
+
elif core_op == 'DCNv3' and not has_cuda_kernel:
|
606 |
+
self.core_op = DCNv3_pytorch
|
607 |
+
print('DCNv3 is not installed, using PyTorch implementation.')
|
608 |
+
else:
|
609 |
+
self.core_op = DCNv3_pytorch
|
610 |
+
print('Using DCNv3 PyTorch implementation.')
|
611 |
+
self.num_classes = num_classes
|
612 |
+
self.num_levels = len(depths)
|
613 |
+
self.depths = depths
|
614 |
+
self.channels = channels
|
615 |
+
self.num_features = int(channels * 2 ** (self.num_levels - 1))
|
616 |
+
self.post_norm = post_norm
|
617 |
+
self.mlp_ratio = mlp_ratio
|
618 |
+
self.use_clip_projector = use_clip_projector
|
619 |
+
self.level2_post_norm_block_ids = level2_post_norm_block_ids
|
620 |
+
self.remove_center = remove_center
|
621 |
+
|
622 |
+
print(f'using core type: {core_op}')
|
623 |
+
print(f'level2_post_norm: {level2_post_norm}')
|
624 |
+
print(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}')
|
625 |
+
print(f'res_post_norm: {res_post_norm}')
|
626 |
+
print(f'remove_center: {remove_center}')
|
627 |
+
|
628 |
+
in_chans = 3
|
629 |
+
self.patch_embed = StemLayer(in_chans=in_chans,
|
630 |
+
out_chans=channels,
|
631 |
+
act_layer=act_layer,
|
632 |
+
norm_layer=norm_layer)
|
633 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
634 |
+
|
635 |
+
dpr = [
|
636 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
637 |
+
]
|
638 |
+
if drop_path_type == 'uniform':
|
639 |
+
for i in range(len(dpr)):
|
640 |
+
dpr[i] = drop_path_rate
|
641 |
+
|
642 |
+
self.levels = nn.ModuleList()
|
643 |
+
for i in range(self.num_levels):
|
644 |
+
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
|
645 |
+
i == 2) else None # for InternImage-H/G
|
646 |
+
level = InternImageBlock(
|
647 |
+
core_op=self.core_op,
|
648 |
+
channels=int(channels * 2 ** i),
|
649 |
+
depth=depths[i],
|
650 |
+
groups=groups[i],
|
651 |
+
mlp_ratio=self.mlp_ratio,
|
652 |
+
drop=drop_rate,
|
653 |
+
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
654 |
+
act_layer=act_layer,
|
655 |
+
norm_layer=norm_layer,
|
656 |
+
post_norm=post_norm,
|
657 |
+
downsample=(i < self.num_levels - 1),
|
658 |
+
layer_scale=layer_scale,
|
659 |
+
offset_scale=offset_scale,
|
660 |
+
with_cp=with_cp,
|
661 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
662 |
+
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
|
663 |
+
res_post_norm=res_post_norm, # for InternImage-H/G
|
664 |
+
center_feature_scale=center_feature_scale, # for InternImage-H/G
|
665 |
+
remove_center=remove_center, # for InternImage-H/G
|
666 |
+
)
|
667 |
+
self.levels.append(level)
|
668 |
+
|
669 |
+
if self.num_classes > 0:
|
670 |
+
if not use_clip_projector: # for InternImage-T/S/B/L/XL
|
671 |
+
self.conv_head = nn.Sequential(
|
672 |
+
nn.Conv2d(self.num_features,
|
673 |
+
int(self.num_features * cls_scale),
|
674 |
+
kernel_size=1,
|
675 |
+
bias=False),
|
676 |
+
build_norm_layer(int(self.num_features * cls_scale), 'BN',
|
677 |
+
'channels_first', 'channels_first'),
|
678 |
+
build_act_layer(act_layer))
|
679 |
+
self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
|
680 |
+
if num_classes > 0 else nn.Identity()
|
681 |
+
else: # for InternImage-H/G
|
682 |
+
pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim = 1024, 2, 16, 768
|
683 |
+
self.dcnv3_head_x4 = nn.Sequential(
|
684 |
+
nn.Conv2d(in_channels=self.num_features,
|
685 |
+
out_channels=pretrain_embed_dim * (_stride ** 2),
|
686 |
+
kernel_size=1), nn.PixelShuffle(_stride))
|
687 |
+
self.dcnv3_head_x3 = nn.Conv2d(in_channels=self.num_features // 2,
|
688 |
+
out_channels=pretrain_embed_dim,
|
689 |
+
kernel_size=1)
|
690 |
+
self.clip_projector = AttentionPoolingBlock(
|
691 |
+
dim=pretrain_embed_dim,
|
692 |
+
num_heads=attnpool_num_heads,
|
693 |
+
qkv_bias=True,
|
694 |
+
qk_scale=None,
|
695 |
+
drop=0.,
|
696 |
+
attn_drop=0.,
|
697 |
+
norm_layer=norm_layer,
|
698 |
+
out_dim=clip_embed_dim)
|
699 |
+
self.fc_norm = build_norm_layer(clip_embed_dim, norm_layer, eps=1e-6)
|
700 |
+
self.head = nn.Linear(
|
701 |
+
clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
702 |
+
|
703 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
704 |
+
self.num_layers = len(depths)
|
705 |
+
self.apply(self._init_weights)
|
706 |
+
self.apply(self._init_deform_weights)
|
707 |
+
|
708 |
+
def _init_weights(self, m):
|
709 |
+
if isinstance(m, nn.Linear):
|
710 |
+
trunc_normal_(m.weight, std=.02)
|
711 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
712 |
+
nn.init.constant_(m.bias, 0)
|
713 |
+
elif isinstance(m, nn.LayerNorm):
|
714 |
+
nn.init.constant_(m.bias, 0)
|
715 |
+
nn.init.constant_(m.weight, 1.0)
|
716 |
+
|
717 |
+
def _init_deform_weights(self, m):
|
718 |
+
if isinstance(m, self.core_op):
|
719 |
+
m._reset_parameters()
|
720 |
+
|
721 |
+
@torch.jit.ignore
|
722 |
+
def lr_decay_keywords(self, decay_ratio=0.87):
|
723 |
+
lr_ratios = {}
|
724 |
+
|
725 |
+
# blocks
|
726 |
+
idx = 0
|
727 |
+
for i in range(4):
|
728 |
+
layer_num = 3 - i # 3 2 1 0
|
729 |
+
for j in range(self.depths[layer_num]):
|
730 |
+
block_num = self.depths[layer_num] - j - 1
|
731 |
+
tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num)
|
732 |
+
decay = 1.0 * (decay_ratio ** idx)
|
733 |
+
lr_ratios[tag] = decay
|
734 |
+
idx += 1
|
735 |
+
# patch_embed (before stage-1)
|
736 |
+
lr_ratios['patch_embed'] = lr_ratios['levels.0.blocks.0.']
|
737 |
+
# levels.0.downsample (between stage-1 and stage-2)
|
738 |
+
lr_ratios['levels.0.downsample'] = lr_ratios['levels.1.blocks.0.']
|
739 |
+
lr_ratios['levels.0.norm'] = lr_ratios['levels.1.blocks.0.']
|
740 |
+
# levels.1.downsample (between stage-2 and stage-3)
|
741 |
+
lr_ratios['levels.1.downsample'] = lr_ratios['levels.2.blocks.0.']
|
742 |
+
lr_ratios['levels.1.norm'] = lr_ratios['levels.2.blocks.0.']
|
743 |
+
# levels.2.downsample (between stage-3 and stage-4)
|
744 |
+
lr_ratios['levels.2.downsample'] = lr_ratios['levels.3.blocks.0.']
|
745 |
+
lr_ratios['levels.2.norm'] = lr_ratios['levels.3.blocks.0.']
|
746 |
+
return lr_ratios
|
747 |
+
|
748 |
+
def forward_features_seq_out(self, x):
|
749 |
+
x = self.patch_embed(x)
|
750 |
+
x = self.pos_drop(x)
|
751 |
+
|
752 |
+
seq_out = []
|
753 |
+
for level in self.levels:
|
754 |
+
x, x_ = level(x, return_wo_downsample=True)
|
755 |
+
seq_out.append(x_)
|
756 |
+
return seq_out
|
757 |
+
|
758 |
+
def forward_features(self, x):
|
759 |
+
xs = self.forward_features_seq_out(x)
|
760 |
+
x1, x2, x3, x4 = xs
|
761 |
+
|
762 |
+
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
|
763 |
+
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
|
764 |
+
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
|
765 |
+
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
|
766 |
+
hidden_states = [x1, x2, x3, x4]
|
767 |
+
|
768 |
+
if self.num_classes > 0:
|
769 |
+
x = self.conv_head(x4)
|
770 |
+
x = self.avgpool(x)
|
771 |
+
x = torch.flatten(x, 1)
|
772 |
+
|
773 |
+
return {
|
774 |
+
'hidden_states': hidden_states,
|
775 |
+
'pooler_output': x if self.num_classes > 0 else None
|
776 |
+
}
|
777 |
+
|
778 |
+
def forward_clip_projector(self, x): # for InternImage-H/G
|
779 |
+
xs = self.forward_features_seq_out(x)
|
780 |
+
x1, x2, x3, x4 = xs
|
781 |
+
|
782 |
+
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
|
783 |
+
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
|
784 |
+
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
|
785 |
+
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
|
786 |
+
hidden_states = [x1, x2, x3, x4]
|
787 |
+
|
788 |
+
if self.num_classes > 0:
|
789 |
+
x4 = self.dcnv3_head_x4(x4)
|
790 |
+
x = x4
|
791 |
+
x3 = self.dcnv3_head_x3(x3)
|
792 |
+
x = x + x3
|
793 |
+
|
794 |
+
x = x.flatten(-2).transpose(1, 2).contiguous()
|
795 |
+
x = self.clip_projector(x)
|
796 |
+
x = self.fc_norm(x)
|
797 |
+
|
798 |
+
return {
|
799 |
+
'hidden_states': hidden_states,
|
800 |
+
'pooler_output': x if self.num_classes > 0 else None
|
801 |
+
}
|
802 |
+
|
803 |
+
def forward(self, x):
|
804 |
+
if self.use_clip_projector: # for InternImage-H/G
|
805 |
+
outputs = self.forward_clip_projector(x)
|
806 |
+
else: # for InternImage-T/S/B/L/XL
|
807 |
+
outputs = self.forward_features(x)
|
808 |
+
|
809 |
+
hidden_states = outputs['hidden_states']
|
810 |
+
pooler_output = outputs['pooler_output']
|
811 |
+
|
812 |
+
if self.num_classes > 0:
|
813 |
+
logits = self.head(pooler_output)
|
814 |
+
else:
|
815 |
+
logits = None
|
816 |
+
|
817 |
+
return BackboneOutput(
|
818 |
+
hidden_states=hidden_states,
|
819 |
+
last_hidden_state=hidden_states[-1],
|
820 |
+
pooler_output=pooler_output,
|
821 |
+
logits=logits
|
822 |
+
)
|
823 |
+
|
824 |
+
|
825 |
+
class InternImageModel(PreTrainedModel):
|
826 |
+
config_class = InternImageConfig
|
827 |
+
|
828 |
+
def __init__(self, config):
|
829 |
+
super().__init__(config)
|
830 |
+
self.model = InternImage(
|
831 |
+
core_op=config.core_op,
|
832 |
+
channels=config.channels,
|
833 |
+
depths=config.depths,
|
834 |
+
groups=config.groups,
|
835 |
+
num_classes=0,
|
836 |
+
mlp_ratio=config.mlp_ratio,
|
837 |
+
drop_rate=config.drop_rate,
|
838 |
+
drop_path_rate=config.drop_path_rate,
|
839 |
+
drop_path_type=config.drop_path_type,
|
840 |
+
act_layer=config.act_layer,
|
841 |
+
norm_layer=config.norm_layer,
|
842 |
+
layer_scale=config.layer_scale,
|
843 |
+
offset_scale=config.offset_scale,
|
844 |
+
post_norm=config.post_norm,
|
845 |
+
cls_scale=config.cls_scale,
|
846 |
+
with_cp=config.with_cp,
|
847 |
+
dw_kernel_size=config.dw_kernel_size, # for InternImage-H/G
|
848 |
+
use_clip_projector=config.use_clip_projector, # for InternImage-H/G
|
849 |
+
level2_post_norm=config.level2_post_norm, # for InternImage-H/G
|
850 |
+
level2_post_norm_block_ids=config.level2_post_norm_block_ids, # for InternImage-H/G
|
851 |
+
res_post_norm=config.res_post_norm, # for InternImage-H/G
|
852 |
+
center_feature_scale=config.center_feature_scale, # for InternImage-H/G
|
853 |
+
remove_center=config.remove_center, # for InternImage-H/G
|
854 |
+
)
|
855 |
+
|
856 |
+
def forward(self, tensor):
|
857 |
+
return self.model.forward_features(tensor)
|
858 |
+
|
859 |
+
|
860 |
+
class InternImageModelForImageClassification(PreTrainedModel):
|
861 |
+
config_class = InternImageConfig
|
862 |
+
|
863 |
+
def __init__(self, config):
|
864 |
+
super().__init__(config)
|
865 |
+
self.model = InternImage(
|
866 |
+
core_op=config.core_op,
|
867 |
+
channels=config.channels,
|
868 |
+
depths=config.depths,
|
869 |
+
groups=config.groups,
|
870 |
+
num_classes=config.num_classes,
|
871 |
+
mlp_ratio=config.mlp_ratio,
|
872 |
+
drop_rate=config.drop_rate,
|
873 |
+
drop_path_rate=config.drop_path_rate,
|
874 |
+
drop_path_type=config.drop_path_type,
|
875 |
+
act_layer=config.act_layer,
|
876 |
+
norm_layer=config.norm_layer,
|
877 |
+
layer_scale=config.layer_scale,
|
878 |
+
offset_scale=config.offset_scale,
|
879 |
+
post_norm=config.post_norm,
|
880 |
+
cls_scale=config.cls_scale,
|
881 |
+
with_cp=config.with_cp,
|
882 |
+
dw_kernel_size=config.dw_kernel_size, # for InternImage-H/G
|
883 |
+
use_clip_projector=config.use_clip_projector, # for InternImage-H/G
|
884 |
+
level2_post_norm=config.level2_post_norm, # for InternImage-H/G
|
885 |
+
level2_post_norm_block_ids=config.level2_post_norm_block_ids, # for InternImage-H/G
|
886 |
+
res_post_norm=config.res_post_norm, # for InternImage-H/G
|
887 |
+
center_feature_scale=config.center_feature_scale, # for InternImage-H/G
|
888 |
+
remove_center=config.remove_center, # for InternImage-H/G
|
889 |
+
)
|
890 |
+
|
891 |
+
def forward(self, tensor, labels=None):
|
892 |
+
outputs = self.model.forward(tensor)
|
893 |
+
|
894 |
+
if labels is not None:
|
895 |
+
logits = outputs['logits']
|
896 |
+
loss = F.cross_entropy(logits, labels)
|
897 |
+
outputs['loss'] = loss
|
898 |
+
|
899 |
+
return outputs
|
preprocessor_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 384,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_normalize": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
7 |
+
"image_mean": [
|
8 |
+
0.485,
|
9 |
+
0.456,
|
10 |
+
0.406
|
11 |
+
],
|
12 |
+
"image_std": [
|
13 |
+
0.229,
|
14 |
+
0.224,
|
15 |
+
0.225
|
16 |
+
],
|
17 |
+
"resample": 3,
|
18 |
+
"size": 384
|
19 |
+
}
|