JUGGHM commited on
Commit
0b1902c
·
verified ·
1 Parent(s): 36b8fb6

Delete mono

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mono/configs/HourglassDecoder/convlarge.0.3_150.py +0 -25
  2. mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py +0 -25
  3. mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py +0 -25
  4. mono/configs/HourglassDecoder/vit.raft5.large.py +0 -33
  5. mono/configs/HourglassDecoder/vit.raft5.small.py +0 -33
  6. mono/configs/__init__.py +0 -1
  7. mono/configs/_base_/_data_base_.py +0 -13
  8. mono/configs/_base_/datasets/_data_base_.py +0 -12
  9. mono/configs/_base_/default_runtime.py +0 -4
  10. mono/configs/_base_/models/backbones/convnext_large.py +0 -16
  11. mono/configs/_base_/models/backbones/dino_vit_large.py +0 -7
  12. mono/configs/_base_/models/backbones/dino_vit_large_reg.py +0 -7
  13. mono/configs/_base_/models/backbones/dino_vit_small_reg.py +0 -7
  14. mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py +0 -10
  15. mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py +0 -20
  16. mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py +0 -19
  17. mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py +0 -19
  18. mono/model/__init__.py +0 -5
  19. mono/model/__pycache__/__init__.cpython-39.pyc +0 -0
  20. mono/model/__pycache__/monodepth_model.cpython-39.pyc +0 -0
  21. mono/model/backbones/ConvNeXt.py +0 -271
  22. mono/model/backbones/ViT_DINO.py +0 -1504
  23. mono/model/backbones/ViT_DINO_reg.py +0 -1293
  24. mono/model/backbones/__init__.py +0 -11
  25. mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc +0 -0
  26. mono/model/backbones/__pycache__/__init__.cpython-39.pyc +0 -0
  27. mono/model/decode_heads/HourGlassDecoder.py +0 -274
  28. mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py +0 -1033
  29. mono/model/decode_heads/__init__.py +0 -4
  30. mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc +0 -0
  31. mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc +0 -0
  32. mono/model/model_pipelines/__base_model__.py +0 -20
  33. mono/model/model_pipelines/__init__.py +0 -6
  34. mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc +0 -0
  35. mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc +0 -0
  36. mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc +0 -0
  37. mono/model/model_pipelines/dense_pipeline.py +0 -16
  38. mono/model/monodepth_model.py +0 -37
  39. mono/tools/test_scale_cano.py +0 -158
  40. mono/utils/__init__.py +0 -1
  41. mono/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  42. mono/utils/__pycache__/avg_meter.cpython-39.pyc +0 -0
  43. mono/utils/__pycache__/comm.cpython-39.pyc +0 -0
  44. mono/utils/__pycache__/custom_data.cpython-39.pyc +0 -0
  45. mono/utils/__pycache__/do_test.cpython-39.pyc +0 -0
  46. mono/utils/__pycache__/logger.cpython-39.pyc +0 -0
  47. mono/utils/__pycache__/mldb.cpython-39.pyc +0 -0
  48. mono/utils/__pycache__/running.cpython-39.pyc +0 -0
  49. mono/utils/__pycache__/transform.cpython-39.pyc +0 -0
  50. mono/utils/__pycache__/unproj_pcd.cpython-39.pyc +0 -0
mono/configs/HourglassDecoder/convlarge.0.3_150.py DELETED
@@ -1,25 +0,0 @@
1
- _base_=[
2
- '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
3
- '../_base_/datasets/_data_base_.py',
4
- '../_base_/default_runtime.py',
5
- ]
6
-
7
- model = dict(
8
- backbone=dict(
9
- pretrained=False,
10
- )
11
- )
12
-
13
- # configs of the canonical space
14
- data_basic=dict(
15
- canonical_space = dict(
16
- img_size=(512, 960),
17
- focal_length=1000.0,
18
- ),
19
- depth_range=(0, 1),
20
- depth_normalize=(0.3, 150),
21
- crop_size = (544, 1216),
22
- )
23
-
24
- batchsize_per_gpu = 2
25
- thread_per_gpu = 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py DELETED
@@ -1,25 +0,0 @@
1
- _base_=[
2
- '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
3
- '../_base_/datasets/_data_base_.py',
4
- '../_base_/default_runtime.py',
5
- ]
6
-
7
- model = dict(
8
- backbone=dict(
9
- pretrained=False,
10
- )
11
- )
12
-
13
- # configs of the canonical space
14
- data_basic=dict(
15
- canonical_space = dict(
16
- img_size=(512, 960),
17
- focal_length=1000.0,
18
- ),
19
- depth_range=(0, 1),
20
- depth_normalize=(0.3, 150),
21
- crop_size = (512, 1088),
22
- )
23
-
24
- batchsize_per_gpu = 2
25
- thread_per_gpu = 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py DELETED
@@ -1,25 +0,0 @@
1
- _base_=[
2
- '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
3
- '../_base_/datasets/_data_base_.py',
4
- '../_base_/default_runtime.py',
5
- ]
6
-
7
- model = dict(
8
- backbone=dict(
9
- pretrained=False,
10
- )
11
- )
12
-
13
- # configs of the canonical space
14
- data_basic=dict(
15
- canonical_space = dict(
16
- img_size=(512, 960),
17
- focal_length=1000.0,
18
- ),
19
- depth_range=(0, 1),
20
- depth_normalize=(0.3, 150),
21
- crop_size = (480, 1216),
22
- )
23
-
24
- batchsize_per_gpu = 2
25
- thread_per_gpu = 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/HourglassDecoder/vit.raft5.large.py DELETED
@@ -1,33 +0,0 @@
1
- _base_=[
2
- '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
3
- '../_base_/datasets/_data_base_.py',
4
- '../_base_/default_runtime.py',
5
- ]
6
-
7
- import numpy as np
8
- model=dict(
9
- decode_head=dict(
10
- type='RAFTDepthNormalDPT5',
11
- iters=8,
12
- n_downsample=2,
13
- detach=False,
14
- )
15
- )
16
-
17
-
18
- max_value = 200
19
- # configs of the canonical space
20
- data_basic=dict(
21
- canonical_space = dict(
22
- # img_size=(540, 960),
23
- focal_length=1000.0,
24
- ),
25
- depth_range=(0, 1),
26
- depth_normalize=(0.1, max_value),
27
- crop_size = (616, 1064), # %28 = 0
28
- clip_depth_range=(0.1, 200),
29
- vit_size=(616,1064)
30
- )
31
-
32
- batchsize_per_gpu = 1
33
- thread_per_gpu = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/HourglassDecoder/vit.raft5.small.py DELETED
@@ -1,33 +0,0 @@
1
- _base_=[
2
- '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
3
- '../_base_/datasets/_data_base_.py',
4
- '../_base_/default_runtime.py',
5
- ]
6
-
7
- import numpy as np
8
- model=dict(
9
- decode_head=dict(
10
- type='RAFTDepthNormalDPT5',
11
- iters=4,
12
- n_downsample=2,
13
- detach=False,
14
- )
15
- )
16
-
17
-
18
- max_value = 200
19
- # configs of the canonical space
20
- data_basic=dict(
21
- canonical_space = dict(
22
- # img_size=(540, 960),
23
- focal_length=1000.0,
24
- ),
25
- depth_range=(0, 1),
26
- depth_normalize=(0.1, max_value),
27
- crop_size = (616, 1064), # %28 = 0
28
- clip_depth_range=(0.1, 200),
29
- vit_size=(616,1064)
30
- )
31
-
32
- batchsize_per_gpu = 1
33
- thread_per_gpu = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/__init__.py DELETED
@@ -1 +0,0 @@
1
-
 
 
mono/configs/_base_/_data_base_.py DELETED
@@ -1,13 +0,0 @@
1
- # canonical camera setting and basic data setting
2
- # we set it same as the E300 camera (crop version)
3
- #
4
- data_basic=dict(
5
- canonical_space = dict(
6
- img_size=(540, 960),
7
- focal_length=1196.0,
8
- ),
9
- depth_range=(0.9, 150),
10
- depth_normalize=(0.006, 1.001),
11
- crop_size = (512, 960),
12
- clip_depth_range=(0.9, 150),
13
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/_base_/datasets/_data_base_.py DELETED
@@ -1,12 +0,0 @@
1
- # canonical camera setting and basic data setting
2
- #
3
- data_basic=dict(
4
- canonical_space = dict(
5
- img_size=(540, 960),
6
- focal_length=1196.0,
7
- ),
8
- depth_range=(0.9, 150),
9
- depth_normalize=(0.006, 1.001),
10
- crop_size = (512, 960),
11
- clip_depth_range=(0.9, 150),
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/_base_/default_runtime.py DELETED
@@ -1,4 +0,0 @@
1
-
2
- load_from = None
3
- cudnn_benchmark = True
4
- test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3','rmse_log', 'log10', 'sq_rel']
 
 
 
 
 
mono/configs/_base_/models/backbones/convnext_large.py DELETED
@@ -1,16 +0,0 @@
1
- #_base_ = ['./_model_base_.py',]
2
-
3
- #'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth'
4
- model = dict(
5
- #type='EncoderDecoderAuxi',
6
- backbone=dict(
7
- type='convnext_large',
8
- pretrained=True,
9
- in_22k=True,
10
- out_indices=[0, 1, 2, 3],
11
- drop_path_rate=0.4,
12
- layer_scale_init_value=1.0,
13
- checkpoint='data/pretrained_weight_repo/convnext/convnext_large_22k_1k_384.pth',
14
- prefix='backbones.',
15
- out_channels=[192, 384, 768, 1536]),
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/_base_/models/backbones/dino_vit_large.py DELETED
@@ -1,7 +0,0 @@
1
- model = dict(
2
- backbone=dict(
3
- type='vit_large',
4
- prefix='backbones.',
5
- out_channels=[1024, 1024, 1024, 1024],
6
- drop_path_rate = 0.0),
7
- )
 
 
 
 
 
 
 
 
mono/configs/_base_/models/backbones/dino_vit_large_reg.py DELETED
@@ -1,7 +0,0 @@
1
- model = dict(
2
- backbone=dict(
3
- type='vit_large_reg',
4
- prefix='backbones.',
5
- out_channels=[1024, 1024, 1024, 1024],
6
- drop_path_rate = 0.0),
7
- )
 
 
 
 
 
 
 
 
mono/configs/_base_/models/backbones/dino_vit_small_reg.py DELETED
@@ -1,7 +0,0 @@
1
- model = dict(
2
- backbone=dict(
3
- type='vit_small_reg',
4
- prefix='backbones.',
5
- out_channels=[384, 384, 384, 384],
6
- drop_path_rate = 0.0),
7
- )
 
 
 
 
 
 
 
 
mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py DELETED
@@ -1,10 +0,0 @@
1
- # model settings
2
- _base_ = ['../backbones/convnext_large.py',]
3
- model = dict(
4
- type='DensePredModel',
5
- decode_head=dict(
6
- type='HourglassDecoder',
7
- in_channels=[192, 384, 768, 1536],
8
- decoder_channel=[128, 128, 256, 512],
9
- prefix='decode_heads.'),
10
- )
 
 
 
 
 
 
 
 
 
 
 
mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py DELETED
@@ -1,20 +0,0 @@
1
- # model settings
2
- _base_ = ['../backbones/dino_vit_large.py']
3
- model = dict(
4
- type='DensePredModel',
5
- decode_head=dict(
6
- type='RAFTDepthDPT',
7
- in_channels=[1024, 1024, 1024, 1024],
8
- use_cls_token=True,
9
- feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
10
- decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
11
- up_scale = 7,
12
- hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
13
- n_gru_layers=3,
14
- n_downsample=2,
15
- iters=12,
16
- slow_fast_gru=True,
17
- corr_radius=4,
18
- corr_levels=4,
19
- prefix='decode_heads.'),
20
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py DELETED
@@ -1,19 +0,0 @@
1
- # model settings
2
- _base_ = ['../backbones/dino_vit_large_reg.py']
3
- model = dict(
4
- type='DensePredModel',
5
- decode_head=dict(
6
- type='RAFTDepthDPT',
7
- in_channels=[1024, 1024, 1024, 1024],
8
- use_cls_token=True,
9
- feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
10
- decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
11
- up_scale = 7,
12
- hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
13
- n_gru_layers=3,
14
- n_downsample=2,
15
- iters=3,
16
- slow_fast_gru=True,
17
- num_register_tokens=4,
18
- prefix='decode_heads.'),
19
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py DELETED
@@ -1,19 +0,0 @@
1
- # model settings
2
- _base_ = ['../backbones/dino_vit_small_reg.py']
3
- model = dict(
4
- type='DensePredModel',
5
- decode_head=dict(
6
- type='RAFTDepthDPT',
7
- in_channels=[384, 384, 384, 384],
8
- use_cls_token=True,
9
- feature_channels = [96, 192, 384, 768], # [2/7, 1/7, 1/14, 1/14]
10
- decoder_channels = [48, 96, 192, 384, 384], # [-, 1/4, 1/7, 1/14, 1/14]
11
- up_scale = 7,
12
- hidden_channels=[48, 48, 48, 48], # [x_4, x_8, x_16, x_32] [1/4, 1/7, 1/14, -]
13
- n_gru_layers=3,
14
- n_downsample=2,
15
- iters=3,
16
- slow_fast_gru=True,
17
- num_register_tokens=4,
18
- prefix='decode_heads.'),
19
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from .monodepth_model import DepthModel
2
- # from .__base_model__ import BaseDepthModel
3
-
4
-
5
- __all__ = ['DepthModel', 'BaseDepthModel']
 
 
 
 
 
 
mono/model/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (250 Bytes)
 
mono/model/__pycache__/monodepth_model.cpython-39.pyc DELETED
Binary file (1.62 kB)
 
mono/model/backbones/ConvNeXt.py DELETED
@@ -1,271 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from timm.models.layers import trunc_normal_, DropPath
5
- from timm.models.registry import register_model
6
-
7
- class Block(nn.Module):
8
- r""" ConvNeXt Block. There are two equivalent implementations:
9
- (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
10
- (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
11
- We use (2) as we find it slightly faster in PyTorch
12
-
13
- Args:
14
- dim (int): Number of input channels.
15
- drop_path (float): Stochastic depth rate. Default: 0.0
16
- layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
17
- """
18
- def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
19
- super().__init__()
20
- self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
21
- self.norm = LayerNorm(dim, eps=1e-6)
22
- self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
23
- self.act = nn.GELU()
24
- self.pwconv2 = nn.Linear(4 * dim, dim)
25
- self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
26
- requires_grad=True) if layer_scale_init_value > 0 else None
27
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
28
-
29
- def forward(self, x):
30
- input = x
31
- x = self.dwconv(x)
32
- x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
33
- x = self.norm(x)
34
- x = self.pwconv1(x)
35
- x = self.act(x)
36
- x = self.pwconv2(x)
37
- if self.gamma is not None:
38
- x = self.gamma * x
39
- x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
40
-
41
- x = input + self.drop_path(x)
42
- return x
43
-
44
- class ConvNeXt(nn.Module):
45
- r""" ConvNeXt
46
- A PyTorch impl of : `A ConvNet for the 2020s` -
47
- https://arxiv.org/pdf/2201.03545.pdf
48
- Args:
49
- in_chans (int): Number of input image channels. Default: 3
50
- num_classes (int): Number of classes for classification head. Default: 1000
51
- depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
52
- dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
53
- drop_path_rate (float): Stochastic depth rate. Default: 0.
54
- layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
55
- head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
56
- """
57
- def __init__(self, in_chans=3, num_classes=1000,
58
- depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
59
- layer_scale_init_value=1e-6, head_init_scale=1.,
60
- **kwargs,):
61
- super().__init__()
62
-
63
- self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
64
- stem = nn.Sequential(
65
- nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
66
- LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
67
- )
68
- self.downsample_layers.append(stem)
69
- for i in range(3):
70
- downsample_layer = nn.Sequential(
71
- LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
72
- nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
73
- )
74
- self.downsample_layers.append(downsample_layer)
75
-
76
- self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
77
- dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
78
- cur = 0
79
- for i in range(4):
80
- stage = nn.Sequential(
81
- *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
82
- layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
83
- )
84
- self.stages.append(stage)
85
- cur += depths[i]
86
-
87
- #self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
88
- #self.head = nn.Linear(dims[-1], num_classes)
89
-
90
- self.apply(self._init_weights)
91
- #self.head.weight.data.mul_(head_init_scale)
92
- #self.head.bias.data.mul_(head_init_scale)
93
-
94
- def _init_weights(self, m):
95
- if isinstance(m, (nn.Conv2d, nn.Linear)):
96
- trunc_normal_(m.weight, std=.02)
97
- nn.init.constant_(m.bias, 0)
98
-
99
- def forward_features(self, x):
100
- features = []
101
- for i in range(4):
102
- x = self.downsample_layers[i](x)
103
- x = self.stages[i](x)
104
- features.append(x)
105
- return features # global average pooling, (N, C, H, W) -> (N, C)
106
-
107
- def forward(self, x):
108
- #x = self.forward_features(x)
109
- #x = self.head(x)
110
- features = self.forward_features(x)
111
- return features
112
-
113
- class LayerNorm(nn.Module):
114
- r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
115
- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
116
- shape (batch_size, height, width, channels) while channels_first corresponds to inputs
117
- with shape (batch_size, channels, height, width).
118
- """
119
- def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
120
- super().__init__()
121
- self.weight = nn.Parameter(torch.ones(normalized_shape))
122
- self.bias = nn.Parameter(torch.zeros(normalized_shape))
123
- self.eps = eps
124
- self.data_format = data_format
125
- if self.data_format not in ["channels_last", "channels_first"]:
126
- raise NotImplementedError
127
- self.normalized_shape = (normalized_shape, )
128
-
129
- def forward(self, x):
130
- if self.data_format == "channels_last":
131
- return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
132
- elif self.data_format == "channels_first":
133
- u = x.mean(1, keepdim=True)
134
- s = (x - u).pow(2).mean(1, keepdim=True)
135
- x = (x - u) / torch.sqrt(s + self.eps)
136
- x = self.weight[:, None, None] * x + self.bias[:, None, None]
137
- return x
138
-
139
-
140
- model_urls = {
141
- "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
142
- "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
143
- "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
144
- "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
145
- "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
146
- "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
147
- "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
148
- "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
149
- "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
150
- }
151
-
152
- def convnext_tiny(pretrained=True,in_22k=False, **kwargs):
153
- model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
154
- if pretrained:
155
- checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
156
- #url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
157
- #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
158
- model_dict = model.state_dict()
159
- pretrained_dict = {}
160
- unmatched_pretrained_dict = {}
161
- for k, v in checkpoint['model'].items():
162
- if k in model_dict:
163
- pretrained_dict[k] = v
164
- else:
165
- unmatched_pretrained_dict[k] = v
166
- model_dict.update(pretrained_dict)
167
- model.load_state_dict(model_dict)
168
- print(
169
- 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
170
- %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
171
- print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
172
- return model
173
-
174
- def convnext_small(pretrained=True,in_22k=False, **kwargs):
175
- model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
176
- if pretrained:
177
- checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
178
- #url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
179
- #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
180
- model_dict = model.state_dict()
181
- pretrained_dict = {}
182
- unmatched_pretrained_dict = {}
183
- for k, v in checkpoint['model'].items():
184
- if k in model_dict:
185
- pretrained_dict[k] = v
186
- else:
187
- unmatched_pretrained_dict[k] = v
188
- model_dict.update(pretrained_dict)
189
- model.load_state_dict(model_dict)
190
- print(
191
- 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
192
- %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
193
- print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
194
- return model
195
-
196
- def convnext_base(pretrained=True, in_22k=False, **kwargs):
197
- model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
198
- if pretrained:
199
- checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
200
- #url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
201
- #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
202
- model_dict = model.state_dict()
203
- pretrained_dict = {}
204
- unmatched_pretrained_dict = {}
205
- for k, v in checkpoint['model'].items():
206
- if k in model_dict:
207
- pretrained_dict[k] = v
208
- else:
209
- unmatched_pretrained_dict[k] = v
210
- model_dict.update(pretrained_dict)
211
- model.load_state_dict(model_dict)
212
- print(
213
- 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
214
- %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
215
- print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
216
- return model
217
-
218
- def convnext_large(pretrained=True, in_22k=False, **kwargs):
219
- model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
220
- if pretrained:
221
- checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
222
- #url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
223
- #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
224
- model_dict = model.state_dict()
225
- pretrained_dict = {}
226
- unmatched_pretrained_dict = {}
227
- for k, v in checkpoint['model'].items():
228
- if k in model_dict:
229
- pretrained_dict[k] = v
230
- else:
231
- unmatched_pretrained_dict[k] = v
232
- model_dict.update(pretrained_dict)
233
- model.load_state_dict(model_dict)
234
- print(
235
- 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
236
- %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
237
- print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
238
- return model
239
-
240
- def convnext_xlarge(pretrained=True, in_22k=False, **kwargs):
241
- model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
242
- if pretrained:
243
- assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
244
- checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
245
- #url = model_urls['convnext_xlarge_22k']
246
- #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
247
- model_dict = model.state_dict()
248
- pretrained_dict = {}
249
- unmatched_pretrained_dict = {}
250
- for k, v in checkpoint['model'].items():
251
- if k in model_dict:
252
- pretrained_dict[k] = v
253
- else:
254
- unmatched_pretrained_dict[k] = v
255
- model_dict.update(pretrained_dict)
256
- model.load_state_dict(model_dict)
257
- print(
258
- 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
259
- %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
260
- print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
261
- return model
262
-
263
- if __name__ == '__main__':
264
- import torch
265
- model = convnext_base(True, in_22k=False).cuda()
266
-
267
- rgb = torch.rand((2, 3, 256, 256)).cuda()
268
- out = model(rgb)
269
- print(len(out))
270
- for i, ft in enumerate(out):
271
- print(i, ft.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/backbones/ViT_DINO.py DELETED
@@ -1,1504 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # References:
8
- # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
-
11
- from functools import partial
12
- import math
13
- import logging
14
- from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
15
-
16
- import torch
17
- import torch.nn as nn
18
- from torch import Tensor
19
- import torch.utils.checkpoint
20
- from torch.nn.init import trunc_normal_
21
-
22
- #from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
23
-
24
- logger = logging.getLogger("dinov2")
25
-
26
- class ConvBlock(nn.Module):
27
- def __init__(self, channels):
28
- super(ConvBlock, self).__init__()
29
-
30
- self.act = nn.ReLU(inplace=True)
31
- self.conv1 = nn.Conv2d(
32
- channels,
33
- channels,
34
- kernel_size=3,
35
- stride=1,
36
- padding=1
37
- )
38
- self.norm1 = nn.BatchNorm2d(channels)
39
- self.conv2 = nn.Conv2d(
40
- channels,
41
- channels,
42
- kernel_size=3,
43
- stride=1,
44
- padding=1
45
- )
46
- self.norm2 = nn.BatchNorm2d(channels)
47
-
48
- def forward(self, x):
49
-
50
- out = self.norm1(x)
51
- out = self.act(out)
52
- out = self.conv1(out)
53
- out = self.norm2(out)
54
- out = self.act(out)
55
- out = self.conv2(out)
56
- return x + out
57
-
58
- def make_2tuple(x):
59
- if isinstance(x, tuple):
60
- assert len(x) == 2
61
- return x
62
-
63
- assert isinstance(x, int)
64
- return (x, x)
65
-
66
- def drop_path(x, drop_prob: float = 0.0, training: bool = False):
67
- if drop_prob == 0.0 or not training:
68
- return x
69
- keep_prob = 1 - drop_prob
70
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
71
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
72
- if keep_prob > 0.0:
73
- random_tensor.div_(keep_prob)
74
- output = x * random_tensor
75
- return output
76
-
77
- class DropPath(nn.Module):
78
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
79
-
80
- def __init__(self, drop_prob=None):
81
- super(DropPath, self).__init__()
82
- self.drop_prob = drop_prob
83
-
84
- def forward(self, x):
85
- return drop_path(x, self.drop_prob, self.training)
86
-
87
- class LayerScale(nn.Module):
88
- def __init__(
89
- self,
90
- dim: int,
91
- init_values: Union[float, Tensor] = 1e-5,
92
- inplace: bool = False,
93
- ) -> None:
94
- super().__init__()
95
- self.inplace = inplace
96
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
97
-
98
- def forward(self, x: Tensor) -> Tensor:
99
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
100
-
101
-
102
- class PatchEmbed(nn.Module):
103
- """
104
- 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
105
-
106
- Args:
107
- img_size: Image size.
108
- patch_size: Patch token size.
109
- in_chans: Number of input image channels.
110
- embed_dim: Number of linear projection output channels.
111
- norm_layer: Normalization layer.
112
- """
113
-
114
- def __init__(
115
- self,
116
- img_size: Union[int, Tuple[int, int]] = 224,
117
- patch_size: Union[int, Tuple[int, int]] = 16,
118
- in_chans: int = 3,
119
- embed_dim: int = 768,
120
- norm_layer: Optional[Callable] = None,
121
- flatten_embedding: bool = True,
122
- ) -> None:
123
- super().__init__()
124
-
125
- image_HW = make_2tuple(img_size)
126
- patch_HW = make_2tuple(patch_size)
127
- patch_grid_size = (
128
- image_HW[0] // patch_HW[0],
129
- image_HW[1] // patch_HW[1],
130
- )
131
-
132
- self.img_size = image_HW
133
- self.patch_size = patch_HW
134
- self.patches_resolution = patch_grid_size
135
- self.num_patches = patch_grid_size[0] * patch_grid_size[1]
136
-
137
- self.in_chans = in_chans
138
- self.embed_dim = embed_dim
139
-
140
- self.flatten_embedding = flatten_embedding
141
-
142
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
143
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
144
-
145
- def forward(self, x: Tensor) -> Tensor:
146
- _, _, H, W = x.shape
147
- patch_H, patch_W = self.patch_size
148
-
149
- assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
150
- assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
151
-
152
- x = self.proj(x) # B C H W
153
- H, W = x.size(2), x.size(3)
154
- x = x.flatten(2).transpose(1, 2) # B HW C
155
- x = self.norm(x)
156
- if not self.flatten_embedding:
157
- x = x.reshape(-1, H, W, self.embed_dim) # B H W C
158
- return x
159
-
160
- def flops(self) -> float:
161
- Ho, Wo = self.patches_resolution
162
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
163
- if self.norm is not None:
164
- flops += Ho * Wo * self.embed_dim
165
- return flops
166
-
167
- class Mlp(nn.Module):
168
- def __init__(
169
- self,
170
- in_features: int,
171
- hidden_features: Optional[int] = None,
172
- out_features: Optional[int] = None,
173
- act_layer: Callable[..., nn.Module] = nn.GELU,
174
- drop: float = 0.0,
175
- bias: bool = True,
176
- ) -> None:
177
- super().__init__()
178
- out_features = out_features or in_features
179
- hidden_features = hidden_features or in_features
180
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
181
- self.act = act_layer()
182
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
183
- self.drop = nn.Dropout(drop)
184
-
185
- def forward(self, x: Tensor) -> Tensor:
186
- x = self.fc1(x)
187
- x = self.act(x)
188
- x = self.drop(x)
189
- x = self.fc2(x)
190
- x = self.drop(x)
191
- return x
192
-
193
-
194
- class SwiGLUFFN(nn.Module):
195
- def __init__(
196
- self,
197
- in_features: int,
198
- hidden_features: Optional[int] = None,
199
- out_features: Optional[int] = None,
200
- act_layer: Callable[..., nn.Module] = None,
201
- drop: float = 0.0,
202
- bias: bool = True,
203
- ) -> None:
204
- super().__init__()
205
- out_features = out_features or in_features
206
- hidden_features = hidden_features or in_features
207
- self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
208
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
209
-
210
- def forward(self, x: Tensor) -> Tensor:
211
- x12 = self.w12(x)
212
- x1, x2 = x12.chunk(2, dim=-1)
213
- hidden = F.silu(x1) * x2
214
- return self.w3(hidden)
215
-
216
-
217
- try:
218
- from xformers.ops import SwiGLU
219
- #import numpy.bool
220
- XFORMERS_AVAILABLE = True
221
- except ImportError:
222
- SwiGLU = SwiGLUFFN
223
- XFORMERS_AVAILABLE = False
224
-
225
- class SwiGLUFFNFused(SwiGLU):
226
- def __init__(
227
- self,
228
- in_features: int,
229
- hidden_features: Optional[int] = None,
230
- out_features: Optional[int] = None,
231
- act_layer: Callable[..., nn.Module] = None,
232
- drop: float = 0.0,
233
- bias: bool = True,
234
- ) -> None:
235
- out_features = out_features or in_features
236
- hidden_features = hidden_features or in_features
237
- hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
238
- super().__init__(
239
- in_features=in_features,
240
- hidden_features=hidden_features,
241
- out_features=out_features,
242
- bias=bias,
243
- )
244
-
245
-
246
- try:
247
- from xformers.ops import memory_efficient_attention, unbind, fmha
248
- from xformers.components.attention import ScaledDotProduct
249
- from xformers.components import MultiHeadDispatch
250
- #import numpy.bool
251
- XFORMERS_AVAILABLE = True
252
- except ImportError:
253
- logger.warning("xFormers not available")
254
- XFORMERS_AVAILABLE = False
255
-
256
-
257
- class Attention(nn.Module):
258
- def __init__(
259
- self,
260
- dim: int,
261
- num_heads: int = 8,
262
- qkv_bias: bool = False,
263
- proj_bias: bool = True,
264
- attn_drop: float = 0.0,
265
- proj_drop: float = 0.0,
266
- window_size: int = 0,
267
- ) -> None:
268
- super().__init__()
269
- self.num_heads = num_heads
270
- head_dim = dim // num_heads
271
- self.scale = head_dim**-0.5
272
-
273
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
274
- self.attn_drop = nn.Dropout(attn_drop)
275
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
276
- self.proj_drop = nn.Dropout(proj_drop)
277
-
278
- #if not self.training:
279
- #
280
- # self.attn = ScaledDotProduct()
281
- #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
282
-
283
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
284
- B, N, C = x.shape
285
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
286
-
287
- q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
288
- attn = q @ k.transpose(-2, -1)
289
-
290
- if attn_bias is not None:
291
- attn = attn + attn_bias[:, :, :N]
292
-
293
- attn = attn.softmax(dim=-1)
294
- attn = self.attn_drop(attn)
295
-
296
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
297
- x = self.proj(x)
298
- x = self.proj_drop(x)
299
- return x
300
-
301
-
302
- class MemEffAttention(Attention):
303
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
304
- if not XFORMERS_AVAILABLE:
305
- #if True:
306
- assert attn_bias is None, "xFormers is required for nested tensors usage"
307
- return super().forward(x, attn_bias)
308
-
309
- B, N, C = x.shape
310
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
311
-
312
- q, k, v = unbind(qkv, 2)
313
- if attn_bias is not None:
314
- x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
315
- else:
316
- x = memory_efficient_attention(q, k, v)
317
- x = x.reshape([B, N, C])
318
-
319
- x = self.proj(x)
320
- x = self.proj_drop(x)
321
- return x
322
-
323
- try:
324
- from xformers.ops import fmha
325
- from xformers.ops import scaled_index_add, index_select_cat
326
- #import numpy.bool
327
- XFORMERS_AVAILABLE = True
328
- except ImportError:
329
- logger.warning("xFormers not available")
330
- XFORMERS_AVAILABLE = False
331
-
332
- class Block(nn.Module):
333
- def __init__(
334
- self,
335
- dim: int,
336
- num_heads: int,
337
- mlp_ratio: float = 4.0,
338
- qkv_bias: bool = False,
339
- proj_bias: bool = True,
340
- ffn_bias: bool = True,
341
- drop: float = 0.0,
342
- attn_drop: float = 0.0,
343
- init_values = None,
344
- drop_path: float = 0.0,
345
- act_layer: Callable[..., nn.Module] = nn.GELU,
346
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
347
- attn_class: Callable[..., nn.Module] = Attention,
348
- ffn_layer: Callable[..., nn.Module] = Mlp,
349
- ) -> None:
350
- super().__init__()
351
- # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
352
- self.norm1 = norm_layer(dim)
353
- self.attn = attn_class(
354
- dim,
355
- num_heads=num_heads,
356
- qkv_bias=qkv_bias,
357
- proj_bias=proj_bias,
358
- attn_drop=attn_drop,
359
- proj_drop=drop,
360
- )
361
- self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
362
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
363
-
364
- self.norm2 = norm_layer(dim)
365
- mlp_hidden_dim = int(dim * mlp_ratio)
366
- self.mlp = ffn_layer(
367
- in_features=dim,
368
- hidden_features=mlp_hidden_dim,
369
- act_layer=act_layer,
370
- drop=drop,
371
- bias=ffn_bias,
372
- )
373
- self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
374
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
375
-
376
- self.sample_drop_ratio = drop_path
377
-
378
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
379
- def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
380
- return self.ls1(self.attn(self.norm1(x), attn_bias))
381
-
382
- def ffn_residual_func(x: Tensor) -> Tensor:
383
- return self.ls2(self.mlp(self.norm2(x)))
384
-
385
- if self.training and self.sample_drop_ratio > 0.1:
386
- # the overhead is compensated only for a drop path rate larger than 0.1
387
- x = drop_add_residual_stochastic_depth(
388
- x,
389
- residual_func=attn_residual_func,
390
- sample_drop_ratio=self.sample_drop_ratio,
391
- attn_bias=attn_bias
392
- )
393
- x = drop_add_residual_stochastic_depth(
394
- x,
395
- residual_func=ffn_residual_func,
396
- sample_drop_ratio=self.sample_drop_ratio,
397
- )
398
- elif self.training and self.sample_drop_ratio > 0.0:
399
- x = x + self.drop_path1(attn_residual_func(x, attn_bias))
400
- x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
401
- else:
402
- x = x + attn_residual_func(x, attn_bias)
403
- x = x + ffn_residual_func(x)
404
- return x
405
-
406
-
407
- def drop_add_residual_stochastic_depth(
408
- x: Tensor,
409
- residual_func: Callable[[Tensor], Tensor],
410
- sample_drop_ratio: float = 0.0, attn_bias=None
411
- ) -> Tensor:
412
- # 1) extract subset using permutation
413
- b, n, d = x.shape
414
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
415
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
416
- x_subset = x[brange]
417
-
418
- # 2) apply residual_func to get residual
419
- residual = residual_func(x_subset, attn_bias)
420
-
421
- x_flat = x.flatten(1)
422
- residual = residual.flatten(1)
423
-
424
- residual_scale_factor = b / sample_subset_size
425
-
426
- # 3) add the residual
427
- x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
428
- return x_plus_residual.view_as(x)
429
-
430
-
431
- def get_branges_scales(x, sample_drop_ratio=0.0):
432
- b, n, d = x.shape
433
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
434
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
435
- residual_scale_factor = b / sample_subset_size
436
- return brange, residual_scale_factor
437
-
438
-
439
- def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
440
- if scaling_vector is None:
441
- x_flat = x.flatten(1)
442
- residual = residual.flatten(1)
443
- x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
444
- else:
445
- x_plus_residual = scaled_index_add(
446
- x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
447
- )
448
- return x_plus_residual
449
-
450
-
451
- attn_bias_cache: Dict[Tuple, Any] = {}
452
-
453
-
454
- def get_attn_bias_and_cat(x_list, branges=None):
455
- """
456
- this will perform the index select, cat the tensors, and provide the attn_bias from cache
457
- """
458
- batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
459
- all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
460
- if all_shapes not in attn_bias_cache.keys():
461
- seqlens = []
462
- for b, x in zip(batch_sizes, x_list):
463
- for _ in range(b):
464
- seqlens.append(x.shape[1])
465
- attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
466
- attn_bias._batch_sizes = batch_sizes
467
- attn_bias_cache[all_shapes] = attn_bias
468
-
469
- if branges is not None:
470
- cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
471
- else:
472
- tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
473
- cat_tensors = torch.cat(tensors_bs1, dim=1)
474
-
475
- return attn_bias_cache[all_shapes], cat_tensors
476
-
477
-
478
- def drop_add_residual_stochastic_depth_list(
479
- x_list: List[Tensor],
480
- residual_func: Callable[[Tensor, Any], Tensor],
481
- sample_drop_ratio: float = 0.0,
482
- scaling_vector=None,
483
- ) -> Tensor:
484
- # 1) generate random set of indices for dropping samples in the batch
485
- branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
486
- branges = [s[0] for s in branges_scales]
487
- residual_scale_factors = [s[1] for s in branges_scales]
488
-
489
- # 2) get attention bias and index+concat the tensors
490
- attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
491
-
492
- # 3) apply residual_func to get residual, and split the result
493
- residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
494
-
495
- outputs = []
496
- for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
497
- outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
498
- return outputs
499
-
500
-
501
- class NestedTensorBlock(Block):
502
- def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
503
- """
504
- x_list contains a list of tensors to nest together and run
505
- """
506
- assert isinstance(self.attn, MemEffAttention)
507
-
508
- if self.training and self.sample_drop_ratio > 0.0:
509
-
510
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
511
- return self.attn(self.norm1(x), attn_bias=attn_bias)
512
-
513
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
514
- return self.mlp(self.norm2(x))
515
-
516
- x_list = drop_add_residual_stochastic_depth_list(
517
- x_list,
518
- residual_func=attn_residual_func,
519
- sample_drop_ratio=self.sample_drop_ratio,
520
- scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
521
- )
522
- x_list = drop_add_residual_stochastic_depth_list(
523
- x_list,
524
- residual_func=ffn_residual_func,
525
- sample_drop_ratio=self.sample_drop_ratio,
526
- scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
527
- )
528
- return x_list
529
- else:
530
-
531
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
532
- return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
533
-
534
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
535
- return self.ls2(self.mlp(self.norm2(x)))
536
-
537
- attn_bias, x = get_attn_bias_and_cat(x_list)
538
- x = x + attn_residual_func(x, attn_bias=attn_bias)
539
- x = x + ffn_residual_func(x)
540
- return attn_bias.split(x)
541
-
542
- def forward(self, x_or_x_list, attn_bias=None):
543
- if isinstance(x_or_x_list, Tensor):
544
- return super().forward(x_or_x_list, attn_bias)
545
- elif isinstance(x_or_x_list, list):
546
- assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
547
- return self.forward_nested(x_or_x_list)
548
- else:
549
- raise AssertionError
550
-
551
-
552
- def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
553
- if not depth_first and include_root:
554
- fn(module=module, name=name)
555
- for child_name, child_module in module.named_children():
556
- child_name = ".".join((name, child_name)) if name else child_name
557
- named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
558
- if depth_first and include_root:
559
- fn(module=module, name=name)
560
- return module
561
-
562
-
563
- class BlockChunk(nn.ModuleList):
564
- def forward(self, x, others=None):
565
- for b in self:
566
- if others == None:
567
- x = b(x)
568
- else:
569
- x = b(x, others)
570
- return x
571
-
572
-
573
- class DinoVisionTransformer(nn.Module):
574
- def __init__(
575
- self,
576
- img_size=224,
577
- patch_size=16,
578
- in_chans=3,
579
- embed_dim=768,
580
- depth=12,
581
- num_heads=12,
582
- mlp_ratio=4.0,
583
- qkv_bias=True,
584
- ffn_bias=True,
585
- proj_bias=True,
586
- drop_path_rate=0.0,
587
- drop_path_uniform=False,
588
- #init_values=None, # for layerscale: None or 0 => no layerscale
589
- init_values=1e-5, # for layerscale: None or 0 => no layerscale
590
- embed_layer=PatchEmbed,
591
- act_layer=nn.GELU,
592
- block_fn=NestedTensorBlock,
593
- ffn_layer="mlp",
594
- block_chunks=1,
595
- window_size=37,
596
- **kwargs
597
- ):
598
- """
599
- Args:
600
- img_size (int, tuple): input image size
601
- patch_size (int, tuple): patch size
602
- in_chans (int): number of input channels
603
- embed_dim (int): embedding dimension
604
- depth (int): depth of transformer
605
- num_heads (int): number of attention heads
606
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
607
- qkv_bias (bool): enable bias for qkv if True
608
- proj_bias (bool): enable bias for proj in attn if True
609
- ffn_bias (bool): enable bias for ffn if True
610
- drop_path_rate (float): stochastic depth rate
611
- drop_path_uniform (bool): apply uniform drop rate across blocks
612
- weight_init (str): weight init scheme
613
- init_values (float): layer-scale init values
614
- embed_layer (nn.Module): patch embedding layer
615
- act_layer (nn.Module): MLP activation layer
616
- block_fn (nn.Module): transformer block class
617
- ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
618
- block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
619
- """
620
- super().__init__()
621
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
622
-
623
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
624
- self.num_tokens = 1
625
- self.n_blocks = depth
626
- self.num_heads = num_heads
627
- self.patch_size = patch_size
628
- self.window_size = window_size
629
-
630
- self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
631
- num_patches = self.patch_embed.num_patches
632
-
633
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
634
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
635
-
636
- if drop_path_uniform is True:
637
- dpr = [drop_path_rate] * depth
638
- else:
639
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
640
-
641
- if ffn_layer == "mlp":
642
- logger.info("using MLP layer as FFN")
643
- ffn_layer = Mlp
644
- elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
645
- logger.info("using SwiGLU layer as FFN")
646
- ffn_layer = SwiGLUFFNFused
647
- elif ffn_layer == "identity":
648
- logger.info("using Identity layer as FFN")
649
-
650
- def f(*args, **kwargs):
651
- return nn.Identity()
652
-
653
- ffn_layer = f
654
- else:
655
- raise NotImplementedError
656
-
657
- blocks_list = [
658
- block_fn(
659
- dim=embed_dim,
660
- num_heads=num_heads,
661
- mlp_ratio=mlp_ratio,
662
- qkv_bias=qkv_bias,
663
- proj_bias=proj_bias,
664
- ffn_bias=ffn_bias,
665
- drop_path=dpr[i],
666
- norm_layer=norm_layer,
667
- act_layer=act_layer,
668
- ffn_layer=ffn_layer,
669
- init_values=init_values,
670
- )
671
- for i in range(depth)
672
- ]
673
- if block_chunks > 0:
674
- self.chunked_blocks = True
675
- chunked_blocks = []
676
- chunksize = depth // block_chunks
677
- for i in range(0, depth, chunksize):
678
- # this is to keep the block index consistent if we chunk the block list
679
- chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
680
- self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
681
- else:
682
- self.chunked_blocks = False
683
- self.blocks = nn.ModuleList(blocks_list)
684
-
685
- self.norm = norm_layer(embed_dim)
686
- self.head = nn.Identity()
687
-
688
- self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
689
-
690
- self.init_weights()
691
-
692
- def init_weights(self):
693
- trunc_normal_(self.pos_embed, std=0.02)
694
- nn.init.normal_(self.cls_token, std=1e-6)
695
- named_apply(init_weights_vit_timm, self)
696
-
697
- def interpolate_pos_encoding(self, x, w, h):
698
- previous_dtype = x.dtype
699
- npatch = x.shape[1] - 1
700
- N = self.pos_embed.shape[1] - 1
701
- if npatch == N and w == h:
702
- return self.pos_embed
703
- pos_embed = self.pos_embed.float()
704
- class_pos_embed = pos_embed[:, 0]
705
- patch_pos_embed = pos_embed[:, 1:]
706
- dim = x.shape[-1]
707
- w0 = w // self.patch_size
708
- h0 = h // self.patch_size
709
- # we add a small number to avoid floating point error in the interpolation
710
- # see discussion at https://github.com/facebookresearch/dino/issues/8
711
- w0, h0 = w0 + 0.1, h0 + 0.1
712
-
713
- patch_pos_embed = nn.functional.interpolate(
714
- patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
715
- scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
716
- mode="bicubic",
717
- )
718
-
719
- assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
720
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
721
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
722
-
723
- def prepare_tokens_with_masks(self, x, masks=None):
724
- B, nc, w, h = x.shape
725
- x = self.patch_embed(x)
726
- if masks is not None:
727
- x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
728
-
729
- x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
730
- x = x + self.interpolate_pos_encoding(x, w, h)
731
-
732
- return x
733
-
734
- def forward_features_list(self, x_list, masks_list):
735
- x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
736
- for blk in self.blocks:
737
- x = blk(x)
738
-
739
- all_x = x
740
- output = []
741
- for x, masks in zip(all_x, masks_list):
742
- x_norm = self.norm(x)
743
- output.append(
744
- {
745
- "x_norm_clstoken": x_norm[:, 0],
746
- "x_norm_patchtokens": x_norm[:, 1:],
747
- "x_prenorm": x,
748
- "masks": masks,
749
- }
750
- )
751
- return output
752
-
753
- def forward_features(self, x, masks=None):
754
- if isinstance(x, list):
755
- return self.forward_features_list(x, masks)
756
-
757
- B, C, H, W = x.size()
758
- pad_h = (self.patch_size - H % self.patch_size)
759
- pad_w = (self.patch_size - W % self.patch_size)
760
- if pad_h == self.patch_size:
761
- pad_h = 0
762
- if pad_w == self.patch_size:
763
- pad_w = 0
764
- #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
765
- if pad_h + pad_w > 0:
766
- x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
767
-
768
- x = self.prepare_tokens_with_masks(x, masks)
769
-
770
- features = []
771
- for blk in self.blocks:
772
- x = blk(x)
773
- # for idx in range(len(self.blocks[0])):
774
- # x = self.blocks[0][idx](x)
775
- # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
776
- # features.append(x)
777
-
778
- #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
779
-
780
- x_norm = self.norm(x)
781
- # return {
782
- # "x_norm_clstoken": x_norm[:, 0],
783
- # "x_norm_patchtokens": x_norm[:, 1:],
784
- # "x_prenorm": x,
785
- # "masks": masks,
786
- # }
787
- features = []
788
- features.append(x_norm)
789
- features.append(x_norm)
790
- features.append(x_norm)
791
- features.append(x_norm)
792
- return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
793
-
794
- def _get_intermediate_layers_not_chunked(self, x, n=1):
795
- x = self.prepare_tokens_with_masks(x)
796
- # If n is an int, take the n last blocks. If it's a list, take them
797
- output, total_block_len = [], len(self.blocks)
798
- blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
799
- for i, blk in enumerate(self.blocks):
800
- x = blk(x)
801
- if i in blocks_to_take:
802
- output.append(x)
803
- assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
804
- return output
805
-
806
- def _get_intermediate_layers_chunked(self, x, n=1):
807
- x = self.prepare_tokens_with_masks(x)
808
- output, i, total_block_len = [], 0, len(self.blocks[-1])
809
- # If n is an int, take the n last blocks. If it's a list, take them
810
- blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
811
- for block_chunk in self.blocks:
812
- for blk in block_chunk[i:]: # Passing the nn.Identity()
813
- x = blk(x)
814
- if i in blocks_to_take:
815
- output.append(x)
816
- i += 1
817
- assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
818
- return output
819
-
820
- def get_intermediate_layers(
821
- self,
822
- x: torch.Tensor,
823
- n: Union[int, Sequence] = 1, # Layers or n last layers to take
824
- reshape: bool = False,
825
- return_class_token: bool = False,
826
- norm=True,
827
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
828
- if self.chunked_blocks:
829
- outputs = self._get_intermediate_layers_chunked(x, n)
830
- else:
831
- outputs = self._get_intermediate_layers_not_chunked(x, n)
832
- if norm:
833
- outputs = [self.norm(out) for out in outputs]
834
- class_tokens = [out[:, 0] for out in outputs]
835
- outputs = [out[:, 1:] for out in outputs]
836
- if reshape:
837
- B, _, w, h = x.shape
838
- outputs = [
839
- out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
840
- for out in outputs
841
- ]
842
- if return_class_token:
843
- return tuple(zip(outputs, class_tokens))
844
- return tuple(outputs)
845
-
846
- def forward(self, *args, is_training=False, **kwargs):
847
- ret = self.forward_features(*args, **kwargs)
848
- return ret
849
- # if is_training:
850
- # return ret
851
- # else:
852
- # return self.head(ret["x_norm_clstoken"])
853
-
854
-
855
- class PosConv(nn.Module):
856
- # PEG from https://arxiv.org/abs/2102.10882
857
- def __init__(self, in_chans, embed_dim=768, stride=1):
858
- super(PosConv, self).__init__()
859
- self.proj = nn.Sequential(
860
- nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim),
861
- )
862
- self.stride = stride
863
-
864
- def forward(self, x, size):
865
- B, N, C = x.shape
866
- cnn_feat_token = x.transpose(1, 2).view(B, C, *size)
867
- x = self.proj(cnn_feat_token)
868
- if self.stride == 1:
869
- x += cnn_feat_token
870
- x = x.flatten(2).transpose(1, 2)
871
- return x
872
-
873
- #def no_weight_decay(self):
874
- #return ['proj.%d.weight' % i for i in range(4)]
875
-
876
- class DinoWindowVisionTransformer(nn.Module):
877
- def __init__(
878
- self,
879
- img_size=224,
880
- patch_size=16,
881
- in_chans=3,
882
- embed_dim=768,
883
- depth=12,
884
- num_heads=12,
885
- mlp_ratio=4.0,
886
- qkv_bias=True,
887
- ffn_bias=True,
888
- proj_bias=True,
889
- drop_path_rate=0.0,
890
- drop_path_uniform=False,
891
- #init_values=None, # for layerscale: None or 0 => no layerscale
892
- init_values=1e-5, # for layerscale: None or 0 => no layerscale
893
- embed_layer=PatchEmbed,
894
- act_layer=nn.GELU,
895
- block_fn=NestedTensorBlock,
896
- ffn_layer="mlp",
897
- block_chunks=1,
898
- window_size=7,
899
- **kwargs
900
- ):
901
- """
902
- Args:
903
- img_size (int, tuple): input image size
904
- patch_size (int, tuple): patch size
905
- in_chans (int): number of input channels
906
- embed_dim (int): embedding dimension
907
- depth (int): depth of transformer
908
- num_heads (int): number of attention heads
909
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
910
- qkv_bias (bool): enable bias for qkv if True
911
- proj_bias (bool): enable bias for proj in attn if True
912
- ffn_bias (bool): enable bias for ffn if True
913
- drop_path_rate (float): stochastic depth rate
914
- drop_path_uniform (bool): apply uniform drop rate across blocks
915
- weight_init (str): weight init scheme
916
- init_values (float): layer-scale init values
917
- embed_layer (nn.Module): patch embedding layer
918
- act_layer (nn.Module): MLP activation layer
919
- block_fn (nn.Module): transformer block class
920
- ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
921
- block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
922
- """
923
- super().__init__()
924
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
925
-
926
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
927
- self.num_tokens = 1
928
- self.n_blocks = depth
929
- self.num_heads = num_heads
930
- self.patch_size = patch_size
931
-
932
- self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
933
- num_patches = self.patch_embed.num_patches
934
-
935
- #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
936
- #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
937
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
938
-
939
- self.pos_conv = PosConv(self.embed_dim, self.embed_dim)
940
-
941
- self.window_size = window_size
942
- #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)])
943
- #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)])
944
-
945
- if drop_path_uniform is True:
946
- dpr = [drop_path_rate] * depth
947
- else:
948
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
949
-
950
- if ffn_layer == "mlp":
951
- logger.info("using MLP layer as FFN")
952
- ffn_layer = Mlp
953
- elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
954
- logger.info("using SwiGLU layer as FFN")
955
- ffn_layer = SwiGLUFFNFused
956
- elif ffn_layer == "identity":
957
- logger.info("using Identity layer as FFN")
958
-
959
- def f(*args, **kwargs):
960
- return nn.Identity()
961
-
962
- ffn_layer = f
963
- else:
964
- raise NotImplementedError
965
-
966
- blocks_list = [
967
- block_fn(
968
- dim=embed_dim,
969
- num_heads=num_heads,
970
- mlp_ratio=mlp_ratio,
971
- qkv_bias=qkv_bias,
972
- proj_bias=proj_bias,
973
- ffn_bias=ffn_bias,
974
- drop_path=dpr[i],
975
- norm_layer=norm_layer,
976
- act_layer=act_layer,
977
- ffn_layer=ffn_layer,
978
- init_values=init_values,
979
- )
980
- for i in range(depth)
981
- ]
982
- if block_chunks > 0:
983
- self.chunked_blocks = True
984
- chunked_blocks = []
985
- chunksize = depth // block_chunks
986
- for i in range(0, depth, chunksize):
987
- # this is to keep the block index consistent if we chunk the block list
988
- chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
989
- self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
990
- else:
991
- self.chunked_blocks = False
992
- self.blocks = nn.ModuleList(blocks_list)
993
-
994
- self.norm = norm_layer(embed_dim)
995
- self.head = nn.Identity()
996
-
997
- self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
998
-
999
- self.nh = -1
1000
- self.nw = -1
1001
- try:
1002
- H = cfg.data_basic['crop_size'][0]
1003
- W = cfg.data_basic['crop_size'][1]
1004
- pad_h = (self.patch_size - H % self.patch_size)
1005
- pad_w = (self.patch_size - W % self.patch_size)
1006
- if pad_h == self.patch_size:
1007
- pad_h = 0
1008
- if pad_w == self.patch_size:
1009
- pad_w = 0
1010
- self.nh = (H + pad_h) // self.patch_size
1011
- self.nw = (W + pad_w) // self.patch_size
1012
- self.prepare_attn_bias((self.nh, self.nw))
1013
- except:
1014
- pass
1015
- self.init_weights()
1016
-
1017
- self.total_step = 10000 # For PE -> GPE transfer
1018
- self.start_step = 2000
1019
- self.current_step = 20000
1020
-
1021
- def init_weights(self):
1022
- #trunc_normal_(self.pos_embed, std=0.02)
1023
- #nn.init.normal_(self.cls_token, std=1e-6)
1024
- named_apply(init_weights_vit_timm, self)
1025
- for i in range(4):
1026
- try:
1027
- nn.init.constant_(self.conv_block[i].conv2.weight, 0.0)
1028
- except:
1029
- pass
1030
-
1031
- def interpolate_pos_encoding(self, x, w, h):
1032
- previous_dtype = x.dtype
1033
- #npatch = x.shape[1] - 1
1034
- #N = self.pos_embed.shape[1] - 1
1035
- npatch = x.shape[1]
1036
- N = self.pos_embed.shape[1]
1037
- if npatch == N and w == h:
1038
- return self.pos_embed
1039
- pos_embed = self.pos_embed.float()
1040
- #class_pos_embed = pos_embed[:, 0]
1041
- #patch_pos_embed = pos_embed[:, 1:]
1042
- patch_pos_embed = pos_embed
1043
- dim = x.shape[-1]
1044
- w0 = w // self.patch_size
1045
- h0 = h // self.patch_size
1046
- # we add a small number to avoid floating point error in the interpolation
1047
- # see discussion at https://github.com/facebookresearch/dino/issues/8
1048
- w0, h0 = w0 + 0.1, h0 + 0.1
1049
-
1050
- patch_pos_embed = nn.functional.interpolate(
1051
- patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
1052
- scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
1053
- mode="bicubic",
1054
- )
1055
-
1056
- assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
1057
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1058
- return patch_pos_embed.to(previous_dtype)
1059
- #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
1060
-
1061
- def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]:
1062
- """
1063
- Partition into non-overlapping windows with padding if needed.
1064
- Args:
1065
- x (tensor): input tokens with [B, H, W, C].
1066
- window_size (int): window size.
1067
-
1068
- Returns:
1069
- windows: windows after partition with [B * num_windows, window_size, window_size, C].
1070
- (Hp, Wp): padded height and width before partition
1071
- """
1072
- if conv_feature == False:
1073
- B, N, C = x.shape
1074
- H, W = hw[0], hw[1]
1075
-
1076
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
1077
-
1078
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
1079
- else:
1080
- B, C, H, W = x.shape
1081
-
1082
- x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
1083
-
1084
- windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C)
1085
-
1086
- #y = torch.cat((x_cls, windows), dim=1)
1087
- return windows #, (Hp, Wp)
1088
-
1089
-
1090
- def window_unpartition(self,
1091
- windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False
1092
- ) -> torch.Tensor:
1093
- """
1094
- Window unpartition into original sequences and removing padding.
1095
- Args:
1096
- windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
1097
- window_size (int): window size.
1098
- pad_hw (Tuple): padded height and width (Hp, Wp).
1099
- hw (Tuple): original height and width (H, W) before padding.
1100
-
1101
- Returns:
1102
- x: unpartitioned sequences with [B, H, W, C].
1103
- """
1104
- H, W = hw
1105
-
1106
- B = windows.shape[0] // (H * W // window_size // window_size)
1107
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
1108
-
1109
- if conv_feature == False:
1110
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1)
1111
- else:
1112
- C = windows.shape[-1]
1113
- x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W)
1114
-
1115
- # if Hp > H or Wp > W:
1116
- # x = x[:, :H, :W, :].contiguous()
1117
- return x
1118
-
1119
- def prepare_tokens_with_masks(self, x, masks=None, step=-1):
1120
- B, nc, w, h = x.shape
1121
- x = self.patch_embed(x)
1122
- if masks is not None:
1123
- x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
1124
-
1125
- #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
1126
- if step == -1:
1127
- step = self.current_step
1128
- else:
1129
- self.current_step = step
1130
-
1131
- if step < self.start_step:
1132
- coef = 0.0
1133
- elif step < self.total_step:
1134
- coef = (step - self.start_step) / (self.total_step - self.start_step)
1135
- else:
1136
- coef = 1.0
1137
-
1138
- x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw))
1139
-
1140
- return x
1141
-
1142
- def prepare_attn_bias(self, shape):
1143
- window_size = self.window_size
1144
- if window_size <= 0:
1145
- return
1146
-
1147
- import xformers.components.attention.attention_patterns as AP
1148
-
1149
- nh, nw = shape
1150
- radius = (window_size-1)//2
1151
- mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
1152
-
1153
- pad = (8 - (nh * nw) % 8)
1154
- if pad == 8:
1155
- pad = 0
1156
- mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous()
1157
- if pad > 0:
1158
- mask = mask_pad[:, :-pad].view(nh, nw, nh, nw)
1159
- else:
1160
- mask = mask_pad[:, :].view(nh, nw, nh, nw)
1161
-
1162
- # angle
1163
- mask[:radius+1, :radius+1, :window_size, :window_size] = True
1164
- mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
1165
- mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
1166
- mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
1167
-
1168
- # edge
1169
- mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
1170
- mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
1171
- mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
1172
- mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
1173
-
1174
- mask = mask.view(nh*nw, nh*nw)
1175
- bias_pad = torch.log(mask_pad)
1176
- #bias = bias_pad[:, :-pad]
1177
- self.register_buffer('attn_bias', bias_pad)
1178
-
1179
- return bias_pad
1180
-
1181
- def forward_features_list(self, x_list, masks_list):
1182
- x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
1183
- for blk in self.blocks:
1184
- x = blk(x)
1185
-
1186
- all_x = x
1187
- output = []
1188
- for x, masks in zip(all_x, masks_list):
1189
- x_norm = self.norm(x)
1190
- output.append(
1191
- {
1192
- "x_norm_clstoken": x_norm[:, 0],
1193
- "x_norm_patchtokens": x_norm[:, 1:],
1194
- "x_prenorm": x,
1195
- "masks": masks,
1196
- }
1197
- )
1198
- return output
1199
-
1200
- def forward_features(self, x, masks=None, **kwargs):
1201
- if isinstance(x, list):
1202
- return self.forward_features_list(x, masks)
1203
-
1204
- B, C, H, W = x.size()
1205
- pad_h = (self.patch_size - H % self.patch_size)
1206
- pad_w = (self.patch_size - W % self.patch_size)
1207
- if pad_h == self.patch_size:
1208
- pad_h = 0
1209
- if pad_w == self.patch_size:
1210
- pad_w = 0
1211
- #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
1212
- if pad_h + pad_w > 0:
1213
- x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
1214
-
1215
- nh = (H+pad_h)//self.patch_size
1216
- nw = (W+pad_w)//self.patch_size
1217
-
1218
- if self.window_size > 0:
1219
- if nh == self.nh and nw == self.nw:
1220
- attn_bias = self.attn_bias
1221
- else:
1222
- attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size))
1223
- self.nh = nh
1224
- self.nw = nw
1225
- attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1)
1226
- else:
1227
- attn_bias = None
1228
-
1229
- x = self.prepare_tokens_with_masks(x, masks)
1230
- #x = self.patch_embed(x)
1231
-
1232
- features = []
1233
- #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
1234
- for blk in self.blocks:
1235
- x = blk(x, attn_bias)
1236
- #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
1237
-
1238
- # for idx in range(len(self.blocks[0])):
1239
- # x = self.blocks[0][idx](x, attn_bias)
1240
-
1241
- # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
1242
- # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
1243
- # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x)
1244
- # if idx + 1 != len(self.blocks[0]):
1245
- # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
1246
- # else:
1247
- # b, c, h, w = x.size()
1248
- # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c)
1249
- #features.append(x)
1250
-
1251
- #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
1252
-
1253
- x_norm = self.norm(x)
1254
- # return {
1255
- # "x_norm_clstoken": x_norm[:, 0],
1256
- # "x_norm_patchtokens": x_norm[:, 1:],
1257
- # "x_prenorm": x,
1258
- # "masks": masks,
1259
- # }
1260
- features = []
1261
- features.append(x_norm)
1262
- features.append(x_norm)
1263
- features.append(x_norm)
1264
- features.append(x_norm)
1265
- return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
1266
-
1267
- def _get_intermediate_layers_not_chunked(self, x, n=1):
1268
- x = self.prepare_tokens_with_masks(x)
1269
- # If n is an int, take the n last blocks. If it's a list, take them
1270
- output, total_block_len = [], len(self.blocks)
1271
- blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
1272
- for i, blk in enumerate(self.blocks):
1273
- x = blk(x)
1274
- if i in blocks_to_take:
1275
- output.append(x)
1276
- assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
1277
- return output
1278
-
1279
- def _get_intermediate_layers_chunked(self, x, n=1):
1280
- x = self.prepare_tokens_with_masks(x)
1281
- output, i, total_block_len = [], 0, len(self.blocks[-1])
1282
- # If n is an int, take the n last blocks. If it's a list, take them
1283
- blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
1284
- for block_chunk in self.blocks:
1285
- for blk in block_chunk[i:]: # Passing the nn.Identity()
1286
- x = blk(x)
1287
- if i in blocks_to_take:
1288
- output.append(x)
1289
- i += 1
1290
- assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
1291
- return output
1292
-
1293
- def get_intermediate_layers(
1294
- self,
1295
- x: torch.Tensor,
1296
- n: Union[int, Sequence] = 1, # Layers or n last layers to take
1297
- reshape: bool = False,
1298
- return_class_token: bool = False,
1299
- norm=True,
1300
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
1301
- if self.chunked_blocks:
1302
- outputs = self._get_intermediate_layers_chunked(x, n)
1303
- else:
1304
- outputs = self._get_intermediate_layers_not_chunked(x, n)
1305
- if norm:
1306
- outputs = [self.norm(out) for out in outputs]
1307
- class_tokens = [out[:, 0] for out in outputs]
1308
- outputs = [out[:, 1:] for out in outputs]
1309
- if reshape:
1310
- B, _, w, h = x.shape
1311
- outputs = [
1312
- out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
1313
- for out in outputs
1314
- ]
1315
- if return_class_token:
1316
- return tuple(zip(outputs, class_tokens))
1317
- return tuple(outputs)
1318
-
1319
- def forward(self, *args, is_training=False, **kwargs):
1320
- ret = self.forward_features(*args, **kwargs)
1321
- return ret
1322
- # if is_training:
1323
- # return ret
1324
- # else:
1325
- # return self.head(ret["x_norm_clstoken"])
1326
-
1327
-
1328
-
1329
-
1330
- def init_weights_vit_timm(module: nn.Module, name: str = ""):
1331
- """ViT weight initialization, original timm impl (for reproducibility)"""
1332
- if isinstance(module, nn.Linear):
1333
- trunc_normal_(module.weight, std=0.02)
1334
- if module.bias is not None:
1335
- nn.init.zeros_(module.bias)
1336
-
1337
-
1338
- def vit_small(patch_size=14, **kwargs):
1339
- model = DinoVisionTransformer(
1340
- patch_size=patch_size,
1341
- embed_dim=384,
1342
- depth=12,
1343
- num_heads=6,
1344
- mlp_ratio=4,
1345
- block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
1346
- **kwargs,
1347
- )
1348
- return model
1349
-
1350
-
1351
- def vit_base(patch_size=14, **kwargs):
1352
- model = DinoWindowVisionTransformer(
1353
- patch_size=patch_size,
1354
- embed_dim=768,
1355
- depth=12,
1356
- num_heads=12,
1357
- mlp_ratio=4,
1358
- block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
1359
- **kwargs,
1360
- )
1361
- return model
1362
-
1363
-
1364
- def vit_large(patch_size=14, checkpoint=None, **kwargs):
1365
- model = DinoVisionTransformer(
1366
- img_size = 518,
1367
- patch_size=patch_size,
1368
- embed_dim=1024,
1369
- depth=24,
1370
- num_heads=16,
1371
- mlp_ratio=4,
1372
- block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
1373
- **kwargs,
1374
- )
1375
-
1376
- if checkpoint is not None:
1377
- with open(checkpoint, "rb") as f:
1378
- state_dict = torch.load(f)
1379
- try:
1380
- model.load_state_dict(state_dict, strict=True)
1381
- except:
1382
- new_state_dict = {}
1383
- for key, value in state_dict.items():
1384
- if 'blocks' in key:
1385
- key_new = 'blocks.0' + key[len('blocks'):]
1386
- else:
1387
- key_new = key
1388
- new_state_dict[key_new] = value
1389
-
1390
- model.load_state_dict(new_state_dict, strict=True)
1391
- #del model.norm
1392
- del model.mask_token
1393
- return model
1394
-
1395
- # model = DinoWindowVisionTransformer(
1396
- # img_size = 518,
1397
- # patch_size=patch_size,
1398
- # embed_dim=1024,
1399
- # depth=24,
1400
- # num_heads=16,
1401
- # mlp_ratio=4,
1402
- # block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
1403
- # window_size=37,
1404
- # **kwargs,
1405
- # )
1406
-
1407
- # if checkpoint is not None:
1408
- # with open(checkpoint, "rb") as f:
1409
- # state_dict = torch.load(f)
1410
- # try:
1411
- # model.load_state_dict(state_dict, strict=True)
1412
- # except:
1413
- # new_state_dict = {}
1414
- # for key, value in state_dict.items():
1415
- # if 'blocks' in key:
1416
- # key_new = 'blocks.0' + key[len('blocks'):]
1417
- # else:
1418
- # key_new = key
1419
- # if 'pos_embed' in key:
1420
- # value = value[:, 1:, :]
1421
- # new_state_dict[key_new] = value
1422
-
1423
- # model.load_state_dict(new_state_dict, strict=False)
1424
- # #del model.norm
1425
- # del model.mask_token
1426
- return model
1427
-
1428
-
1429
- def vit_giant2(patch_size=16, **kwargs):
1430
- """
1431
- Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
1432
- """
1433
- model = DinoVisionTransformer(
1434
- patch_size=patch_size,
1435
- embed_dim=1536,
1436
- depth=40,
1437
- num_heads=24,
1438
- mlp_ratio=4,
1439
- block_fn=partial(Block, attn_class=MemEffAttention),
1440
- **kwargs,
1441
- )
1442
- return model
1443
-
1444
- if __name__ == '__main__':
1445
- try:
1446
- from mmcv.utils import Config
1447
- except:
1448
- from mmengine import Config
1449
-
1450
- #rgb = torch.rand((2, 3, 518, 518)).cuda()
1451
-
1452
- #cfg.data_basic['crop_size']['0']
1453
- #cfg.data_basic['crop_size']['1']
1454
- cfg = Config.fromfile('/cpfs01/user/mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py')
1455
-
1456
- #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
1457
- rgb = torch.zeros(1, 3, 1400, 1680).cuda()
1458
- model = vit_large(checkpoint="/cpfs02/shared/public/custom/group_local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda()
1459
-
1460
- #import timm
1461
- #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
1462
- #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
1463
-
1464
- out1 = model(rgb)
1465
- #out2 = model2(rgb)
1466
- temp = 0
1467
-
1468
-
1469
-
1470
- # import time
1471
- # window_size = 37
1472
- # def prepare_window_masks(shape):
1473
- # if window_size <= 0:
1474
- # return None
1475
- # import xformers.components.attention.attention_patterns as AP
1476
-
1477
- # B, nh, nw, _, _ = shape
1478
- # radius = (window_size-1)//2
1479
- # #time0 = time.time()
1480
- # d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
1481
- # #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
1482
- # # mask = mask.view(nh, nw, nh, nw)
1483
- # # #time1 = time.time() - time0
1484
-
1485
- # # # angle
1486
- # # mask[:radius+1, :radius+1, :window_size, :window_size] = True
1487
- # # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
1488
- # # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
1489
- # # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
1490
- # # time2 = time.time() - time0 - time1
1491
-
1492
- # # # edge
1493
- # # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
1494
- # # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
1495
- # # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
1496
- # # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
1497
- # # time3 = time.time() - time0 - time2
1498
- # # print(time1, time2, time3)
1499
-
1500
- # # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1)
1501
-
1502
- # shape = (1, 55, 55, None, None)
1503
- # mask = prepare_window_masks(shape)
1504
- # # temp = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/backbones/ViT_DINO_reg.py DELETED
@@ -1,1293 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # References:
8
- # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
-
11
- from functools import partial
12
- import math
13
- import logging
14
- from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
15
-
16
- import torch
17
- import torch.nn as nn
18
- from torch import Tensor
19
- import torch.utils.checkpoint
20
- from torch.nn.init import trunc_normal_
21
- import torch.nn.init
22
- import torch.nn.functional as F
23
-
24
- #from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
25
-
26
- logger = logging.getLogger("dinov2")
27
-
28
- # SSF finetuning originally by dongzelian
29
- def init_ssf_scale_shift(dim):
30
- scale = nn.Parameter(torch.ones(dim))
31
- shift = nn.Parameter(torch.zeros(dim))
32
-
33
- nn.init.normal_(scale, mean=1, std=.02)
34
- nn.init.normal_(shift, std=.02)
35
-
36
- return scale, shift
37
-
38
- def ssf_ada(x, scale, shift):
39
- assert scale.shape == shift.shape
40
- if x.shape[-1] == scale.shape[0]:
41
- return x * scale + shift
42
- elif x.shape[1] == scale.shape[0]:
43
- return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)
44
- else:
45
- raise ValueError('the input tensor shape does not match the shape of the scale factor.')
46
-
47
- # LoRA finetuning originally by edwardjhu
48
- class LoRALayer():
49
- def __init__(
50
- self,
51
- r: int,
52
- lora_alpha: int,
53
- lora_dropout: float,
54
- merge_weights: bool,
55
- ):
56
- self.r = r
57
- self.lora_alpha = lora_alpha
58
- # Optional dropout
59
- if lora_dropout > 0.:
60
- self.lora_dropout = nn.Dropout(p=lora_dropout)
61
- else:
62
- self.lora_dropout = lambda x: x
63
- # Mark the weight as unmerged
64
- self.merged = False
65
- self.merge_weights = merge_weights
66
-
67
- class LoRALinear(nn.Linear, LoRALayer):
68
- # LoRA implemented in a dense layer
69
- def __init__(
70
- self,
71
- in_features: int,
72
- out_features: int,
73
- r: int = 0,
74
- lora_alpha: int = 1,
75
- lora_dropout: float = 0.,
76
- fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
77
- merge_weights: bool = True,
78
- **kwargs
79
- ):
80
- nn.Linear.__init__(self, in_features, out_features, **kwargs)
81
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
82
- merge_weights=merge_weights)
83
-
84
- self.fan_in_fan_out = fan_in_fan_out
85
- # Actual trainable parameters
86
- if r > 0:
87
- self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
88
- self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
89
- self.scaling = self.lora_alpha / self.r
90
- # Freezing the pre-trained weight matrix
91
- self.weight.requires_grad = False
92
- self.reset_parameters()
93
- if fan_in_fan_out:
94
- self.weight.data = self.weight.data.transpose(0, 1)
95
-
96
- def reset_parameters(self):
97
- #nn.Linear.reset_parameters(self)
98
- if hasattr(self, 'lora_A'):
99
- # initialize B the same way as the default for nn.Linear and A to zero
100
- # this is different than what is described in the paper but should not affect performance
101
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
102
- nn.init.zeros_(self.lora_B)
103
-
104
- # def train(self, mode: bool = True):
105
- # def T(w):
106
- # return w.transpose(0, 1) if self.fan_in_fan_out else w
107
- # nn.Linear.train(self, mode)
108
- # if mode:
109
- # if self.merge_weights and self.merged:
110
- # # Make sure that the weights are not merged
111
- # if self.r > 0:
112
- # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
113
- # self.merged = False
114
- # else:
115
- # if self.merge_weights and not self.merged:
116
- # # Merge the weights and mark it
117
- # if self.r > 0:
118
- # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
119
- # self.merged = True
120
-
121
- def forward(self, x: torch.Tensor):
122
- def T(w):
123
- return w.transpose(0, 1) if self.fan_in_fan_out else w
124
- if self.r > 0 and not self.merged:
125
- result = F.linear(x, T(self.weight), bias=self.bias)
126
- result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
127
- return result
128
- else:
129
- return F.linear(x, T(self.weight), bias=self.bias)
130
-
131
-
132
-
133
- def make_2tuple(x):
134
- if isinstance(x, tuple):
135
- assert len(x) == 2
136
- return x
137
-
138
- assert isinstance(x, int)
139
- return (x, x)
140
-
141
- def drop_path(x, drop_prob: float = 0.0, training: bool = False):
142
- if drop_prob == 0.0 or not training:
143
- return x
144
- keep_prob = 1 - drop_prob
145
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
146
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
147
- if keep_prob > 0.0:
148
- random_tensor.div_(keep_prob)
149
- output = x * random_tensor
150
- return output
151
-
152
- class DropPath(nn.Module):
153
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
154
-
155
- def __init__(self, drop_prob=None):
156
- super(DropPath, self).__init__()
157
- self.drop_prob = drop_prob
158
-
159
- def forward(self, x):
160
- return drop_path(x, self.drop_prob, self.training)
161
-
162
- class LayerScale(nn.Module):
163
- def __init__(
164
- self,
165
- dim: int,
166
- init_values: Union[float, Tensor] = 1e-5,
167
- inplace: bool = False,
168
- ) -> None:
169
- super().__init__()
170
- self.inplace = inplace
171
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
172
-
173
- def forward(self, x: Tensor) -> Tensor:
174
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
175
-
176
-
177
- class PatchEmbed(nn.Module):
178
- """
179
- 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
180
-
181
- Args:
182
- img_size: Image size.
183
- patch_size: Patch token size.
184
- in_chans: Number of input image channels.
185
- embed_dim: Number of linear projection output channels.
186
- norm_layer: Normalization layer.
187
- """
188
-
189
- def __init__(
190
- self,
191
- img_size: Union[int, Tuple[int, int]] = 224,
192
- patch_size: Union[int, Tuple[int, int]] = 16,
193
- in_chans: int = 3,
194
- embed_dim: int = 768,
195
- norm_layer: Optional[Callable] = None,
196
- flatten_embedding: bool = True,
197
- tuning_mode: Optional[str] = None
198
- ) -> None:
199
- super().__init__()
200
-
201
- image_HW = make_2tuple(img_size)
202
- patch_HW = make_2tuple(patch_size)
203
- patch_grid_size = (
204
- image_HW[0] // patch_HW[0],
205
- image_HW[1] // patch_HW[1],
206
- )
207
-
208
- self.img_size = image_HW
209
- self.patch_size = patch_HW
210
- self.patches_resolution = patch_grid_size
211
- self.num_patches = patch_grid_size[0] * patch_grid_size[1]
212
-
213
- self.in_chans = in_chans
214
- self.embed_dim = embed_dim
215
-
216
- self.flatten_embedding = flatten_embedding
217
-
218
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
219
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
220
-
221
- if tuning_mode != None:
222
- self.tuning_mode = tuning_mode
223
- if tuning_mode == 'ssf':
224
- self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
225
- else:
226
- pass
227
- #raise NotImplementedError()
228
- else:
229
- self.tuning_mode = None
230
-
231
- def forward(self, x: Tensor) -> Tensor:
232
- _, _, H, W = x.shape
233
- patch_H, patch_W = self.patch_size
234
-
235
- assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
236
- assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
237
-
238
- x = self.proj(x) # B C H W
239
- H, W = x.size(2), x.size(3)
240
- x = x.flatten(2).transpose(1, 2) # B HW C
241
- x = self.norm(x)
242
- if self.tuning_mode == 'ssf':
243
- x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
244
- if not self.flatten_embedding:
245
- x = x.reshape(-1, H, W, self.embed_dim) # B H W C
246
- return x
247
-
248
- def flops(self) -> float:
249
- Ho, Wo = self.patches_resolution
250
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
251
- if self.norm is not None:
252
- flops += Ho * Wo * self.embed_dim
253
- return flops
254
-
255
- class Mlp(nn.Module):
256
- def __init__(
257
- self,
258
- in_features: int,
259
- hidden_features: Optional[int] = None,
260
- out_features: Optional[int] = None,
261
- act_layer: Callable[..., nn.Module] = nn.GELU,
262
- drop: float = 0.0,
263
- bias: bool = True,
264
- tuning_mode: Optional[int] = None
265
- ) -> None:
266
- super().__init__()
267
- out_features = out_features or in_features
268
- hidden_features = hidden_features or in_features
269
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
270
- self.act = act_layer()
271
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
272
- self.drop = nn.Dropout(drop)
273
-
274
- if tuning_mode != None:
275
- self.tuning_mode = tuning_mode
276
- if tuning_mode == 'ssf':
277
- self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)
278
- self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
279
- else:
280
- pass
281
- #raise NotImplementedError()
282
- else:
283
- self.tuning_mode = None
284
-
285
- def forward(self, x: Tensor) -> Tensor:
286
- x = self.fc1(x)
287
- if self.tuning_mode == 'ssf':
288
- x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
289
-
290
- x = self.act(x)
291
- x = self.drop(x)
292
- x = self.fc2(x)
293
- if self.tuning_mode == 'ssf':
294
- x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
295
-
296
- x = self.drop(x)
297
- return x
298
-
299
-
300
- class SwiGLUFFN(nn.Module):
301
- def __init__(
302
- self,
303
- in_features: int,
304
- hidden_features: Optional[int] = None,
305
- out_features: Optional[int] = None,
306
- act_layer: Callable[..., nn.Module] = None,
307
- drop: float = 0.0,
308
- bias: bool = True,
309
- tuning_mode: Optional[int] = None
310
- ) -> None:
311
- super().__init__()
312
- out_features = out_features or in_features
313
- hidden_features = hidden_features or in_features
314
- self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
315
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
316
-
317
- if tuning_mode != None:
318
- self.tuning_mode = tuning_mode
319
- if tuning_mode == 'ssf':
320
- self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(2 * hidden_features)
321
- self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
322
- else:
323
- pass
324
- #raise NotImplementedError()
325
- else:
326
- self.tuning_mode = None
327
-
328
-
329
- def forward(self, x: Tensor) -> Tensor:
330
- x12 = self.w12(x)
331
- if self.tuning_mode == 'ssf':
332
- x12 = ssf_ada(x12, self.ssf_scale_1, self.ssf_shift_1)
333
-
334
- x1, x2 = x12.chunk(2, dim=-1)
335
- hidden = F.silu(x1) * x2
336
- out = self.w3(hidden)
337
-
338
- if self.tuning_mode == 'ssf':
339
- out = ssf_ada(out, self.ssf_scale_2, self.ssf_scale_2)
340
-
341
- return out
342
-
343
-
344
- try:
345
- from xformers.ops import SwiGLU
346
- #import numpy.bool
347
- XFORMERS_AVAILABLE = True
348
- except ImportError:
349
- SwiGLU = SwiGLUFFN
350
- XFORMERS_AVAILABLE = False
351
-
352
- class SwiGLUFFNFused(SwiGLU):
353
- def __init__(
354
- self,
355
- in_features: int,
356
- hidden_features: Optional[int] = None,
357
- out_features: Optional[int] = None,
358
- act_layer: Callable[..., nn.Module] = None,
359
- drop: float = 0.0,
360
- bias: bool = True,
361
- ) -> None:
362
- out_features = out_features or in_features
363
- hidden_features = hidden_features or in_features
364
- hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
365
- super().__init__(
366
- in_features=in_features,
367
- hidden_features=hidden_features,
368
- out_features=out_features,
369
- bias=bias,
370
- )
371
-
372
-
373
- try:
374
- from xformers.ops import memory_efficient_attention, unbind, fmha
375
- from xformers.components.attention import ScaledDotProduct
376
- from xformers.components import MultiHeadDispatch
377
- #import numpy.bool
378
- XFORMERS_AVAILABLE = True
379
- except ImportError:
380
- logger.warning("xFormers not available")
381
- XFORMERS_AVAILABLE = False
382
-
383
-
384
- class Attention(nn.Module):
385
- def __init__(
386
- self,
387
- dim: int,
388
- num_heads: int = 8,
389
- qkv_bias: bool = False,
390
- proj_bias: bool = True,
391
- attn_drop: float = 0.0,
392
- proj_drop: float = 0.0,
393
- window_size: int = 0,
394
- tuning_mode: Optional[int] = None
395
- ) -> None:
396
- super().__init__()
397
- self.num_heads = num_heads
398
- head_dim = dim // num_heads
399
- self.scale = head_dim**-0.5
400
-
401
- if tuning_mode == 'lora':
402
- self.tuning_mode = tuning_mode
403
- self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=8)
404
- else:
405
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
406
-
407
- self.attn_drop = nn.Dropout(attn_drop)
408
-
409
- if tuning_mode == 'lora':
410
- self.tuning_mode = tuning_mode
411
- self.proj = LoRALinear(dim, dim, bias=proj_bias, r=8)
412
- else:
413
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
414
- self.proj_drop = nn.Dropout(proj_drop)
415
-
416
- if tuning_mode != None:
417
- self.tuning_mode = tuning_mode
418
- if tuning_mode == 'ssf':
419
- self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3)
420
- self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
421
- else:
422
- pass
423
- #raise NotImplementedError()
424
- else:
425
- self.tuning_mode = None
426
-
427
- #if not self.training:
428
- #
429
- # self.attn = ScaledDotProduct()
430
- #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
431
-
432
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
433
- B, N, C = x.shape
434
- if self.tuning_mode == 'ssf':
435
- qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
436
- else:
437
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
438
-
439
- q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
440
- attn = q @ k.transpose(-2, -1)
441
-
442
- if attn_bias is not None:
443
- attn = attn + attn_bias[:, :, :N]
444
-
445
- attn = attn.softmax(dim=-1)
446
- attn = self.attn_drop(attn)
447
-
448
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
449
- x = self.proj(x)
450
-
451
- if self.tuning_mode == 'ssf':
452
- x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
453
-
454
- x = self.proj_drop(x)
455
- return x
456
-
457
-
458
- class MemEffAttention(Attention):
459
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
460
- if not XFORMERS_AVAILABLE:
461
- #if True:
462
- assert attn_bias is None, "xFormers is required for nested tensors usage"
463
- return super().forward(x, attn_bias)
464
-
465
- B, N, C = x.shape
466
- if self.tuning_mode == 'ssf':
467
- qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads)
468
- else:
469
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
470
-
471
- q, k, v = unbind(qkv, 2)
472
- if attn_bias is not None:
473
- x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
474
- else:
475
- x = memory_efficient_attention(q, k, v)
476
- x = x.reshape([B, N, C])
477
-
478
- x = self.proj(x)
479
- if self.tuning_mode == 'ssf':
480
- x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
481
-
482
- x = self.proj_drop(x)
483
- return x
484
-
485
- try:
486
- from xformers.ops import fmha
487
- from xformers.ops import scaled_index_add, index_select_cat
488
- #import numpy.bool
489
- XFORMERS_AVAILABLE = True
490
- except ImportError:
491
- logger.warning("xFormers not available")
492
- XFORMERS_AVAILABLE = False
493
-
494
- class Block(nn.Module):
495
- def __init__(
496
- self,
497
- dim: int,
498
- num_heads: int,
499
- mlp_ratio: float = 4.0,
500
- qkv_bias: bool = False,
501
- proj_bias: bool = True,
502
- ffn_bias: bool = True,
503
- drop: float = 0.0,
504
- attn_drop: float = 0.0,
505
- init_values = None,
506
- drop_path: float = 0.0,
507
- act_layer: Callable[..., nn.Module] = nn.GELU,
508
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
509
- attn_class: Callable[..., nn.Module] = Attention,
510
- ffn_layer: Callable[..., nn.Module] = Mlp,
511
- tuning_mode: Optional[int] = None
512
- ) -> None:
513
- super().__init__()
514
- # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
515
- self.norm1 = norm_layer(dim)
516
- self.attn = attn_class(
517
- dim,
518
- num_heads=num_heads,
519
- qkv_bias=qkv_bias,
520
- proj_bias=proj_bias,
521
- attn_drop=attn_drop,
522
- proj_drop=drop,
523
- tuning_mode=tuning_mode
524
- )
525
-
526
- if tuning_mode != None:
527
- self.tuning_mode = tuning_mode
528
- if tuning_mode == 'ssf':
529
- self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
530
- self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
531
- else:
532
- pass
533
- #raise NotImplementedError()
534
- else:
535
- self.tuning_mode = None
536
-
537
- self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
538
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
539
-
540
- self.norm2 = norm_layer(dim)
541
- mlp_hidden_dim = int(dim * mlp_ratio)
542
- self.mlp = ffn_layer(
543
- in_features=dim,
544
- hidden_features=mlp_hidden_dim,
545
- act_layer=act_layer,
546
- drop=drop,
547
- bias=ffn_bias,
548
- )
549
- self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
550
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
551
-
552
- self.sample_drop_ratio = drop_path
553
-
554
- def forward(self, x: Tensor, attn_bias=None) -> Tensor:
555
- def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
556
- if self.tuning_mode == 'ssf':
557
- return self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1), attn_bias))
558
- else:
559
- return self.ls1(self.attn(self.norm1(x), attn_bias))
560
-
561
- def ffn_residual_func(x: Tensor) -> Tensor:
562
- if self.tuning_mode == 'ssf':
563
- return self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))
564
- else:
565
- return self.ls2(self.mlp(self.norm2(x)))
566
-
567
- if self.training and self.sample_drop_ratio > 0.1:
568
- # the overhead is compensated only for a drop path rate larger than 0.1
569
- x = drop_add_residual_stochastic_depth(
570
- x,
571
- residual_func=attn_residual_func,
572
- sample_drop_ratio=self.sample_drop_ratio,
573
- attn_bias=attn_bias
574
- )
575
- x = drop_add_residual_stochastic_depth(
576
- x,
577
- residual_func=ffn_residual_func,
578
- sample_drop_ratio=self.sample_drop_ratio,
579
- )
580
- elif self.training and self.sample_drop_ratio > 0.0:
581
- x = x + self.drop_path1(attn_residual_func(x, attn_bias))
582
- x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
583
- else:
584
- x = x + attn_residual_func(x, attn_bias)
585
- x = x + ffn_residual_func(x)
586
- return x
587
-
588
-
589
- def drop_add_residual_stochastic_depth(
590
- x: Tensor,
591
- residual_func: Callable[[Tensor], Tensor],
592
- sample_drop_ratio: float = 0.0, attn_bias=None
593
- ) -> Tensor:
594
- # 1) extract subset using permutation
595
- b, n, d = x.shape
596
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
597
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
598
- x_subset = x[brange]
599
-
600
- # 2) apply residual_func to get residual
601
- residual = residual_func(x_subset, attn_bias)
602
-
603
- x_flat = x.flatten(1)
604
- residual = residual.flatten(1)
605
-
606
- residual_scale_factor = b / sample_subset_size
607
-
608
- # 3) add the residual
609
- x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
610
- return x_plus_residual.view_as(x)
611
-
612
-
613
- def get_branges_scales(x, sample_drop_ratio=0.0):
614
- b, n, d = x.shape
615
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
616
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
617
- residual_scale_factor = b / sample_subset_size
618
- return brange, residual_scale_factor
619
-
620
-
621
- def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
622
- if scaling_vector is None:
623
- x_flat = x.flatten(1)
624
- residual = residual.flatten(1)
625
- x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
626
- else:
627
- x_plus_residual = scaled_index_add(
628
- x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
629
- )
630
- return x_plus_residual
631
-
632
-
633
- attn_bias_cache: Dict[Tuple, Any] = {}
634
-
635
-
636
- def get_attn_bias_and_cat(x_list, branges=None):
637
- """
638
- this will perform the index select, cat the tensors, and provide the attn_bias from cache
639
- """
640
- batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
641
- all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
642
- if all_shapes not in attn_bias_cache.keys():
643
- seqlens = []
644
- for b, x in zip(batch_sizes, x_list):
645
- for _ in range(b):
646
- seqlens.append(x.shape[1])
647
- attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
648
- attn_bias._batch_sizes = batch_sizes
649
- attn_bias_cache[all_shapes] = attn_bias
650
-
651
- if branges is not None:
652
- cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
653
- else:
654
- tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
655
- cat_tensors = torch.cat(tensors_bs1, dim=1)
656
-
657
- return attn_bias_cache[all_shapes], cat_tensors
658
-
659
-
660
- def drop_add_residual_stochastic_depth_list(
661
- x_list: List[Tensor],
662
- residual_func: Callable[[Tensor, Any], Tensor],
663
- sample_drop_ratio: float = 0.0,
664
- scaling_vector=None,
665
- ) -> Tensor:
666
- # 1) generate random set of indices for dropping samples in the batch
667
- branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
668
- branges = [s[0] for s in branges_scales]
669
- residual_scale_factors = [s[1] for s in branges_scales]
670
-
671
- # 2) get attention bias and index+concat the tensors
672
- attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
673
-
674
- # 3) apply residual_func to get residual, and split the result
675
- residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
676
-
677
- outputs = []
678
- for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
679
- outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
680
- return outputs
681
-
682
-
683
- class NestedTensorBlock(Block):
684
- def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
685
- """
686
- x_list contains a list of tensors to nest together and run
687
- """
688
- assert isinstance(self.attn, MemEffAttention)
689
-
690
- if self.training and self.sample_drop_ratio > 0.0:
691
-
692
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
693
- return self.attn(self.norm1(x), attn_bias=attn_bias)
694
-
695
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
696
- return self.mlp(self.norm2(x))
697
-
698
- x_list = drop_add_residual_stochastic_depth_list(
699
- x_list,
700
- residual_func=attn_residual_func,
701
- sample_drop_ratio=self.sample_drop_ratio,
702
- scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
703
- )
704
- x_list = drop_add_residual_stochastic_depth_list(
705
- x_list,
706
- residual_func=ffn_residual_func,
707
- sample_drop_ratio=self.sample_drop_ratio,
708
- scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
709
- )
710
- return x_list
711
- else:
712
-
713
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
714
- return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
715
-
716
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
717
- return self.ls2(self.mlp(self.norm2(x)))
718
-
719
- attn_bias, x = get_attn_bias_and_cat(x_list)
720
- x = x + attn_residual_func(x, attn_bias=attn_bias)
721
- x = x + ffn_residual_func(x)
722
- return attn_bias.split(x)
723
-
724
- def forward(self, x_or_x_list, attn_bias=None):
725
- if isinstance(x_or_x_list, Tensor):
726
- return super().forward(x_or_x_list, attn_bias)
727
- elif isinstance(x_or_x_list, list):
728
- assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
729
- return self.forward_nested(x_or_x_list)
730
- else:
731
- raise AssertionError
732
-
733
-
734
- def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
735
- if not depth_first and include_root:
736
- fn(module=module, name=name)
737
- for child_name, child_module in module.named_children():
738
- child_name = ".".join((name, child_name)) if name else child_name
739
- named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
740
- if depth_first and include_root:
741
- fn(module=module, name=name)
742
- return module
743
-
744
-
745
- class BlockChunk(nn.ModuleList):
746
- def forward(self, x, others=None):
747
- for b in self:
748
- if others == None:
749
- x = b(x)
750
- else:
751
- x = b(x, others)
752
- return x
753
-
754
-
755
- class DinoVisionTransformer(nn.Module):
756
- def __init__(
757
- self,
758
- img_size=518,
759
- patch_size=16,
760
- in_chans=3,
761
- embed_dim=768,
762
- depth=12,
763
- num_heads=12,
764
- mlp_ratio=4.0,
765
- qkv_bias=True,
766
- ffn_bias=True,
767
- proj_bias=True,
768
- drop_path_rate=0.0,
769
- drop_path_uniform=False,
770
- init_values=1e-5, # for layerscale: None or 0 => no layerscale
771
- embed_layer=PatchEmbed,
772
- act_layer=nn.GELU,
773
- block_fn=Block,
774
- ffn_layer="mlp",
775
- block_chunks=1,
776
- num_register_tokens=0,
777
- interpolate_antialias=False,
778
- interpolate_offset=0.1,
779
- tuning_mode=None,
780
- **kwargs
781
- ):
782
- """
783
- Args:
784
- img_size (int, tuple): input image size
785
- patch_size (int, tuple): patch size
786
- in_chans (int): number of input channels
787
- embed_dim (int): embedding dimension
788
- depth (int): depth of transformer
789
- num_heads (int): number of attention heads
790
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
791
- qkv_bias (bool): enable bias for qkv if True
792
- proj_bias (bool): enable bias for proj in attn if True
793
- ffn_bias (bool): enable bias for ffn if True
794
- drop_path_rate (float): stochastic depth rate
795
- drop_path_uniform (bool): apply uniform drop rate across blocks
796
- weight_init (str): weight init scheme
797
- init_values (float): layer-scale init values
798
- embed_layer (nn.Module): patch embedding layer
799
- act_layer (nn.Module): MLP activation layer
800
- block_fn (nn.Module): transformer block class
801
- ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
802
- block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
803
- num_register_tokens: (int) number of extra cls tokens (so-called "registers")
804
- interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
805
- interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
806
- """
807
- super().__init__()
808
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
809
-
810
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
811
- self.num_tokens = 1
812
- self.n_blocks = depth
813
- self.num_heads = num_heads
814
- self.patch_size = patch_size
815
- self.num_register_tokens = num_register_tokens
816
- self.interpolate_antialias = interpolate_antialias
817
- self.interpolate_offset = interpolate_offset
818
-
819
- if tuning_mode != None:
820
- self.tuning_mode = tuning_mode
821
- if tuning_mode == 'ssf':
822
- self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
823
- else:
824
- pass
825
- #raise NotImplementedError()
826
- else:
827
- self.tuning_mode = None
828
- tuning_mode_list = [tuning_mode] * depth
829
-
830
- self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode)
831
- num_patches = self.patch_embed.num_patches
832
-
833
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
834
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
835
- assert num_register_tokens >= 0
836
- self.register_tokens = (
837
- nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
838
- )
839
-
840
- if drop_path_uniform is True:
841
- dpr = [drop_path_rate] * depth
842
- else:
843
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
844
-
845
- if ffn_layer == "mlp":
846
- logger.info("using MLP layer as FFN")
847
- ffn_layer = Mlp
848
- elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
849
- logger.info("using SwiGLU layer as FFN")
850
- ffn_layer = SwiGLUFFNFused
851
- elif ffn_layer == "identity":
852
- logger.info("using Identity layer as FFN")
853
-
854
- def f(*args, **kwargs):
855
- return nn.Identity()
856
-
857
- ffn_layer = f
858
- else:
859
- raise NotImplementedError
860
-
861
- blocks_list = [
862
- block_fn(
863
- dim=embed_dim,
864
- num_heads=num_heads,
865
- mlp_ratio=mlp_ratio,
866
- qkv_bias=qkv_bias,
867
- proj_bias=proj_bias,
868
- ffn_bias=ffn_bias,
869
- drop_path=dpr[i],
870
- norm_layer=norm_layer,
871
- act_layer=act_layer,
872
- ffn_layer=ffn_layer,
873
- init_values=init_values,
874
- tuning_mode=tuning_mode_list[i]
875
- )
876
- for i in range(depth)
877
- ]
878
- if block_chunks > 0:
879
- self.chunked_blocks = True
880
- chunked_blocks = []
881
- chunksize = depth // block_chunks
882
- for i in range(0, depth, chunksize):
883
- # this is to keep the block index consistent if we chunk the block list
884
- chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
885
- self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
886
- else:
887
- self.chunked_blocks = False
888
- self.blocks = nn.ModuleList(blocks_list)
889
-
890
- self.norm = norm_layer(embed_dim)
891
- self.head = nn.Identity()
892
-
893
- self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
894
-
895
- self.init_weights()
896
-
897
- def init_weights(self):
898
- trunc_normal_(self.pos_embed, std=0.02)
899
- nn.init.normal_(self.cls_token, std=1e-6)
900
- if self.register_tokens is not None:
901
- nn.init.normal_(self.register_tokens, std=1e-6)
902
- named_apply(init_weights_vit_timm, self)
903
-
904
- def interpolate_pos_encoding(self, x, w, h):
905
- previous_dtype = x.dtype
906
- npatch = x.shape[1] - 1
907
- N = self.pos_embed.shape[1] - 1
908
- if npatch == N and w == h:
909
- return self.pos_embed
910
- pos_embed = self.pos_embed.float()
911
- class_pos_embed = pos_embed[:, 0]
912
- patch_pos_embed = pos_embed[:, 1:]
913
- dim = x.shape[-1]
914
- w0 = w // self.patch_size
915
- h0 = h // self.patch_size
916
- # we add a small number to avoid floating point error in the interpolation
917
- # see discussion at https://github.com/facebookresearch/dino/issues/8
918
- w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
919
-
920
- sqrt_N = math.sqrt(N)
921
- sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
922
- patch_pos_embed = nn.functional.interpolate(
923
- patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
924
- scale_factor=(sx, sy),
925
- mode="bicubic",
926
- antialias=self.interpolate_antialias,
927
- )
928
-
929
- assert int(w0) == patch_pos_embed.shape[-2]
930
- assert int(h0) == patch_pos_embed.shape[-1]
931
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
932
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
933
-
934
- def prepare_tokens_with_masks(self, x, masks=None):
935
- B, nc, w, h = x.shape
936
- x = self.patch_embed(x)
937
- if masks is not None:
938
- x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
939
-
940
- x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
941
- x = x + self.interpolate_pos_encoding(x, w, h)
942
-
943
- if self.register_tokens is not None:
944
- x = torch.cat(
945
- (
946
- x[:, :1],
947
- self.register_tokens.expand(x.shape[0], -1, -1),
948
- x[:, 1:],
949
- ),
950
- dim=1,
951
- )
952
-
953
- return x
954
-
955
- def forward_features_list(self, x_list, masks_list):
956
- x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
957
- for blk in self.blocks:
958
- x = blk(x)
959
-
960
- all_x = x
961
- output = []
962
- for x, masks in zip(all_x, masks_list):
963
- x_norm = self.norm(x)
964
- output.append(
965
- {
966
- "x_norm_clstoken": x_norm[:, 0],
967
- "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
968
- "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
969
- "x_prenorm": x,
970
- "masks": masks,
971
- }
972
- )
973
- return output
974
-
975
- def forward_features(self, x, masks=None):
976
- if isinstance(x, list):
977
- return self.forward_features_list(x, masks)
978
-
979
- B, C, H, W = x.size()
980
- pad_h = (self.patch_size - H % self.patch_size)
981
- pad_w = (self.patch_size - W % self.patch_size)
982
- if pad_h == self.patch_size:
983
- pad_h = 0
984
- if pad_w == self.patch_size:
985
- pad_w = 0
986
- #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
987
- if pad_h + pad_w > 0:
988
- x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
989
-
990
- x = self.prepare_tokens_with_masks(x, masks)
991
-
992
- for blk in self.blocks:
993
- x = blk(x)
994
-
995
- x_norm = self.norm(x)
996
- if self.tuning_mode == 'ssf':
997
- x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1)
998
-
999
- # return {
1000
- # "x_norm_clstoken": x_norm[:, 0],
1001
- # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
1002
- # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
1003
- # "x_prenorm": x,
1004
- # "masks": masks,
1005
- # }
1006
- features = []
1007
- features.append(x_norm)
1008
- features.append(x_norm)
1009
- features.append(x_norm)
1010
- features.append(x_norm)
1011
- return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)]
1012
-
1013
-
1014
- def _get_intermediate_layers_not_chunked(self, x, n=1):
1015
- x = self.prepare_tokens_with_masks(x)
1016
- # If n is an int, take the n last blocks. If it's a list, take them
1017
- output, total_block_len = [], len(self.blocks)
1018
- blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
1019
- for i, blk in enumerate(self.blocks):
1020
- x = blk(x)
1021
- if i in blocks_to_take:
1022
- output.append(x)
1023
- assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
1024
- return output
1025
-
1026
- def _get_intermediate_layers_chunked(self, x, n=1):
1027
- x = self.prepare_tokens_with_masks(x)
1028
- output, i, total_block_len = [], 0, len(self.blocks[-1])
1029
- # If n is an int, take the n last blocks. If it's a list, take them
1030
- blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
1031
- for block_chunk in self.blocks:
1032
- for blk in block_chunk[i:]: # Passing the nn.Identity()
1033
- x = blk(x)
1034
- if i in blocks_to_take:
1035
- output.append(x)
1036
- i += 1
1037
- assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
1038
- return output
1039
-
1040
- def get_intermediate_layers(
1041
- self,
1042
- x: torch.Tensor,
1043
- n: Union[int, Sequence] = 1, # Layers or n last layers to take
1044
- reshape: bool = False,
1045
- return_class_token: bool = False,
1046
- norm=True,
1047
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
1048
- if self.chunked_blocks:
1049
- outputs = self._get_intermediate_layers_chunked(x, n)
1050
- else:
1051
- outputs = self._get_intermediate_layers_not_chunked(x, n)
1052
- if norm:
1053
- outputs = [self.norm(out) for out in outputs]
1054
- class_tokens = [out[:, 0] for out in outputs]
1055
- outputs = [out[:, 1:] for out in outputs]
1056
- if reshape:
1057
- B, _, w, h = x.shape
1058
- outputs = [
1059
- out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
1060
- for out in outputs
1061
- ]
1062
- if return_class_token:
1063
- return tuple(zip(outputs, class_tokens))
1064
- return tuple(outputs)
1065
-
1066
- def forward(self, *args, is_training=False, **kwargs):
1067
- ret = self.forward_features(*args, **kwargs)
1068
- return ret
1069
- # if is_training:
1070
- # return ret
1071
- # else:
1072
- # return self.head(ret["x_norm_clstoken"])
1073
-
1074
-
1075
- def init_weights_vit_timm(module: nn.Module, name: str = ""):
1076
- """ViT weight initialization, original timm impl (for reproducibility)"""
1077
- if isinstance(module, nn.Linear):
1078
- trunc_normal_(module.weight, std=0.02)
1079
- if module.bias is not None:
1080
- nn.init.zeros_(module.bias)
1081
-
1082
-
1083
- def load_ckpt_dino(checkpoint, model):
1084
- if checkpoint is not None:
1085
- try:
1086
- with open(checkpoint, "rb") as f:
1087
- state_dict = torch.load(f)
1088
- except:
1089
- print('NO pretrained imagenet ckpt available! Check your path!')
1090
- del model.mask_token
1091
- return
1092
-
1093
- try:
1094
- model.load_state_dict(state_dict, strict=True)
1095
- except:
1096
- new_state_dict = {}
1097
- for key, value in state_dict.items():
1098
- if 'blocks' in key:
1099
- key_new = 'blocks.0' + key[len('blocks'):]
1100
- else:
1101
- key_new = key
1102
- new_state_dict[key_new] = value
1103
-
1104
- model.load_state_dict(new_state_dict, strict=True)
1105
- del model.mask_token
1106
- return
1107
- else:
1108
- return
1109
-
1110
-
1111
- def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
1112
- model = DinoVisionTransformer(
1113
- patch_size=patch_size,
1114
- embed_dim=384,
1115
- depth=12,
1116
- num_heads=6,
1117
- mlp_ratio=4,
1118
- block_fn=partial(Block, attn_class=MemEffAttention),
1119
- num_register_tokens=num_register_tokens,
1120
- **kwargs,
1121
- )
1122
-
1123
- load_ckpt_dino(checkpoint, model)
1124
-
1125
- return model
1126
-
1127
-
1128
- def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
1129
- model = DinoVisionTransformer(
1130
- patch_size=patch_size,
1131
- embed_dim=768,
1132
- depth=12,
1133
- num_heads=12,
1134
- mlp_ratio=4,
1135
- block_fn=partial(Block, attn_class=MemEffAttention),
1136
- num_register_tokens=num_register_tokens,
1137
- **kwargs,
1138
- )
1139
- return model
1140
-
1141
-
1142
- def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
1143
- model = DinoVisionTransformer(
1144
- patch_size=patch_size,
1145
- embed_dim=1024,
1146
- depth=24,
1147
- num_heads=16,
1148
- mlp_ratio=4,
1149
- block_fn=partial(Block, attn_class=MemEffAttention),
1150
- num_register_tokens=num_register_tokens,
1151
- **kwargs,
1152
- )
1153
-
1154
- if checkpoint is not None:
1155
- with open(checkpoint, "rb") as f:
1156
- state_dict = torch.load(f)
1157
- try:
1158
- model.load_state_dict(state_dict, strict=True)
1159
- except:
1160
- new_state_dict = {}
1161
- for key, value in state_dict.items():
1162
- if 'blocks' in key:
1163
- key_new = 'blocks.0' + key[len('blocks'):]
1164
- else:
1165
- key_new = key
1166
- new_state_dict[key_new] = value
1167
-
1168
- model.load_state_dict(new_state_dict, strict=True)
1169
- del model.mask_token
1170
- return model
1171
-
1172
-
1173
- def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
1174
- """
1175
- Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
1176
- """
1177
- model = DinoVisionTransformer(
1178
- patch_size=patch_size,
1179
- embed_dim=1536,
1180
- depth=40,
1181
- num_heads=24,
1182
- mlp_ratio=4,
1183
- block_fn=partial(Block, attn_class=MemEffAttention),
1184
- num_register_tokens=num_register_tokens,
1185
- ffn_layer='swiglu',
1186
- **kwargs,
1187
- )
1188
- return model
1189
-
1190
-
1191
-
1192
- def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
1193
- model = DinoVisionTransformer(
1194
- patch_size=patch_size,
1195
- embed_dim=384,
1196
- depth=12,
1197
- num_heads=6,
1198
- mlp_ratio=4,
1199
- block_fn=partial(Block, attn_class=MemEffAttention),
1200
- num_register_tokens=num_register_tokens,
1201
- tuning_mode=tuning_mode,
1202
- **kwargs,
1203
- )
1204
-
1205
- load_ckpt_dino(checkpoint, model)
1206
-
1207
- return model
1208
-
1209
-
1210
- def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs):
1211
- model = DinoVisionTransformer(
1212
- patch_size=patch_size,
1213
- embed_dim=768,
1214
- depth=12,
1215
- num_heads=12,
1216
- mlp_ratio=4,
1217
- block_fn=partial(Block, attn_class=MemEffAttention),
1218
- num_register_tokens=num_register_tokens,
1219
- **kwargs,
1220
- )
1221
-
1222
- load_ckpt_dino(checkpoint, model)
1223
-
1224
- return model
1225
-
1226
-
1227
- def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
1228
- model = DinoVisionTransformer(
1229
- img_size = 518,
1230
- patch_size=patch_size,
1231
- embed_dim=1024,
1232
- depth=24,
1233
- num_heads=16,
1234
- mlp_ratio=4,
1235
- block_fn=partial(Block, attn_class=MemEffAttention),
1236
- num_register_tokens=num_register_tokens,
1237
- tuning_mode=tuning_mode,
1238
- **kwargs,
1239
- )
1240
-
1241
- load_ckpt_dino(checkpoint, model)
1242
-
1243
- return model
1244
-
1245
-
1246
- def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
1247
- """
1248
- Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
1249
- """
1250
- model = DinoVisionTransformer(
1251
- patch_size=patch_size,
1252
- embed_dim=1536,
1253
- depth=40,
1254
- num_heads=24,
1255
- mlp_ratio=4,
1256
- block_fn=partial(Block, attn_class=MemEffAttention),
1257
- num_register_tokens=num_register_tokens,
1258
- ffn_layer='swiglu',
1259
- tuning_mode=tuning_mode,
1260
- **kwargs,
1261
- )
1262
-
1263
- load_ckpt_dino(checkpoint, model)
1264
-
1265
- return model
1266
-
1267
- if __name__ == '__main__':
1268
- try:
1269
- from mmcv.utils import Config
1270
- except:
1271
- from mmengine import Config
1272
-
1273
- #rgb = torch.rand((2, 3, 518, 518)).cuda()
1274
-
1275
- #cfg.data_basic['crop_size']['0']
1276
- #cfg.data_basic['crop_size']['1']
1277
- cfg = Config.fromfile('/opt/ml/project/mu.hu/projects/monodepth_vit/mono/configs/RAFTDecoder/vit.raft5.large.kitti.py')
1278
-
1279
- #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
1280
- rgb = torch.zeros(1, 3, 616, 1064).cuda()
1281
- cfg['tuning_mode'] = 'ssf'
1282
- #model = vit_large_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_reg4_pretrain.pth", kwarg=cfg).cuda()
1283
- model = vit_large_reg(tuning_mode='ssf').cuda()
1284
-
1285
- #import timm
1286
- #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
1287
- #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
1288
-
1289
- out1 = model(rgb)
1290
- #out2 = model2(rgb)
1291
- temp = 0
1292
-
1293
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/backbones/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- from .ConvNeXt import convnext_xlarge
2
- from .ConvNeXt import convnext_small
3
- from .ConvNeXt import convnext_base
4
- from .ConvNeXt import convnext_large
5
- from .ConvNeXt import convnext_tiny
6
- from .ViT_DINO import vit_large
7
- from .ViT_DINO_reg import vit_small_reg, vit_large_reg
8
-
9
- __all__ = [
10
- 'convnext_xlarge', 'convnext_small', 'convnext_base', 'convnext_large', 'convnext_tiny', 'vit_small_reg', 'vit_large_reg'
11
- ]
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc DELETED
Binary file (9.37 kB)
 
mono/model/backbones/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (410 Bytes)
 
mono/model/decode_heads/HourGlassDecoder.py DELETED
@@ -1,274 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import math
5
- import torch.nn.functional as F
6
-
7
- def compute_depth_expectation(prob, depth_values):
8
- depth_values = depth_values.view(*depth_values.shape, 1, 1)
9
- depth = torch.sum(prob * depth_values, 1)
10
- return depth
11
-
12
- class ConvBlock(nn.Module):
13
- def __init__(self, in_channels, out_channels, kernel_size=3):
14
- super(ConvBlock, self).__init__()
15
-
16
- if kernel_size == 3:
17
- self.conv = nn.Sequential(
18
- nn.ReflectionPad2d(1),
19
- nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
20
- )
21
- elif kernel_size == 1:
22
- self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)
23
-
24
- self.nonlin = nn.ELU(inplace=True)
25
-
26
- def forward(self, x):
27
- out = self.conv(x)
28
- out = self.nonlin(out)
29
- return out
30
-
31
-
32
- class ConvBlock_double(nn.Module):
33
- def __init__(self, in_channels, out_channels, kernel_size=3):
34
- super(ConvBlock_double, self).__init__()
35
-
36
- if kernel_size == 3:
37
- self.conv = nn.Sequential(
38
- nn.ReflectionPad2d(1),
39
- nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
40
- )
41
- elif kernel_size == 1:
42
- self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)
43
-
44
- self.nonlin = nn.ELU(inplace=True)
45
- self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1)
46
- self.nonlin_2 =nn.ELU(inplace=True)
47
-
48
- def forward(self, x):
49
- out = self.conv(x)
50
- out = self.nonlin(out)
51
- out = self.conv_2(out)
52
- out = self.nonlin_2(out)
53
- return out
54
-
55
- class DecoderFeature(nn.Module):
56
- def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]):
57
- super(DecoderFeature, self).__init__()
58
- self.num_ch_dec = num_ch_dec
59
- self.feat_channels = feat_channels
60
-
61
- self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1)
62
- self.upconv_3_1 = ConvBlock_double(
63
- self.feat_channels[2] + self.num_ch_dec[3],
64
- self.num_ch_dec[3],
65
- kernel_size=1)
66
-
67
- self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3)
68
- self.upconv_2_1 = ConvBlock_double(
69
- self.feat_channels[1] + self.num_ch_dec[2],
70
- self.num_ch_dec[2],
71
- kernel_size=3)
72
-
73
- self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3)
74
- self.upconv_1_1 = ConvBlock_double(
75
- self.feat_channels[0] + self.num_ch_dec[1],
76
- self.num_ch_dec[1],
77
- kernel_size=3)
78
- self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
79
-
80
- def forward(self, ref_feature):
81
- x = ref_feature[3]
82
-
83
- x = self.upconv_3_0(x)
84
- x = torch.cat((self.upsample(x), ref_feature[2]), 1)
85
- x = self.upconv_3_1(x)
86
-
87
- x = self.upconv_2_0(x)
88
- x = torch.cat((self.upsample(x), ref_feature[1]), 1)
89
- x = self.upconv_2_1(x)
90
-
91
- x = self.upconv_1_0(x)
92
- x = torch.cat((self.upsample(x), ref_feature[0]), 1)
93
- x = self.upconv_1_1(x)
94
- return x
95
-
96
-
97
- class UNet(nn.Module):
98
- def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'):
99
- super(UNet, self).__init__()
100
- basic_block = ConvBnReLU
101
- num_depth = 128
102
-
103
- self.conv0 = basic_block(inp_ch, num_depth)
104
- if channel_mode == 'v0':
105
- channels = [num_depth, num_depth//2, num_depth//4, num_depth//8, num_depth // 8]
106
- elif channel_mode == 'v1':
107
- channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth]
108
- self.down_sample_times = down_sample_times
109
- for i in range(down_sample_times):
110
- setattr(
111
- self, 'conv_%d' % i,
112
- nn.Sequential(
113
- basic_block(channels[i], channels[i+1], stride=2),
114
- basic_block(channels[i+1], channels[i+1])
115
- )
116
- )
117
- for i in range(down_sample_times-1,-1,-1):
118
- setattr(self, 'deconv_%d' % i,
119
- nn.Sequential(
120
- nn.ConvTranspose2d(
121
- channels[i+1],
122
- channels[i],
123
- kernel_size=3,
124
- padding=1,
125
- output_padding=1,
126
- stride=2,
127
- bias=False),
128
- nn.BatchNorm2d(channels[i]),
129
- nn.ReLU(inplace=True)
130
- )
131
- )
132
- self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0)
133
-
134
- def forward(self, x):
135
- features = {}
136
- conv0 = self.conv0(x)
137
- x = conv0
138
- features[0] = conv0
139
- for i in range(self.down_sample_times):
140
- x = getattr(self, 'conv_%d' % i)(x)
141
- features[i+1] = x
142
- for i in range(self.down_sample_times-1,-1,-1):
143
- x = features[i] + getattr(self, 'deconv_%d' % i)(x)
144
- x = self.prob(x)
145
- return x
146
-
147
- class ConvBnReLU(nn.Module):
148
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
149
- super(ConvBnReLU, self).__init__()
150
- self.conv = nn.Conv2d(
151
- in_channels,
152
- out_channels,
153
- kernel_size,
154
- stride=stride,
155
- padding=pad,
156
- bias=False
157
- )
158
- self.bn = nn.BatchNorm2d(out_channels)
159
-
160
- def forward(self, x):
161
- return F.relu(self.bn(self.conv(x)), inplace=True)
162
-
163
-
164
- class HourglassDecoder(nn.Module):
165
- def __init__(self, cfg):
166
- super(HourglassDecoder, self).__init__()
167
- self.inchannels = cfg.model.decode_head.in_channels # [256, 512, 1024, 2048]
168
- self.decoder_channels = cfg.model.decode_head.decoder_channel # [64, 64, 128, 256]
169
- self.min_val = cfg.data_basic.depth_normalize[0]
170
- self.max_val = cfg.data_basic.depth_normalize[1]
171
-
172
- self.num_ch_dec = self.decoder_channels # [64, 64, 128, 256]
173
- self.num_depth_regressor_anchor = 512
174
- self.feat_channels = self.inchannels
175
- unet_in_channel = self.num_ch_dec[1]
176
- unet_out_channel = 256
177
-
178
- self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec)
179
- self.conv_out_2 = UNet(inp_ch=unet_in_channel,
180
- output_chal=unet_out_channel + 1,
181
- down_sample_times=3,
182
- channel_mode='v0',
183
- )
184
-
185
- self.depth_regressor_2 = nn.Sequential(
186
- nn.Conv2d(unet_out_channel,
187
- self.num_depth_regressor_anchor,
188
- kernel_size=3,
189
- padding=1,
190
- ),
191
- nn.BatchNorm2d(self.num_depth_regressor_anchor),
192
- nn.ReLU(inplace=True),
193
- nn.Conv2d(
194
- self.num_depth_regressor_anchor,
195
- self.num_depth_regressor_anchor,
196
- kernel_size=1,
197
- )
198
- )
199
- self.residual_channel = 16
200
- self.conv_up_2 = nn.Sequential(
201
- nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1),
202
- nn.BatchNorm2d(self.residual_channel),
203
- nn.ReLU(),
204
- nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
205
- nn.Upsample(scale_factor=4),
206
- nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
207
- nn.ReLU(),
208
- nn.Conv2d(self.residual_channel, 1, 1, padding=0),
209
- )
210
-
211
- def get_bins(self, bins_num):
212
- depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda')
213
- depth_bins_vec = torch.exp(depth_bins_vec)
214
- return depth_bins_vec
215
-
216
- def register_depth_expectation_anchor(self, bins_num, B):
217
- depth_bins_vec = self.get_bins(bins_num)
218
- depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
219
- self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
220
-
221
- def upsample(self, x, scale_factor=2):
222
- return F.interpolate(x, scale_factor=scale_factor, mode='nearest')
223
-
224
- def regress_depth_2(self, feature_map_d):
225
- prob = self.depth_regressor_2(feature_map_d).softmax(dim=1)
226
- B = prob.shape[0]
227
- if "depth_expectation_anchor" not in self._buffers:
228
- self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
229
- d = compute_depth_expectation(
230
- prob,
231
- self.depth_expectation_anchor[:B, ...]
232
- ).unsqueeze(1)
233
- return d
234
-
235
- def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
236
- y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
237
- torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
238
- meshgrid = torch.stack((x, y))
239
- meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
240
- return meshgrid
241
-
242
- def forward(self, features_mono, **kwargs):
243
- '''
244
- trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4]
245
- inv_intrinsic_pool: list of inverse intrinsic matrix.
246
- features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...].
247
- '''
248
- outputs = {}
249
- # get encoder feature of the reference view
250
- ref_feat = features_mono
251
-
252
- feature_map_mono = self.decoder_mono(ref_feat)
253
- feature_map_mono_pred = self.conv_out_2(feature_map_mono)
254
- confidence_map_2 = feature_map_mono_pred[:, -1:, :, :]
255
- feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :]
256
-
257
- depth_pred_2 = self.regress_depth_2(feature_map_d_2)
258
-
259
- B, _, H, W = depth_pred_2.shape
260
-
261
- meshgrid = self.create_mesh_grid(H, W, B)
262
-
263
- depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \
264
- self.conv_up_2(
265
- torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1)
266
- )
267
- confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4)
268
-
269
- outputs=dict(
270
- prediction=depth_pred_mono,
271
- confidence=confidence_map_mono,
272
- pred_logit=None,
273
- )
274
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py DELETED
@@ -1,1033 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import math
5
- import torch.nn.functional as F
6
-
7
- # LORA finetuning originally by edwardjhu
8
- class LoRALayer():
9
- def __init__(
10
- self,
11
- r: int,
12
- lora_alpha: int,
13
- lora_dropout: float,
14
- merge_weights: bool,
15
- ):
16
- self.r = r
17
- self.lora_alpha = lora_alpha
18
- # Optional dropout
19
- if lora_dropout > 0.:
20
- self.lora_dropout = nn.Dropout(p=lora_dropout)
21
- else:
22
- self.lora_dropout = lambda x: x
23
- # Mark the weight as unmerged
24
- self.merged = False
25
- self.merge_weights = merge_weights
26
-
27
- class LoRALinear(nn.Linear, LoRALayer):
28
- # LoRA implemented in a dense layer
29
- def __init__(
30
- self,
31
- in_features: int,
32
- out_features: int,
33
- r: int = 0,
34
- lora_alpha: int = 1,
35
- lora_dropout: float = 0.,
36
- fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
37
- merge_weights: bool = True,
38
- **kwargs
39
- ):
40
- nn.Linear.__init__(self, in_features, out_features, **kwargs)
41
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
42
- merge_weights=merge_weights)
43
-
44
- self.fan_in_fan_out = fan_in_fan_out
45
- # Actual trainable parameters
46
- if r > 0:
47
- self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
48
- self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
49
- self.scaling = self.lora_alpha / self.r
50
- # Freezing the pre-trained weight matrix
51
- self.weight.requires_grad = False
52
- self.reset_parameters()
53
- if fan_in_fan_out:
54
- self.weight.data = self.weight.data.transpose(0, 1)
55
-
56
- def reset_parameters(self):
57
- #nn.Linear.reset_parameters(self)
58
- if hasattr(self, 'lora_A'):
59
- # initialize B the same way as the default for nn.Linear and A to zero
60
- # this is different than what is described in the paper but should not affect performance
61
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
62
- nn.init.zeros_(self.lora_B)
63
-
64
- # def train(self, mode: bool = True):
65
- # def T(w):
66
- # return w.transpose(0, 1) if self.fan_in_fan_out else w
67
- # nn.Linear.train(self, mode)
68
- # if mode:
69
- # if self.merge_weights and self.merged:
70
- # # Make sure that the weights are not merged
71
- # if self.r > 0:
72
- # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
73
- # self.merged = False
74
- # else:
75
- # if self.merge_weights and not self.merged:
76
- # # Merge the weights and mark it
77
- # if self.r > 0:
78
- # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
79
- # self.merged = True
80
-
81
- def forward(self, x: torch.Tensor):
82
- def T(w):
83
- return w.transpose(0, 1) if self.fan_in_fan_out else w
84
- if self.r > 0 and not self.merged:
85
- result = F.linear(x, T(self.weight), bias=self.bias)
86
- result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
87
- return result
88
- else:
89
- return F.linear(x, T(self.weight), bias=self.bias)
90
-
91
- class ConvLoRA(nn.Conv2d, LoRALayer):
92
- def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
93
- #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
94
- nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
95
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
96
- assert isinstance(kernel_size, int)
97
-
98
- # Actual trainable parameters
99
- if r > 0:
100
- self.lora_A = nn.Parameter(
101
- self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
102
- )
103
- self.lora_B = nn.Parameter(
104
- self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
105
- )
106
- self.scaling = self.lora_alpha / self.r
107
- # Freezing the pre-trained weight matrix
108
- self.weight.requires_grad = False
109
- self.reset_parameters()
110
- self.merged = False
111
-
112
- def reset_parameters(self):
113
- #self.conv.reset_parameters()
114
- if hasattr(self, 'lora_A'):
115
- # initialize A the same way as the default for nn.Linear and B to zero
116
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
117
- nn.init.zeros_(self.lora_B)
118
-
119
- # def train(self, mode=True):
120
- # super(ConvLoRA, self).train(mode)
121
- # if mode:
122
- # if self.merge_weights and self.merged:
123
- # if self.r > 0:
124
- # # Make sure that the weights are not merged
125
- # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
126
- # self.merged = False
127
- # else:
128
- # if self.merge_weights and not self.merged:
129
- # if self.r > 0:
130
- # # Merge the weights and mark it
131
- # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
132
- # self.merged = True
133
-
134
- def forward(self, x):
135
- if self.r > 0 and not self.merged:
136
- # return self.conv._conv_forward(
137
- # x,
138
- # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
139
- # self.conv.bias
140
- # )
141
- weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
142
- bias = self.bias
143
-
144
- return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
145
- else:
146
- return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
147
-
148
- class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer):
149
- def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
150
- #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
151
- nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
152
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
153
- assert isinstance(kernel_size, int)
154
-
155
- # Actual trainable parameters
156
- if r > 0:
157
- self.lora_A = nn.Parameter(
158
- self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
159
- )
160
- self.lora_B = nn.Parameter(
161
- self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
162
- )
163
- self.scaling = self.lora_alpha / self.r
164
- # Freezing the pre-trained weight matrix
165
- self.weight.requires_grad = False
166
- self.reset_parameters()
167
- self.merged = False
168
-
169
- def reset_parameters(self):
170
- #self.conv.reset_parameters()
171
- if hasattr(self, 'lora_A'):
172
- # initialize A the same way as the default for nn.Linear and B to zero
173
- nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
174
- nn.init.zeros_(self.lora_B)
175
-
176
- # def train(self, mode=True):
177
- # super(ConvTransposeLoRA, self).train(mode)
178
- # if mode:
179
- # if self.merge_weights and self.merged:
180
- # if self.r > 0:
181
- # # Make sure that the weights are not merged
182
- # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
183
- # self.merged = False
184
- # else:
185
- # if self.merge_weights and not self.merged:
186
- # if self.r > 0:
187
- # # Merge the weights and mark it
188
- # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
189
- # self.merged = True
190
-
191
- def forward(self, x):
192
- if self.r > 0 and not self.merged:
193
- weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
194
- bias = self.bias
195
- return F.conv_transpose2d(x, weight,
196
- bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
197
- groups=self.groups, dilation=self.dilation)
198
- else:
199
- return F.conv_transpose2d(x, self.weight,
200
- bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
201
- groups=self.groups, dilation=self.dilation)
202
- #return self.conv(x)
203
-
204
- class Conv2dLoRA(ConvLoRA):
205
- def __init__(self, *args, **kwargs):
206
- super(Conv2dLoRA, self).__init__(*args, **kwargs)
207
-
208
- class ConvTranspose2dLoRA(ConvTransposeLoRA):
209
- def __init__(self, *args, **kwargs):
210
- super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs)
211
-
212
-
213
- def compute_depth_expectation(prob, depth_values):
214
- depth_values = depth_values.view(*depth_values.shape, 1, 1)
215
- depth = torch.sum(prob * depth_values, 1)
216
- return depth
217
-
218
- def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
219
- with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
220
- return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
221
-
222
- # def upflow8(flow, mode='bilinear'):
223
- # new_size = (8 * flow.shape[2], 8 * flow.shape[3])
224
- # return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
225
-
226
- def upflow4(flow, mode='bilinear'):
227
- new_size = (4 * flow.shape[2], 4 * flow.shape[3])
228
- with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
229
- return F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
230
-
231
- def coords_grid(batch, ht, wd):
232
- # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
233
- coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)))
234
- coords = torch.stack(coords[::-1], dim=0).float()
235
- return coords[None].repeat(batch, 1, 1, 1)
236
-
237
- def norm_normalize(norm_out):
238
- min_kappa = 0.01
239
- norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
240
- norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
241
- kappa = F.elu(kappa) + 1.0 + min_kappa
242
- final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
243
- return final_out
244
-
245
- # uncertainty-guided sampling (only used during training)
246
- @torch.no_grad()
247
- def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
248
- device = init_normal.device
249
- B, _, H, W = init_normal.shape
250
- N = int(sampling_ratio * H * W)
251
- beta = beta
252
-
253
- # uncertainty map
254
- uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W
255
-
256
- # gt_invalid_mask (B, H, W)
257
- if gt_norm_mask is not None:
258
- gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
259
- gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
260
- uncertainty_map[gt_invalid_mask] = -1e4
261
-
262
- # (B, H*W)
263
- _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
264
-
265
- # importance sampling
266
- if int(beta * N) > 0:
267
- importance = idx[:, :int(beta * N)] # B, beta*N
268
-
269
- # remaining
270
- remaining = idx[:, int(beta * N):] # B, H*W - beta*N
271
-
272
- # coverage
273
- num_coverage = N - int(beta * N)
274
-
275
- if num_coverage <= 0:
276
- samples = importance
277
- else:
278
- coverage_list = []
279
- for i in range(B):
280
- idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
281
- coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
282
- coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
283
- samples = torch.cat((importance, coverage), dim=1) # B, N
284
-
285
- else:
286
- # remaining
287
- remaining = idx[:, :] # B, H*W
288
-
289
- # coverage
290
- num_coverage = N
291
-
292
- coverage_list = []
293
- for i in range(B):
294
- idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
295
- coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
296
- coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
297
- samples = coverage
298
-
299
- # point coordinates
300
- rows_int = samples // W # 0 for first row, H-1 for last row
301
- rows_float = rows_int / float(H-1) # 0 to 1.0
302
- rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
303
-
304
- cols_int = samples % W # 0 for first column, W-1 for last column
305
- cols_float = cols_int / float(W-1) # 0 to 1.0
306
- cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
307
-
308
- point_coords = torch.zeros(B, 1, N, 2)
309
- point_coords[:, 0, :, 0] = cols_float # x coord
310
- point_coords[:, 0, :, 1] = rows_float # y coord
311
- point_coords = point_coords.to(device)
312
- return point_coords, rows_int, cols_int
313
-
314
- class FlowHead(nn.Module):
315
- def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None):
316
- super(FlowHead, self).__init__()
317
- self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
318
- self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
319
-
320
- self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
321
- self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
322
- self.relu = nn.ReLU(inplace=True)
323
-
324
- def forward(self, x):
325
- depth = self.conv2d(self.relu(self.conv1d(x)))
326
- normal = self.conv2n(self.relu(self.conv1n(x)))
327
- return torch.cat((depth, normal), dim=1)
328
-
329
-
330
- class ConvGRU(nn.Module):
331
- def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None):
332
- super(ConvGRU, self).__init__()
333
- self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
334
- self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
335
- self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
336
-
337
- def forward(self, h, cz, cr, cq, *x_list):
338
- x = torch.cat(x_list, dim=1)
339
- hx = torch.cat([h, x], dim=1)
340
-
341
- z = torch.sigmoid((self.convz(hx) + cz))
342
- r = torch.sigmoid((self.convr(hx) + cr))
343
- q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq))
344
-
345
- # z = torch.sigmoid((self.convz(hx) + cz).float())
346
- # r = torch.sigmoid((self.convr(hx) + cr).float())
347
- # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float())
348
-
349
- h = (1-z) * h + z * q
350
- return h
351
-
352
- def pool2x(x):
353
- return F.avg_pool2d(x, 3, stride=2, padding=1)
354
-
355
- def pool4x(x):
356
- return F.avg_pool2d(x, 5, stride=4, padding=1)
357
-
358
- def interp(x, dest):
359
- interp_args = {'mode': 'bilinear', 'align_corners': True}
360
- return interpolate_float32(x, dest.shape[2:], **interp_args)
361
-
362
- class BasicMultiUpdateBlock(nn.Module):
363
- def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None):
364
- super().__init__()
365
- self.args = args
366
- self.n_gru_layers = args.model.decode_head.n_gru_layers # 3
367
- self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
368
-
369
- # self.encoder = BasicMotionEncoder(args)
370
- # encoder_output_dim = 128 # if there is corr volume
371
- encoder_output_dim = 6 # no corr volume
372
-
373
- self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode)
374
- self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode)
375
- self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode)
376
- self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode)
377
- factor = 2**self.n_downsample
378
-
379
- self.mask = nn.Sequential(
380
- Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0),
381
- nn.ReLU(inplace=True),
382
- Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0))
383
-
384
- def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True):
385
-
386
- if iter32:
387
- net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1]))
388
- if iter16:
389
- if self.n_gru_layers > 2:
390
- net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1]))
391
- else:
392
- net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]))
393
- if iter08:
394
- if corr is not None:
395
- motion_features = self.encoder(flow, corr)
396
- else:
397
- motion_features = flow
398
- if self.n_gru_layers > 1:
399
- net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
400
- else:
401
- net[0] = self.gru08(net[0], *(inp[0]), motion_features)
402
-
403
- if not update:
404
- return net
405
-
406
- delta_flow = self.flow_head(net[0])
407
-
408
- # scale mask to balence gradients
409
- mask = .25 * self.mask(net[0])
410
- return net, mask, delta_flow
411
-
412
- class LayerNorm2d(nn.LayerNorm):
413
- def __init__(self, dim):
414
- super(LayerNorm2d, self).__init__(dim)
415
-
416
- def forward(self, x):
417
- x = x.permute(0, 2, 3, 1).contiguous()
418
- x = super(LayerNorm2d, self).forward(x)
419
- x = x.permute(0, 3, 1, 2).contiguous()
420
- return x
421
-
422
- class ResidualBlock(nn.Module):
423
- def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None):
424
- super(ResidualBlock, self).__init__()
425
-
426
- self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0)
427
- self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
428
- self.relu = nn.ReLU(inplace=True)
429
-
430
- num_groups = planes // 8
431
-
432
- if norm_fn == 'group':
433
- self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
434
- self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
435
- if not (stride == 1 and in_planes == planes):
436
- self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
437
-
438
- elif norm_fn == 'batch':
439
- self.norm1 = nn.BatchNorm2d(planes)
440
- self.norm2 = nn.BatchNorm2d(planes)
441
- if not (stride == 1 and in_planes == planes):
442
- self.norm3 = nn.BatchNorm2d(planes)
443
-
444
- elif norm_fn == 'instance':
445
- self.norm1 = nn.InstanceNorm2d(planes)
446
- self.norm2 = nn.InstanceNorm2d(planes)
447
- if not (stride == 1 and in_planes == planes):
448
- self.norm3 = nn.InstanceNorm2d(planes)
449
-
450
- elif norm_fn == 'layer':
451
- self.norm1 = LayerNorm2d(planes)
452
- self.norm2 = LayerNorm2d(planes)
453
- if not (stride == 1 and in_planes == planes):
454
- self.norm3 = LayerNorm2d(planes)
455
-
456
- elif norm_fn == 'none':
457
- self.norm1 = nn.Sequential()
458
- self.norm2 = nn.Sequential()
459
- if not (stride == 1 and in_planes == planes):
460
- self.norm3 = nn.Sequential()
461
-
462
- if stride == 1 and in_planes == planes:
463
- self.downsample = None
464
-
465
- else:
466
- self.downsample = nn.Sequential(
467
- Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0), self.norm3)
468
-
469
- def forward(self, x):
470
- y = x
471
- y = self.conv1(y)
472
- y = self.norm1(y)
473
- y = self.relu(y)
474
- y = self.conv2(y)
475
- y = self.norm2(y)
476
- y = self.relu(y)
477
-
478
- if self.downsample is not None:
479
- x = self.downsample(x)
480
-
481
- return self.relu(x+y)
482
-
483
-
484
- class ContextFeatureEncoder(nn.Module):
485
- '''
486
- Encoder features are used to:
487
- 1. initialize the hidden state of the update operator
488
- 2. and also injected into the GRU during each iteration of the update operator
489
- '''
490
- def __init__(self, in_dim, output_dim, tuning_mode=None):
491
- '''
492
- in_dim = [x4, x8, x16, x32]
493
- output_dim = [hindden_dims, context_dims]
494
- [[x4,x8,x16,x32],[x4,x8,x16,x32]]
495
- '''
496
- super().__init__()
497
-
498
- output_list = []
499
- for dim in output_dim:
500
- conv_out = nn.Sequential(
501
- ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode),
502
- Conv2dLoRA(dim[0], dim[0], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
503
- output_list.append(conv_out)
504
-
505
- self.outputs04 = nn.ModuleList(output_list)
506
-
507
- output_list = []
508
- for dim in output_dim:
509
- conv_out = nn.Sequential(
510
- ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode),
511
- Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
512
- output_list.append(conv_out)
513
-
514
- self.outputs08 = nn.ModuleList(output_list)
515
-
516
- output_list = []
517
- for dim in output_dim:
518
- conv_out = nn.Sequential(
519
- ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode),
520
- Conv2dLoRA(dim[2], dim[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
521
- output_list.append(conv_out)
522
-
523
- self.outputs16 = nn.ModuleList(output_list)
524
-
525
- # output_list = []
526
- # for dim in output_dim:
527
- # conv_out = Conv2dLoRA(in_dim[3], dim[3], 3, padding=1)
528
- # output_list.append(conv_out)
529
-
530
- # self.outputs32 = nn.ModuleList(output_list)
531
-
532
- def forward(self, encoder_features):
533
- x_4, x_8, x_16, x_32 = encoder_features
534
-
535
- outputs04 = [f(x_4) for f in self.outputs04]
536
- outputs08 = [f(x_8) for f in self.outputs08]
537
- outputs16 = [f(x_16)for f in self.outputs16]
538
- # outputs32 = [f(x_32) for f in self.outputs32]
539
-
540
- return (outputs04, outputs08, outputs16)
541
-
542
- class ConvBlock(nn.Module):
543
- # reimplementation of DPT
544
- def __init__(self, channels, tuning_mode=None):
545
- super(ConvBlock, self).__init__()
546
-
547
- self.act = nn.ReLU(inplace=True)
548
- self.conv1 = Conv2dLoRA(
549
- channels,
550
- channels,
551
- kernel_size=3,
552
- stride=1,
553
- padding=1,
554
- r = 8 if tuning_mode == 'lora' else 0
555
- )
556
- self.conv2 = Conv2dLoRA(
557
- channels,
558
- channels,
559
- kernel_size=3,
560
- stride=1,
561
- padding=1,
562
- r = 8 if tuning_mode == 'lora' else 0
563
- )
564
-
565
- def forward(self, x):
566
- out = self.act(x)
567
- out = self.conv1(out)
568
- out = self.act(out)
569
- out = self.conv2(out)
570
- return x + out
571
-
572
- class FuseBlock(nn.Module):
573
- # reimplementation of DPT
574
- def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None):
575
- super(FuseBlock, self).__init__()
576
-
577
- self.fuse = fuse
578
- self.scale_factor = scale_factor
579
- self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode)
580
- if self.fuse:
581
- self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode)
582
-
583
- self.out_conv = Conv2dLoRA(
584
- in_channels,
585
- out_channels,
586
- kernel_size=1,
587
- stride=1,
588
- padding=0,
589
- r = 8 if tuning_mode == 'lora' else 0
590
- )
591
- self.upsample = upsample
592
-
593
- def forward(self, x1, x2=None):
594
- if x2 is not None:
595
- x2 = self.way_branch(x2)
596
- x1 = x1 + x2
597
-
598
- out = self.way_trunk(x1)
599
-
600
- if self.upsample:
601
- out = interpolate_float32(
602
- out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
603
- )
604
- out = self.out_conv(out)
605
- return out
606
-
607
- class Readout(nn.Module):
608
- # From DPT
609
- def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
610
- super(Readout, self).__init__()
611
- self.use_cls_token = use_cls_token
612
- if self.use_cls_token == True:
613
- self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0)
614
- self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0)
615
- self.act = nn.GELU()
616
- else:
617
- self.project = nn.Identity()
618
-
619
- def forward(self, x):
620
-
621
- if self.use_cls_token == True:
622
- x_patch = self.project_patch(x[0])
623
- x_learn = self.project_learn(x[1])
624
- x_learn = x_learn.expand_as(x_patch).contiguous()
625
- features = x_patch + x_learn
626
- return self.act(features)
627
- else:
628
- return self.project(x)
629
-
630
- class Token2Feature(nn.Module):
631
- # From DPT
632
- def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
633
- super(Token2Feature, self).__init__()
634
- self.scale_factor = scale_factor
635
- self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
636
- if scale_factor > 1 and isinstance(scale_factor, int):
637
- self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
638
- in_channels=vit_channel,
639
- out_channels=feature_channel,
640
- kernel_size=scale_factor,
641
- stride=scale_factor,
642
- padding=0,
643
- )
644
-
645
- elif scale_factor > 1:
646
- self.sample = nn.Sequential(
647
- # Upsample2(upscale=scale_factor),
648
- # nn.Upsample(scale_factor=scale_factor),
649
- Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
650
- in_channels=vit_channel,
651
- out_channels=feature_channel,
652
- kernel_size=1,
653
- stride=1,
654
- padding=0,
655
- ),
656
- )
657
-
658
-
659
- elif scale_factor < 1:
660
- scale_factor = int(1.0 / scale_factor)
661
- self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
662
- in_channels=vit_channel,
663
- out_channels=feature_channel,
664
- kernel_size=scale_factor+1,
665
- stride=scale_factor,
666
- padding=1,
667
- )
668
-
669
- else:
670
- self.sample = nn.Identity()
671
-
672
- def forward(self, x):
673
- x = self.readoper(x)
674
- #if use_cls_token == True:
675
- x = x.permute(0, 3, 1, 2).contiguous()
676
- if isinstance(self.scale_factor, float):
677
- x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest')
678
- x = self.sample(x)
679
- return x
680
-
681
- class EncoderFeature(nn.Module):
682
- def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None):
683
- super(EncoderFeature, self).__init__()
684
- self.vit_channel = vit_channel
685
- self.num_ch_dec = num_ch_dec
686
-
687
- self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
688
- self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
689
- self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
690
- self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
691
-
692
- def forward(self, ref_feature):
693
- x = self.read_3(ref_feature[3]) # 1/14
694
- x2 = self.read_2(ref_feature[2]) # 1/14
695
- x1 = self.read_1(ref_feature[1]) # 1/7
696
- x0 = self.read_0(ref_feature[0]) # 1/4
697
-
698
- return x, x2, x1, x0
699
-
700
- class DecoderFeature(nn.Module):
701
- def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None):
702
- super(DecoderFeature, self).__init__()
703
- self.vit_channel = vit_channel
704
- self.num_ch_dec = num_ch_dec
705
-
706
- self.upconv_3 = FuseBlock(
707
- self.num_ch_dec[4],
708
- self.num_ch_dec[3],
709
- fuse=False, upsample=False, tuning_mode=tuning_mode)
710
-
711
- self.upconv_2 = FuseBlock(
712
- self.num_ch_dec[3],
713
- self.num_ch_dec[2],
714
- tuning_mode=tuning_mode)
715
-
716
- self.upconv_1 = FuseBlock(
717
- self.num_ch_dec[2],
718
- self.num_ch_dec[1] + 2,
719
- scale_factor=7/4,
720
- tuning_mode=tuning_mode)
721
-
722
- # self.upconv_0 = FuseBlock(
723
- # self.num_ch_dec[1],
724
- # self.num_ch_dec[0] + 1,
725
- # )
726
-
727
- def forward(self, ref_feature):
728
- x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4
729
-
730
- x = self.upconv_3(x) # 1/14
731
- x = self.upconv_2(x, x2) # 1/7
732
- x = self.upconv_1(x, x1) # 1/4
733
- # x = self.upconv_0(x, x0) # 4/7
734
- return x
735
-
736
- class RAFTDepthNormalDPT5(nn.Module):
737
- def __init__(self, cfg):
738
- super().__init__()
739
- self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024]
740
- self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14]
741
- self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14]
742
- self.use_cls_token = cfg.model.decode_head.use_cls_token
743
- self.up_scale = cfg.model.decode_head.up_scale
744
- self.num_register_tokens = cfg.model.decode_head.num_register_tokens
745
- self.min_val = cfg.data_basic.depth_normalize[0]
746
- self.max_val = cfg.data_basic.depth_normalize[1]
747
- self.regress_scale = 100.0\
748
-
749
- try:
750
- tuning_mode = cfg.model.decode_head.tuning_mode
751
- except:
752
- tuning_mode = None
753
- self.tuning_mode = tuning_mode
754
-
755
- self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128]
756
- self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3
757
- self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
758
- self.iters = cfg.model.decode_head.iters # 22
759
- self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True
760
-
761
- self.num_depth_regressor_anchor = 256 # 512
762
- self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res
763
- self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode)
764
- self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode)
765
- self.depth_regressor = nn.Sequential(
766
- Conv2dLoRA(self.used_res_channel,
767
- self.num_depth_regressor_anchor,
768
- kernel_size=3,
769
- padding=1, r = 8 if tuning_mode == 'lora' else 0),
770
- # nn.BatchNorm2d(self.num_depth_regressor_anchor),
771
- nn.ReLU(inplace=True),
772
- Conv2dLoRA(self.num_depth_regressor_anchor,
773
- self.num_depth_regressor_anchor,
774
- kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
775
- )
776
- self.normal_predictor = nn.Sequential(
777
- Conv2dLoRA(self.used_res_channel,
778
- 128,
779
- kernel_size=3,
780
- padding=1, r = 8 if tuning_mode == 'lora' else 0,),
781
- # nn.BatchNorm2d(128),
782
- nn.ReLU(inplace=True),
783
- Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
784
- Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
785
- Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
786
- )
787
-
788
- self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode)
789
- self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)])
790
- self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode)
791
-
792
- self.relu = nn.ReLU(inplace=True)
793
-
794
- def get_bins(self, bins_num):
795
- depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
796
- depth_bins_vec = torch.exp(depth_bins_vec)
797
- return depth_bins_vec
798
-
799
- def register_depth_expectation_anchor(self, bins_num, B):
800
- depth_bins_vec = self.get_bins(bins_num)
801
- depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
802
- self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
803
-
804
- def clamp(self, x):
805
- y = self.relu(x - self.min_val) + self.min_val
806
- y = self.max_val - self.relu(self.max_val - y)
807
- return y
808
-
809
- def regress_depth(self, feature_map_d):
810
- prob_feature = self.depth_regressor(feature_map_d)
811
- prob = prob_feature.softmax(dim=1)
812
- #prob = prob_feature.float().softmax(dim=1)
813
-
814
- ## Error logging
815
- if torch.isnan(prob).any():
816
- print('prob_feat_nan!!!')
817
- if torch.isinf(prob).any():
818
- print('prob_feat_inf!!!')
819
-
820
- # h = prob[0,:,0,0].cpu().numpy().reshape(-1)
821
- # import matplotlib.pyplot as plt
822
- # plt.bar(range(len(h)), h)
823
- B = prob.shape[0]
824
- if "depth_expectation_anchor" not in self._buffers:
825
- self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
826
- d = compute_depth_expectation(
827
- prob,
828
- self.depth_expectation_anchor[:B, ...]).unsqueeze(1)
829
-
830
- ## Error logging
831
- if torch.isnan(d ).any():
832
- print('d_nan!!!')
833
- if torch.isinf(d ).any():
834
- print('d_inf!!!')
835
-
836
- return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature
837
-
838
- def pred_normal(self, feature_map, confidence):
839
- normal_out = self.normal_predictor(feature_map)
840
-
841
- ## Error logging
842
- if torch.isnan(normal_out).any():
843
- print('norm_nan!!!')
844
- if torch.isinf(normal_out).any():
845
- print('norm_feat_inf!!!')
846
-
847
- return norm_normalize(torch.cat([normal_out, confidence], dim=1))
848
- #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
849
-
850
- def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
851
- y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
852
- torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
853
- meshgrid = torch.stack((x, y))
854
- meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
855
- #self.register_buffer('meshgrid', meshgrid, persistent=False)
856
- return meshgrid
857
-
858
- def upsample_flow(self, flow, mask):
859
- """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
860
- N, D, H, W = flow.shape
861
- factor = 2 ** self.n_downsample
862
- mask = mask.view(N, 1, 9, factor, factor, H, W)
863
- mask = torch.softmax(mask, dim=2)
864
- #mask = torch.softmax(mask.float(), dim=2)
865
-
866
- #up_flow = F.unfold(factor * flow, [3,3], padding=1)
867
- up_flow = F.unfold(flow, [3,3], padding=1)
868
- up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
869
-
870
- up_flow = torch.sum(mask * up_flow, dim=2)
871
- up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
872
- return up_flow.reshape(N, D, factor*H, factor*W)
873
-
874
- def initialize_flow(self, img):
875
- """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
876
- N, _, H, W = img.shape
877
-
878
- coords0 = coords_grid(N, H, W).to(img.device)
879
- coords1 = coords_grid(N, H, W).to(img.device)
880
-
881
- return coords0, coords1
882
-
883
- def upsample(self, x, scale_factor=2):
884
- """Upsample input tensor by a factor of 2
885
- """
886
- return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest")
887
-
888
- def forward(self, vit_features, **kwargs):
889
- ## read vit token to multi-scale features
890
- B, H, W, _, _, num_register_tokens = vit_features[1]
891
- vit_features = vit_features[0]
892
-
893
- ## Error logging
894
- if torch.isnan(vit_features[0]).any():
895
- print('vit_feature_nan!!!')
896
- if torch.isinf(vit_features[0]).any():
897
- print('vit_feature_inf!!!')
898
-
899
- if self.use_cls_token == True:
900
- vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \
901
- ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features]
902
- else:
903
- vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features]
904
- encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4
905
-
906
- ## Error logging
907
- for en_ft in encoder_features:
908
- if torch.isnan(en_ft).any():
909
- print('decoder_feature_nan!!!')
910
- print(en_ft.shape)
911
- if torch.isinf(en_ft).any():
912
- print('decoder_feature_inf!!!')
913
- print(en_ft.shape)
914
-
915
- ## decode features to init-depth (and confidence)
916
- ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth
917
-
918
- ## Error logging
919
- if torch.isnan(ref_feat).any():
920
- print('ref_feat_nan!!!')
921
- if torch.isinf(ref_feat).any():
922
- print('ref_feat_inf!!!')
923
-
924
- feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction
925
- depth_confidence_map = ref_feat[:, -2:-1, :, :]
926
- normal_confidence_map = ref_feat[:, -1:, :, :]
927
- depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth
928
- normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal
929
-
930
- depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W)
931
-
932
- ## encoder features to context-feature for init-hidden-state and contex-features
933
- cnet_list = self.context_feature_encoder(encoder_features[::-1])
934
- net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state
935
- inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features
936
-
937
- # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
938
- inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
939
-
940
- coords0, coords1 = self.initialize_flow(net_list[0])
941
- if depth_init is not None:
942
- coords1 = coords1 + depth_init
943
-
944
- if self.training:
945
- low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())]
946
- init_depth = upflow4(depth_init)
947
- flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)]
948
- conf_predictions = [init_depth[:,1:2]]
949
- normal_outs = [norm_normalize(init_depth[:,2:].clone())]
950
-
951
- else:
952
- flow_predictions = []
953
- conf_predictions = []
954
- samples_pred_list = []
955
- coord_list = []
956
- normal_outs = []
957
- low_resolution_init = []
958
-
959
- for itr in range(self.iters):
960
- # coords1 = coords1.detach()
961
- flow = coords1 - coords0
962
- if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU
963
- net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
964
- if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU
965
- net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False)
966
- net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2)
967
-
968
- # F(t+1) = F(t) + \Delta(t)
969
- coords1 = coords1 + delta_flow
970
-
971
- # We do not need to upsample or output intermediate results in test_mode
972
- #if (not self.training) and itr < self.iters-1:
973
- #continue
974
-
975
- # upsample predictions
976
- if up_mask is None:
977
- flow_up = self.upsample(coords1-coords0, 4)
978
- else:
979
- flow_up = self.upsample_flow(coords1 - coords0, up_mask)
980
- # flow_up = self.upsample(coords1-coords0, 4)
981
-
982
- flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val))
983
- conf_predictions.append(flow_up[:,1:2])
984
- normal_outs.append(norm_normalize(flow_up[:,2:].clone()))
985
-
986
- outputs=dict(
987
- prediction=flow_predictions[-1],
988
- predictions_list=flow_predictions,
989
- confidence=conf_predictions[-1],
990
- confidence_list=conf_predictions,
991
- pred_logit=None,
992
- # samples_pred_list=samples_pred_list,
993
- # coord_list=coord_list,
994
- prediction_normal=normal_outs[-1],
995
- normal_out_list=normal_outs,
996
- low_resolution_init=low_resolution_init,
997
- )
998
-
999
- return outputs
1000
-
1001
-
1002
- if __name__ == "__main__":
1003
- try:
1004
- from mmcv.utils import Config
1005
- except:
1006
- from mmengine import Config
1007
- cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py')
1008
- cfg.model.decode_head.in_channels = [384, 384, 384, 384]
1009
- cfg.model.decode_head.feature_channels = [96, 192, 384, 768]
1010
- cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384]
1011
- cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48]
1012
- cfg.model.decode_head.up_scale = 7
1013
-
1014
- # cfg.model.decode_head.use_cls_token = True
1015
- # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
1016
- # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
1017
- # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
1018
- # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]]
1019
-
1020
- cfg.model.decode_head.use_cls_token = True
1021
- cfg.model.decode_head.num_register_tokens = 4
1022
- vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\
1023
- torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
1024
- torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
1025
- torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)]
1026
-
1027
- decoder = RAFTDepthNormalDPT5(cfg).cuda()
1028
- output = decoder(vit_feature)
1029
- temp = 1
1030
-
1031
-
1032
-
1033
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/decode_heads/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .HourGlassDecoder import HourglassDecoder
2
- from .RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5
3
-
4
- __all__=['HourglassDecoder', 'RAFTDepthNormalDPT5']
 
 
 
 
 
mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc DELETED
Binary file (8.65 kB)
 
mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (252 Bytes)
 
mono/model/model_pipelines/__base_model__.py DELETED
@@ -1,20 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from mono.utils.comm import get_func
4
-
5
-
6
- class BaseDepthModel(nn.Module):
7
- def __init__(self, cfg, **kwargs) -> None:
8
- super(BaseDepthModel, self).__init__()
9
- model_type = cfg.model.type
10
- self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg)
11
-
12
- def forward(self, data):
13
- output = self.depth_model(**data)
14
-
15
- return output['prediction'], output['confidence'], output
16
-
17
- def inference(self, data):
18
- with torch.no_grad():
19
- pred_depth, confidence, _ = self.forward(data)
20
- return pred_depth, confidence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/model_pipelines/__init__.py DELETED
@@ -1,6 +0,0 @@
1
-
2
- from .dense_pipeline import DensePredModel
3
- from .__base_model__ import BaseDepthModel
4
- __all__ = [
5
- 'DensePredModel', 'BaseDepthModel',
6
- ]
 
 
 
 
 
 
 
mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc DELETED
Binary file (1.19 kB)
 
mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (313 Bytes)
 
mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc DELETED
Binary file (1.01 kB)
 
mono/model/model_pipelines/dense_pipeline.py DELETED
@@ -1,16 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from mono.utils.comm import get_func
4
-
5
- class DensePredModel(nn.Module):
6
- def __init__(self, cfg) -> None:
7
- super(DensePredModel, self).__init__()
8
-
9
- self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone)
10
- self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg)
11
-
12
- def forward(self, input, **kwargs):
13
- # [f_32, f_16, f_8, f_4]
14
- features = self.encoder(input)
15
- out = self.decoder(features, **kwargs)
16
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/model/monodepth_model.py DELETED
@@ -1,37 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from .model_pipelines.__base_model__ import BaseDepthModel
4
-
5
- class DepthModel(BaseDepthModel):
6
- def __init__(self, cfg, **kwards):
7
- super(DepthModel, self).__init__(cfg)
8
- model_type = cfg.model.type
9
-
10
- def inference(self, data):
11
- with torch.no_grad():
12
- pred_depth, confidence, output_dict = self.forward(data)
13
- return pred_depth, confidence, output_dict
14
-
15
- def get_monodepth_model(
16
- cfg : dict,
17
- **kwargs
18
- ) -> nn.Module:
19
- # config depth model
20
- model = DepthModel(cfg, **kwargs)
21
- #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath)
22
- assert isinstance(model, nn.Module)
23
- return model
24
-
25
- def get_configured_monodepth_model(
26
- cfg: dict,
27
- ) -> nn.Module:
28
- """
29
- Args:
30
- @ configs: configures for the network.
31
- @ load_imagenet_model: whether to initialize from ImageNet-pretrained model.
32
- @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with.
33
- Returns:
34
- # model: depth model.
35
- """
36
- model = get_monodepth_model(cfg)
37
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/tools/test_scale_cano.py DELETED
@@ -1,158 +0,0 @@
1
- import os
2
- import os.path as osp
3
- import cv2
4
- import time
5
- import sys
6
- CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
- sys.path.append(CODE_SPACE)
8
- import argparse
9
- import mmcv
10
- import torch
11
- import torch.distributed as dist
12
- import torch.multiprocessing as mp
13
-
14
- try:
15
- from mmcv.utils import Config, DictAction
16
- except:
17
- from mmengine import Config, DictAction
18
- from datetime import timedelta
19
- import random
20
- import numpy as np
21
- from mono.utils.logger import setup_logger
22
- import glob
23
- from mono.utils.comm import init_env
24
- from mono.model.monodepth_model import get_configured_monodepth_model
25
- from mono.utils.running import load_ckpt
26
- from mono.utils.do_test import do_scalecano_test_with_custom_data
27
- from mono.utils.mldb import load_data_info, reset_ckpt_path
28
- from mono.utils.custom_data import load_from_annos, load_data
29
-
30
- def parse_args():
31
- parser = argparse.ArgumentParser(description='Train a segmentor')
32
- parser.add_argument('config', help='train config file path')
33
- parser.add_argument('--show-dir', help='the dir to save logs and visualization results')
34
- parser.add_argument('--load-from', help='the checkpoint file to load weights from')
35
- parser.add_argument('--node_rank', type=int, default=0)
36
- parser.add_argument('--nnodes', type=int, default=1, help='number of nodes')
37
- parser.add_argument('--options', nargs='+', action=DictAction, help='custom options')
38
- parser.add_argument('--launcher', choices=['None', 'pytorch', 'slurm', 'mpi', 'ror'], default='slurm', help='job launcher')
39
- parser.add_argument('--test_data_path', default='None', type=str, help='the path of test data')
40
- args = parser.parse_args()
41
- return args
42
-
43
- def main(args):
44
- os.chdir(CODE_SPACE)
45
- cfg = Config.fromfile(args.config)
46
-
47
- if args.options is not None:
48
- cfg.merge_from_dict(args.options)
49
-
50
- # show_dir is determined in this priority: CLI > segment in file > filename
51
- if args.show_dir is not None:
52
- # update configs according to CLI args if args.show_dir is not None
53
- cfg.show_dir = args.show_dir
54
- else:
55
- # use condig filename + timestamp as default show_dir if args.show_dir is None
56
- cfg.show_dir = osp.join('./show_dirs',
57
- osp.splitext(osp.basename(args.config))[0],
58
- args.timestamp)
59
-
60
- # ckpt path
61
- if args.load_from is None:
62
- raise RuntimeError('Please set model path!')
63
- cfg.load_from = args.load_from
64
-
65
- # load data info
66
- data_info = {}
67
- load_data_info('data_info', data_info=data_info)
68
- cfg.mldb_info = data_info
69
- # update check point info
70
- reset_ckpt_path(cfg.model, data_info)
71
-
72
- # create show dir
73
- os.makedirs(osp.abspath(cfg.show_dir), exist_ok=True)
74
-
75
- # init the logger before other steps
76
- cfg.log_file = osp.join(cfg.show_dir, f'{args.timestamp}.log')
77
- logger = setup_logger(cfg.log_file)
78
-
79
- # log some basic info
80
- logger.info(f'Config:\n{cfg.pretty_text}')
81
-
82
- # init distributed env dirst, since logger depends on the dist info
83
- if args.launcher == 'None':
84
- cfg.distributed = False
85
- else:
86
- cfg.distributed = True
87
- init_env(args.launcher, cfg)
88
- logger.info(f'Distributed training: {cfg.distributed}')
89
-
90
- # dump config
91
- cfg.dump(osp.join(cfg.show_dir, osp.basename(args.config)))
92
- test_data_path = args.test_data_path
93
- if not os.path.isabs(test_data_path):
94
- test_data_path = osp.join(CODE_SPACE, test_data_path)
95
-
96
- if 'json' in test_data_path:
97
- test_data = load_from_annos(test_data_path)
98
- else:
99
- test_data = load_data(args.test_data_path)
100
-
101
- if not cfg.distributed:
102
- main_worker(0, cfg, args.launcher, test_data)
103
- else:
104
- # distributed training
105
- if args.launcher == 'ror':
106
- local_rank = cfg.dist_params.local_rank
107
- main_worker(local_rank, cfg, args.launcher, test_data)
108
- else:
109
- mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher, test_data))
110
-
111
- def main_worker(local_rank: int, cfg: dict, launcher: str, test_data: list):
112
- if cfg.distributed:
113
- cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank
114
- cfg.dist_params.local_rank = local_rank
115
-
116
- if launcher == 'ror':
117
- init_torch_process_group(use_hvd=False)
118
- else:
119
- torch.cuda.set_device(local_rank)
120
- default_timeout = timedelta(minutes=30)
121
- dist.init_process_group(
122
- backend=cfg.dist_params.backend,
123
- init_method=cfg.dist_params.dist_url,
124
- world_size=cfg.dist_params.world_size,
125
- rank=cfg.dist_params.global_rank,
126
- timeout=default_timeout)
127
-
128
- logger = setup_logger(cfg.log_file)
129
- # build model
130
- model = get_configured_monodepth_model(cfg, )
131
-
132
- # config distributed training
133
- if cfg.distributed:
134
- model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
135
- device_ids=[local_rank],
136
- output_device=local_rank,
137
- find_unused_parameters=True)
138
- else:
139
- model = torch.nn.DataParallel(model).cuda()
140
-
141
- # load ckpt
142
- model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False)
143
- model.eval()
144
-
145
- do_scalecano_test_with_custom_data(
146
- model,
147
- cfg,
148
- test_data,
149
- logger,
150
- cfg.distributed,
151
- local_rank
152
- )
153
-
154
- if __name__ == '__main__':
155
- args = parse_args()
156
- timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
157
- args.timestamp = timestamp
158
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mono/utils/__init__.py DELETED
@@ -1 +0,0 @@
1
-
 
 
mono/utils/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (160 Bytes)
 
mono/utils/__pycache__/avg_meter.cpython-39.pyc DELETED
Binary file (10.1 kB)
 
mono/utils/__pycache__/comm.cpython-39.pyc DELETED
Binary file (9.72 kB)
 
mono/utils/__pycache__/custom_data.cpython-39.pyc DELETED
Binary file (1.21 kB)
 
mono/utils/__pycache__/do_test.cpython-39.pyc DELETED
Binary file (8.71 kB)
 
mono/utils/__pycache__/logger.cpython-39.pyc DELETED
Binary file (3.17 kB)
 
mono/utils/__pycache__/mldb.cpython-39.pyc DELETED
Binary file (1.18 kB)
 
mono/utils/__pycache__/running.cpython-39.pyc DELETED
Binary file (2.09 kB)
 
mono/utils/__pycache__/transform.cpython-39.pyc DELETED
Binary file (11.5 kB)
 
mono/utils/__pycache__/unproj_pcd.cpython-39.pyc DELETED
Binary file (2.61 kB)