diff --git a/mono/configs/HourglassDecoder/convlarge.0.3_150.py b/mono/configs/HourglassDecoder/convlarge.0.3_150.py deleted file mode 100644 index 37b91c80284d6db3df3017ec636f18198e42dc08..0000000000000000000000000000000000000000 --- a/mono/configs/HourglassDecoder/convlarge.0.3_150.py +++ /dev/null @@ -1,25 +0,0 @@ -_base_=[ - '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', - '../_base_/datasets/_data_base_.py', - '../_base_/default_runtime.py', - ] - -model = dict( - backbone=dict( - pretrained=False, - ) -) - -# configs of the canonical space -data_basic=dict( - canonical_space = dict( - img_size=(512, 960), - focal_length=1000.0, - ), - depth_range=(0, 1), - depth_normalize=(0.3, 150), - crop_size = (544, 1216), -) - -batchsize_per_gpu = 2 -thread_per_gpu = 4 diff --git a/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py b/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py deleted file mode 100644 index cdd9156b7f2f0921fb01b1adaf9a2a7447332d6e..0000000000000000000000000000000000000000 --- a/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py +++ /dev/null @@ -1,25 +0,0 @@ -_base_=[ - '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', - '../_base_/datasets/_data_base_.py', - '../_base_/default_runtime.py', - ] - -model = dict( - backbone=dict( - pretrained=False, - ) -) - -# configs of the canonical space -data_basic=dict( - canonical_space = dict( - img_size=(512, 960), - focal_length=1000.0, - ), - depth_range=(0, 1), - depth_normalize=(0.3, 150), - crop_size = (512, 1088), -) - -batchsize_per_gpu = 2 -thread_per_gpu = 4 diff --git a/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py b/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py deleted file mode 100644 index 6601f5cdfad07c5fad8b89fbf959e67039126dfa..0000000000000000000000000000000000000000 --- a/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py +++ /dev/null @@ -1,25 +0,0 @@ -_base_=[ - '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', - '../_base_/datasets/_data_base_.py', - '../_base_/default_runtime.py', - ] - -model = dict( - backbone=dict( - pretrained=False, - ) -) - -# configs of the canonical space -data_basic=dict( - canonical_space = dict( - img_size=(512, 960), - focal_length=1000.0, - ), - depth_range=(0, 1), - depth_normalize=(0.3, 150), - crop_size = (480, 1216), -) - -batchsize_per_gpu = 2 -thread_per_gpu = 4 diff --git a/mono/configs/HourglassDecoder/vit.raft5.large.py b/mono/configs/HourglassDecoder/vit.raft5.large.py deleted file mode 100644 index 4febdcb2867513008496f394ce8dc513230fb480..0000000000000000000000000000000000000000 --- a/mono/configs/HourglassDecoder/vit.raft5.large.py +++ /dev/null @@ -1,33 +0,0 @@ -_base_=[ - '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py', - '../_base_/datasets/_data_base_.py', - '../_base_/default_runtime.py', - ] - -import numpy as np -model=dict( - decode_head=dict( - type='RAFTDepthNormalDPT5', - iters=8, - n_downsample=2, - detach=False, - ) -) - - -max_value = 200 -# configs of the canonical space -data_basic=dict( - canonical_space = dict( - # img_size=(540, 960), - focal_length=1000.0, - ), - depth_range=(0, 1), - depth_normalize=(0.1, max_value), - crop_size = (616, 1064), # %28 = 0 - clip_depth_range=(0.1, 200), - vit_size=(616,1064) -) - -batchsize_per_gpu = 1 -thread_per_gpu = 1 diff --git a/mono/configs/HourglassDecoder/vit.raft5.small.py b/mono/configs/HourglassDecoder/vit.raft5.small.py deleted file mode 100644 index 25eb68cc151f090c7654b7ebbcaf9dfc6a478570..0000000000000000000000000000000000000000 --- a/mono/configs/HourglassDecoder/vit.raft5.small.py +++ /dev/null @@ -1,33 +0,0 @@ -_base_=[ - '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py', - '../_base_/datasets/_data_base_.py', - '../_base_/default_runtime.py', - ] - -import numpy as np -model=dict( - decode_head=dict( - type='RAFTDepthNormalDPT5', - iters=4, - n_downsample=2, - detach=False, - ) -) - - -max_value = 200 -# configs of the canonical space -data_basic=dict( - canonical_space = dict( - # img_size=(540, 960), - focal_length=1000.0, - ), - depth_range=(0, 1), - depth_normalize=(0.1, max_value), - crop_size = (616, 1064), # %28 = 0 - clip_depth_range=(0.1, 200), - vit_size=(616,1064) -) - -batchsize_per_gpu = 1 -thread_per_gpu = 1 diff --git a/mono/configs/__init__.py b/mono/configs/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/mono/configs/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/mono/configs/_base_/_data_base_.py b/mono/configs/_base_/_data_base_.py deleted file mode 100644 index 35f3844f24191b6b9452e136ea3205b7622466d7..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/_data_base_.py +++ /dev/null @@ -1,13 +0,0 @@ -# canonical camera setting and basic data setting -# we set it same as the E300 camera (crop version) -# -data_basic=dict( - canonical_space = dict( - img_size=(540, 960), - focal_length=1196.0, - ), - depth_range=(0.9, 150), - depth_normalize=(0.006, 1.001), - crop_size = (512, 960), - clip_depth_range=(0.9, 150), -) diff --git a/mono/configs/_base_/datasets/_data_base_.py b/mono/configs/_base_/datasets/_data_base_.py deleted file mode 100644 index b554444e9b75b4519b862e726890dcf7859be0ec..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/datasets/_data_base_.py +++ /dev/null @@ -1,12 +0,0 @@ -# canonical camera setting and basic data setting -# -data_basic=dict( - canonical_space = dict( - img_size=(540, 960), - focal_length=1196.0, - ), - depth_range=(0.9, 150), - depth_normalize=(0.006, 1.001), - crop_size = (512, 960), - clip_depth_range=(0.9, 150), -) diff --git a/mono/configs/_base_/default_runtime.py b/mono/configs/_base_/default_runtime.py deleted file mode 100644 index a690b491bf50aad5c2fd7e9ac387609123a4594a..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/default_runtime.py +++ /dev/null @@ -1,4 +0,0 @@ - -load_from = None -cudnn_benchmark = True -test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3','rmse_log', 'log10', 'sq_rel'] diff --git a/mono/configs/_base_/models/backbones/convnext_large.py b/mono/configs/_base_/models/backbones/convnext_large.py deleted file mode 100644 index 5a22f7e1b53ca154bfae1672e6ee3b52028039b9..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/models/backbones/convnext_large.py +++ /dev/null @@ -1,16 +0,0 @@ -#_base_ = ['./_model_base_.py',] - -#'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth' -model = dict( - #type='EncoderDecoderAuxi', - backbone=dict( - type='convnext_large', - pretrained=True, - in_22k=True, - out_indices=[0, 1, 2, 3], - drop_path_rate=0.4, - layer_scale_init_value=1.0, - checkpoint='data/pretrained_weight_repo/convnext/convnext_large_22k_1k_384.pth', - prefix='backbones.', - out_channels=[192, 384, 768, 1536]), - ) diff --git a/mono/configs/_base_/models/backbones/dino_vit_large.py b/mono/configs/_base_/models/backbones/dino_vit_large.py deleted file mode 100644 index 843178ed6e61d74070b971f01148f87fdf2a62cf..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/models/backbones/dino_vit_large.py +++ /dev/null @@ -1,7 +0,0 @@ -model = dict( - backbone=dict( - type='vit_large', - prefix='backbones.', - out_channels=[1024, 1024, 1024, 1024], - drop_path_rate = 0.0), - ) diff --git a/mono/configs/_base_/models/backbones/dino_vit_large_reg.py b/mono/configs/_base_/models/backbones/dino_vit_large_reg.py deleted file mode 100644 index 25e96747d459d42df299f8a6a1e14044a0e56164..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/models/backbones/dino_vit_large_reg.py +++ /dev/null @@ -1,7 +0,0 @@ -model = dict( - backbone=dict( - type='vit_large_reg', - prefix='backbones.', - out_channels=[1024, 1024, 1024, 1024], - drop_path_rate = 0.0), - ) diff --git a/mono/configs/_base_/models/backbones/dino_vit_small_reg.py b/mono/configs/_base_/models/backbones/dino_vit_small_reg.py deleted file mode 100644 index 0c8bd97dccb9cdee7517250f40e01bb3124144e6..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/models/backbones/dino_vit_small_reg.py +++ /dev/null @@ -1,7 +0,0 @@ -model = dict( - backbone=dict( - type='vit_small_reg', - prefix='backbones.', - out_channels=[384, 384, 384, 384], - drop_path_rate = 0.0), - ) diff --git a/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py b/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py deleted file mode 100644 index f262288c49e7ffccb6174b09b0daf80ff79dd684..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py +++ /dev/null @@ -1,10 +0,0 @@ -# model settings -_base_ = ['../backbones/convnext_large.py',] -model = dict( - type='DensePredModel', - decode_head=dict( - type='HourglassDecoder', - in_channels=[192, 384, 768, 1536], - decoder_channel=[128, 128, 256, 512], - prefix='decode_heads.'), -) diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py deleted file mode 100644 index bd69efefab2c03de435996c6b7b65ff941db1e5d..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py +++ /dev/null @@ -1,20 +0,0 @@ -# model settings -_base_ = ['../backbones/dino_vit_large.py'] -model = dict( - type='DensePredModel', - decode_head=dict( - type='RAFTDepthDPT', - in_channels=[1024, 1024, 1024, 1024], - use_cls_token=True, - feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14] - decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14] - up_scale = 7, - hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536] - n_gru_layers=3, - n_downsample=2, - iters=12, - slow_fast_gru=True, - corr_radius=4, - corr_levels=4, - prefix='decode_heads.'), -) diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py deleted file mode 100644 index 26ab6dc090e9cdb840d84fab10587becb536dbb8..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py +++ /dev/null @@ -1,19 +0,0 @@ -# model settings -_base_ = ['../backbones/dino_vit_large_reg.py'] -model = dict( - type='DensePredModel', - decode_head=dict( - type='RAFTDepthDPT', - in_channels=[1024, 1024, 1024, 1024], - use_cls_token=True, - feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14] - decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14] - up_scale = 7, - hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536] - n_gru_layers=3, - n_downsample=2, - iters=3, - slow_fast_gru=True, - num_register_tokens=4, - prefix='decode_heads.'), -) diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py deleted file mode 100644 index 19466c191e9f2a83903e55ca4fc0827d9a11bcb9..0000000000000000000000000000000000000000 --- a/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py +++ /dev/null @@ -1,19 +0,0 @@ -# model settings -_base_ = ['../backbones/dino_vit_small_reg.py'] -model = dict( - type='DensePredModel', - decode_head=dict( - type='RAFTDepthDPT', - in_channels=[384, 384, 384, 384], - use_cls_token=True, - feature_channels = [96, 192, 384, 768], # [2/7, 1/7, 1/14, 1/14] - decoder_channels = [48, 96, 192, 384, 384], # [-, 1/4, 1/7, 1/14, 1/14] - up_scale = 7, - hidden_channels=[48, 48, 48, 48], # [x_4, x_8, x_16, x_32] [1/4, 1/7, 1/14, -] - n_gru_layers=3, - n_downsample=2, - iters=3, - slow_fast_gru=True, - num_register_tokens=4, - prefix='decode_heads.'), -) diff --git a/mono/model/__init__.py b/mono/model/__init__.py deleted file mode 100644 index 9e1ea3d3e3b880e28ef880083b3c79e3b00cd119..0000000000000000000000000000000000000000 --- a/mono/model/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .monodepth_model import DepthModel -# from .__base_model__ import BaseDepthModel - - -__all__ = ['DepthModel', 'BaseDepthModel'] diff --git a/mono/model/__pycache__/__init__.cpython-39.pyc b/mono/model/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 3c9c860c14219cf199bdb577cb7e0e6dd7e5eadb..0000000000000000000000000000000000000000 Binary files a/mono/model/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/mono/model/__pycache__/monodepth_model.cpython-39.pyc b/mono/model/__pycache__/monodepth_model.cpython-39.pyc deleted file mode 100644 index bd965a942a758e150ac2ca0854800bce82b83f14..0000000000000000000000000000000000000000 Binary files a/mono/model/__pycache__/monodepth_model.cpython-39.pyc and /dev/null differ diff --git a/mono/model/backbones/ConvNeXt.py b/mono/model/backbones/ConvNeXt.py deleted file mode 100644 index f1c4be0e6463ae2b0dda6d20fc273a300afa5ebf..0000000000000000000000000000000000000000 --- a/mono/model/backbones/ConvNeXt.py +++ /dev/null @@ -1,271 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from timm.models.layers import trunc_normal_, DropPath -from timm.models.registry import register_model - -class Block(nn.Module): - r""" ConvNeXt Block. There are two equivalent implementations: - (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) - (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back - We use (2) as we find it slightly faster in PyTorch - - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - """ - def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): - super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.pwconv2 = nn.Linear(4 * dim, dim) - self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), - requires_grad=True) if layer_scale_init_value > 0 else None - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, x): - input = x - x = self.dwconv(x) - x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - if self.gamma is not None: - x = self.gamma * x - x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) - - x = input + self.drop_path(x) - return x - -class ConvNeXt(nn.Module): - r""" ConvNeXt - A PyTorch impl of : `A ConvNet for the 2020s` - - https://arxiv.org/pdf/2201.03545.pdf - Args: - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] - dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] - drop_path_rate (float): Stochastic depth rate. Default: 0. - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. - """ - def __init__(self, in_chans=3, num_classes=1000, - depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., - layer_scale_init_value=1e-6, head_init_scale=1., - **kwargs,): - super().__init__() - - self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers - stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first") - ) - self.downsample_layers.append(stem) - for i in range(3): - downsample_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), - ) - self.downsample_layers.append(downsample_layer) - - self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks - dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - cur = 0 - for i in range(4): - stage = nn.Sequential( - *[Block(dim=dims[i], drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] - ) - self.stages.append(stage) - cur += depths[i] - - #self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer - #self.head = nn.Linear(dims[-1], num_classes) - - self.apply(self._init_weights) - #self.head.weight.data.mul_(head_init_scale) - #self.head.bias.data.mul_(head_init_scale) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - nn.init.constant_(m.bias, 0) - - def forward_features(self, x): - features = [] - for i in range(4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - features.append(x) - return features # global average pooling, (N, C, H, W) -> (N, C) - - def forward(self, x): - #x = self.forward_features(x) - #x = self.head(x) - features = self.forward_features(x) - return features - -class LayerNorm(nn.Module): - r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape, ) - - def forward(self, x): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - - -model_urls = { - "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", - "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", - "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", - "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", - "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", - "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", - "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", - "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", - "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", -} - -def convnext_tiny(pretrained=True,in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) - if pretrained: - checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") - #url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] - #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) - model_dict = model.state_dict() - pretrained_dict = {} - unmatched_pretrained_dict = {} - for k, v in checkpoint['model'].items(): - if k in model_dict: - pretrained_dict[k] = v - else: - unmatched_pretrained_dict[k] = v - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) - print( - 'Successfully loaded pretrained %d params, and %d paras are unmatched.' - %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) - print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) - return model - -def convnext_small(pretrained=True,in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - if pretrained: - checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") - #url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] - #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") - model_dict = model.state_dict() - pretrained_dict = {} - unmatched_pretrained_dict = {} - for k, v in checkpoint['model'].items(): - if k in model_dict: - pretrained_dict[k] = v - else: - unmatched_pretrained_dict[k] = v - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) - print( - 'Successfully loaded pretrained %d params, and %d paras are unmatched.' - %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) - print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) - return model - -def convnext_base(pretrained=True, in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - if pretrained: - checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") - #url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] - #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") - model_dict = model.state_dict() - pretrained_dict = {} - unmatched_pretrained_dict = {} - for k, v in checkpoint['model'].items(): - if k in model_dict: - pretrained_dict[k] = v - else: - unmatched_pretrained_dict[k] = v - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) - print( - 'Successfully loaded pretrained %d params, and %d paras are unmatched.' - %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) - print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) - return model - -def convnext_large(pretrained=True, in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - if pretrained: - checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") - #url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] - #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") - model_dict = model.state_dict() - pretrained_dict = {} - unmatched_pretrained_dict = {} - for k, v in checkpoint['model'].items(): - if k in model_dict: - pretrained_dict[k] = v - else: - unmatched_pretrained_dict[k] = v - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) - print( - 'Successfully loaded pretrained %d params, and %d paras are unmatched.' - %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) - print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) - return model - -def convnext_xlarge(pretrained=True, in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) - if pretrained: - assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" - checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") - #url = model_urls['convnext_xlarge_22k'] - #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") - model_dict = model.state_dict() - pretrained_dict = {} - unmatched_pretrained_dict = {} - for k, v in checkpoint['model'].items(): - if k in model_dict: - pretrained_dict[k] = v - else: - unmatched_pretrained_dict[k] = v - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) - print( - 'Successfully loaded pretrained %d params, and %d paras are unmatched.' - %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) - print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) - return model - -if __name__ == '__main__': - import torch - model = convnext_base(True, in_22k=False).cuda() - - rgb = torch.rand((2, 3, 256, 256)).cuda() - out = model(rgb) - print(len(out)) - for i, ft in enumerate(out): - print(i, ft.shape) diff --git a/mono/model/backbones/ViT_DINO.py b/mono/model/backbones/ViT_DINO.py deleted file mode 100644 index 5a1998f0dd5024fbe69895e244fc054245a06568..0000000000000000000000000000000000000000 --- a/mono/model/backbones/ViT_DINO.py +++ /dev/null @@ -1,1504 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py - -from functools import partial -import math -import logging -from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List - -import torch -import torch.nn as nn -from torch import Tensor -import torch.utils.checkpoint -from torch.nn.init import trunc_normal_ - -#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block - -logger = logging.getLogger("dinov2") - -class ConvBlock(nn.Module): - def __init__(self, channels): - super(ConvBlock, self).__init__() - - self.act = nn.ReLU(inplace=True) - self.conv1 = nn.Conv2d( - channels, - channels, - kernel_size=3, - stride=1, - padding=1 - ) - self.norm1 = nn.BatchNorm2d(channels) - self.conv2 = nn.Conv2d( - channels, - channels, - kernel_size=3, - stride=1, - padding=1 - ) - self.norm2 = nn.BatchNorm2d(channels) - - def forward(self, x): - - out = self.norm1(x) - out = self.act(out) - out = self.conv1(out) - out = self.norm2(out) - out = self.act(out) - out = self.conv2(out) - return x + out - -def make_2tuple(x): - if isinstance(x, tuple): - assert len(x) == 2 - return x - - assert isinstance(x, int) - return (x, x) - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0: - random_tensor.div_(keep_prob) - output = x * random_tensor - return output - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - -class LayerScale(nn.Module): - def __init__( - self, - dim: int, - init_values: Union[float, Tensor] = 1e-5, - inplace: bool = False, - ) -> None: - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: Tensor) -> Tensor: - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - -class PatchEmbed(nn.Module): - """ - 2D image to patch embedding: (B,C,H,W) -> (B,N,D) - - Args: - img_size: Image size. - patch_size: Patch token size. - in_chans: Number of input image channels. - embed_dim: Number of linear projection output channels. - norm_layer: Normalization layer. - """ - - def __init__( - self, - img_size: Union[int, Tuple[int, int]] = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer: Optional[Callable] = None, - flatten_embedding: bool = True, - ) -> None: - super().__init__() - - image_HW = make_2tuple(img_size) - patch_HW = make_2tuple(patch_size) - patch_grid_size = ( - image_HW[0] // patch_HW[0], - image_HW[1] // patch_HW[1], - ) - - self.img_size = image_HW - self.patch_size = patch_HW - self.patches_resolution = patch_grid_size - self.num_patches = patch_grid_size[0] * patch_grid_size[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.flatten_embedding = flatten_embedding - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - _, _, H, W = x.shape - patch_H, patch_W = self.patch_size - - assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" - assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" - - x = self.proj(x) # B C H W - H, W = x.size(2), x.size(3) - x = x.flatten(2).transpose(1, 2) # B HW C - x = self.norm(x) - if not self.flatten_embedding: - x = x.reshape(-1, H, W, self.embed_dim) # B H W C - return x - - def flops(self) -> float: - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class Mlp(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = nn.GELU, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) - self.drop = nn.Dropout(drop) - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class SwiGLUFFN(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) - self.w3 = nn.Linear(hidden_features, out_features, bias=bias) - - def forward(self, x: Tensor) -> Tensor: - x12 = self.w12(x) - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - return self.w3(hidden) - - -try: - from xformers.ops import SwiGLU - #import numpy.bool - XFORMERS_AVAILABLE = True -except ImportError: - SwiGLU = SwiGLUFFN - XFORMERS_AVAILABLE = False - -class SwiGLUFFNFused(SwiGLU): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - out_features = out_features or in_features - hidden_features = hidden_features or in_features - hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - out_features=out_features, - bias=bias, - ) - - -try: - from xformers.ops import memory_efficient_attention, unbind, fmha - from xformers.components.attention import ScaledDotProduct - from xformers.components import MultiHeadDispatch - #import numpy.bool - XFORMERS_AVAILABLE = True -except ImportError: - logger.warning("xFormers not available") - XFORMERS_AVAILABLE = False - - -class Attention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - proj_bias: bool = True, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - window_size: int = 0, - ) -> None: - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim, bias=proj_bias) - self.proj_drop = nn.Dropout(proj_drop) - - #if not self.training: - # - # self.attn = ScaledDotProduct() - #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn) - - def forward(self, x: Tensor, attn_bias=None) -> Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - - q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = q @ k.transpose(-2, -1) - - if attn_bias is not None: - attn = attn + attn_bias[:, :, :N] - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class MemEffAttention(Attention): - def forward(self, x: Tensor, attn_bias=None) -> Tensor: - if not XFORMERS_AVAILABLE: - #if True: - assert attn_bias is None, "xFormers is required for nested tensors usage" - return super().forward(x, attn_bias) - - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - - q, k, v = unbind(qkv, 2) - if attn_bias is not None: - x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N]) - else: - x = memory_efficient_attention(q, k, v) - x = x.reshape([B, N, C]) - - x = self.proj(x) - x = self.proj_drop(x) - return x - -try: - from xformers.ops import fmha - from xformers.ops import scaled_index_add, index_select_cat - #import numpy.bool - XFORMERS_AVAILABLE = True -except ImportError: - logger.warning("xFormers not available") - XFORMERS_AVAILABLE = False - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = False, - proj_bias: bool = True, - ffn_bias: bool = True, - drop: float = 0.0, - attn_drop: float = 0.0, - init_values = None, - drop_path: float = 0.0, - act_layer: Callable[..., nn.Module] = nn.GELU, - norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_class: Callable[..., nn.Module] = Attention, - ffn_layer: Callable[..., nn.Module] = Mlp, - ) -> None: - super().__init__() - # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") - self.norm1 = norm_layer(dim) - self.attn = attn_class( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - attn_drop=attn_drop, - proj_drop=drop, - ) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = ffn_layer( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - bias=ffn_bias, - ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.sample_drop_ratio = drop_path - - def forward(self, x: Tensor, attn_bias=None) -> Tensor: - def attn_residual_func(x: Tensor, attn_bias) -> Tensor: - return self.ls1(self.attn(self.norm1(x), attn_bias)) - - def ffn_residual_func(x: Tensor) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - if self.training and self.sample_drop_ratio > 0.1: - # the overhead is compensated only for a drop path rate larger than 0.1 - x = drop_add_residual_stochastic_depth( - x, - residual_func=attn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - attn_bias=attn_bias - ) - x = drop_add_residual_stochastic_depth( - x, - residual_func=ffn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - ) - elif self.training and self.sample_drop_ratio > 0.0: - x = x + self.drop_path1(attn_residual_func(x, attn_bias)) - x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 - else: - x = x + attn_residual_func(x, attn_bias) - x = x + ffn_residual_func(x) - return x - - -def drop_add_residual_stochastic_depth( - x: Tensor, - residual_func: Callable[[Tensor], Tensor], - sample_drop_ratio: float = 0.0, attn_bias=None -) -> Tensor: - # 1) extract subset using permutation - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - x_subset = x[brange] - - # 2) apply residual_func to get residual - residual = residual_func(x_subset, attn_bias) - - x_flat = x.flatten(1) - residual = residual.flatten(1) - - residual_scale_factor = b / sample_subset_size - - # 3) add the residual - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) - return x_plus_residual.view_as(x) - - -def get_branges_scales(x, sample_drop_ratio=0.0): - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - residual_scale_factor = b / sample_subset_size - return brange, residual_scale_factor - - -def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): - if scaling_vector is None: - x_flat = x.flatten(1) - residual = residual.flatten(1) - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) - else: - x_plus_residual = scaled_index_add( - x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor - ) - return x_plus_residual - - -attn_bias_cache: Dict[Tuple, Any] = {} - - -def get_attn_bias_and_cat(x_list, branges=None): - """ - this will perform the index select, cat the tensors, and provide the attn_bias from cache - """ - batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] - all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) - if all_shapes not in attn_bias_cache.keys(): - seqlens = [] - for b, x in zip(batch_sizes, x_list): - for _ in range(b): - seqlens.append(x.shape[1]) - attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) - attn_bias._batch_sizes = batch_sizes - attn_bias_cache[all_shapes] = attn_bias - - if branges is not None: - cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) - else: - tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) - cat_tensors = torch.cat(tensors_bs1, dim=1) - - return attn_bias_cache[all_shapes], cat_tensors - - -def drop_add_residual_stochastic_depth_list( - x_list: List[Tensor], - residual_func: Callable[[Tensor, Any], Tensor], - sample_drop_ratio: float = 0.0, - scaling_vector=None, -) -> Tensor: - # 1) generate random set of indices for dropping samples in the batch - branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] - branges = [s[0] for s in branges_scales] - residual_scale_factors = [s[1] for s in branges_scales] - - # 2) get attention bias and index+concat the tensors - attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) - - # 3) apply residual_func to get residual, and split the result - residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore - - outputs = [] - for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): - outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) - return outputs - - -class NestedTensorBlock(Block): - def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: - """ - x_list contains a list of tensors to nest together and run - """ - assert isinstance(self.attn, MemEffAttention) - - if self.training and self.sample_drop_ratio > 0.0: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.attn(self.norm1(x), attn_bias=attn_bias) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.mlp(self.norm2(x)) - - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=attn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, - ) - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=ffn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, - ) - return x_list - else: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - attn_bias, x = get_attn_bias_and_cat(x_list) - x = x + attn_residual_func(x, attn_bias=attn_bias) - x = x + ffn_residual_func(x) - return attn_bias.split(x) - - def forward(self, x_or_x_list, attn_bias=None): - if isinstance(x_or_x_list, Tensor): - return super().forward(x_or_x_list, attn_bias) - elif isinstance(x_or_x_list, list): - assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" - return self.forward_nested(x_or_x_list) - else: - raise AssertionError - - -def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = ".".join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -class BlockChunk(nn.ModuleList): - def forward(self, x, others=None): - for b in self: - if others == None: - x = b(x) - else: - x = b(x, others) - return x - - -class DinoVisionTransformer(nn.Module): - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - ffn_bias=True, - proj_bias=True, - drop_path_rate=0.0, - drop_path_uniform=False, - #init_values=None, # for layerscale: None or 0 => no layerscale - init_values=1e-5, # for layerscale: None or 0 => no layerscale - embed_layer=PatchEmbed, - act_layer=nn.GELU, - block_fn=NestedTensorBlock, - ffn_layer="mlp", - block_chunks=1, - window_size=37, - **kwargs - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - proj_bias (bool): enable bias for proj in attn if True - ffn_bias (bool): enable bias for ffn if True - drop_path_rate (float): stochastic depth rate - drop_path_uniform (bool): apply uniform drop rate across blocks - weight_init (str): weight init scheme - init_values (float): layer-scale init values - embed_layer (nn.Module): patch embedding layer - act_layer (nn.Module): MLP activation layer - block_fn (nn.Module): transformer block class - ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" - block_chunks: (int) split block sequence into block_chunks units for FSDP wrap - """ - super().__init__() - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 - self.n_blocks = depth - self.num_heads = num_heads - self.patch_size = patch_size - self.window_size = window_size - - self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - - if drop_path_uniform is True: - dpr = [drop_path_rate] * depth - else: - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - - if ffn_layer == "mlp": - logger.info("using MLP layer as FFN") - ffn_layer = Mlp - elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": - logger.info("using SwiGLU layer as FFN") - ffn_layer = SwiGLUFFNFused - elif ffn_layer == "identity": - logger.info("using Identity layer as FFN") - - def f(*args, **kwargs): - return nn.Identity() - - ffn_layer = f - else: - raise NotImplementedError - - blocks_list = [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - drop_path=dpr[i], - norm_layer=norm_layer, - act_layer=act_layer, - ffn_layer=ffn_layer, - init_values=init_values, - ) - for i in range(depth) - ] - if block_chunks > 0: - self.chunked_blocks = True - chunked_blocks = [] - chunksize = depth // block_chunks - for i in range(0, depth, chunksize): - # this is to keep the block index consistent if we chunk the block list - chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) - self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) - else: - self.chunked_blocks = False - self.blocks = nn.ModuleList(blocks_list) - - self.norm = norm_layer(embed_dim) - self.head = nn.Identity() - - self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) - - self.init_weights() - - def init_weights(self): - trunc_normal_(self.pos_embed, std=0.02) - nn.init.normal_(self.cls_token, std=1e-6) - named_apply(init_weights_vit_timm, self) - - def interpolate_pos_encoding(self, x, w, h): - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = self.pos_embed.shape[1] - 1 - if npatch == N and w == h: - return self.pos_embed - pos_embed = self.pos_embed.float() - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - w0 = w // self.patch_size - h0 = h // self.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - w0, h0 = w0 + 0.1, h0 + 0.1 - - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), - scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), - mode="bicubic", - ) - - assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - - def prepare_tokens_with_masks(self, x, masks=None): - B, nc, w, h = x.shape - x = self.patch_embed(x) - if masks is not None: - x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.interpolate_pos_encoding(x, w, h) - - return x - - def forward_features_list(self, x_list, masks_list): - x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] - for blk in self.blocks: - x = blk(x) - - all_x = x - output = [] - for x, masks in zip(all_x, masks_list): - x_norm = self.norm(x) - output.append( - { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_patchtokens": x_norm[:, 1:], - "x_prenorm": x, - "masks": masks, - } - ) - return output - - def forward_features(self, x, masks=None): - if isinstance(x, list): - return self.forward_features_list(x, masks) - - B, C, H, W = x.size() - pad_h = (self.patch_size - H % self.patch_size) - pad_w = (self.patch_size - W % self.patch_size) - if pad_h == self.patch_size: - pad_h = 0 - if pad_w == self.patch_size: - pad_w = 0 - #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) - if pad_h + pad_w > 0: - x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') - - x = self.prepare_tokens_with_masks(x, masks) - - features = [] - for blk in self.blocks: - x = blk(x) - # for idx in range(len(self.blocks[0])): - # x = self.blocks[0][idx](x) - # if (idx + 1) % (len(self.blocks[0]) // 4) == 0: - # features.append(x) - - #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] - - x_norm = self.norm(x) - # return { - # "x_norm_clstoken": x_norm[:, 0], - # "x_norm_patchtokens": x_norm[:, 1:], - # "x_prenorm": x, - # "masks": masks, - # } - features = [] - features.append(x_norm) - features.append(x_norm) - features.append(x_norm) - features.append(x_norm) - return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] - - def _get_intermediate_layers_not_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - # If n is an int, take the n last blocks. If it's a list, take them - output, total_block_len = [], len(self.blocks) - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in blocks_to_take: - output.append(x) - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def _get_intermediate_layers_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - output, i, total_block_len = [], 0, len(self.blocks[-1]) - # If n is an int, take the n last blocks. If it's a list, take them - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for block_chunk in self.blocks: - for blk in block_chunk[i:]: # Passing the nn.Identity() - x = blk(x) - if i in blocks_to_take: - output.append(x) - i += 1 - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def get_intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, # Layers or n last layers to take - reshape: bool = False, - return_class_token: bool = False, - norm=True, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - if self.chunked_blocks: - outputs = self._get_intermediate_layers_chunked(x, n) - else: - outputs = self._get_intermediate_layers_not_chunked(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1:] for out in outputs] - if reshape: - B, _, w, h = x.shape - outputs = [ - out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() - for out in outputs - ] - if return_class_token: - return tuple(zip(outputs, class_tokens)) - return tuple(outputs) - - def forward(self, *args, is_training=False, **kwargs): - ret = self.forward_features(*args, **kwargs) - return ret - # if is_training: - # return ret - # else: - # return self.head(ret["x_norm_clstoken"]) - - -class PosConv(nn.Module): - # PEG from https://arxiv.org/abs/2102.10882 - def __init__(self, in_chans, embed_dim=768, stride=1): - super(PosConv, self).__init__() - self.proj = nn.Sequential( - nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim), - ) - self.stride = stride - - def forward(self, x, size): - B, N, C = x.shape - cnn_feat_token = x.transpose(1, 2).view(B, C, *size) - x = self.proj(cnn_feat_token) - if self.stride == 1: - x += cnn_feat_token - x = x.flatten(2).transpose(1, 2) - return x - - #def no_weight_decay(self): - #return ['proj.%d.weight' % i for i in range(4)] - -class DinoWindowVisionTransformer(nn.Module): - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - ffn_bias=True, - proj_bias=True, - drop_path_rate=0.0, - drop_path_uniform=False, - #init_values=None, # for layerscale: None or 0 => no layerscale - init_values=1e-5, # for layerscale: None or 0 => no layerscale - embed_layer=PatchEmbed, - act_layer=nn.GELU, - block_fn=NestedTensorBlock, - ffn_layer="mlp", - block_chunks=1, - window_size=7, - **kwargs - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - proj_bias (bool): enable bias for proj in attn if True - ffn_bias (bool): enable bias for ffn if True - drop_path_rate (float): stochastic depth rate - drop_path_uniform (bool): apply uniform drop rate across blocks - weight_init (str): weight init scheme - init_values (float): layer-scale init values - embed_layer (nn.Module): patch embedding layer - act_layer (nn.Module): MLP activation layer - block_fn (nn.Module): transformer block class - ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" - block_chunks: (int) split block sequence into block_chunks units for FSDP wrap - """ - super().__init__() - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 - self.n_blocks = depth - self.num_heads = num_heads - self.patch_size = patch_size - - self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - - self.pos_conv = PosConv(self.embed_dim, self.embed_dim) - - self.window_size = window_size - #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)]) - #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)]) - - if drop_path_uniform is True: - dpr = [drop_path_rate] * depth - else: - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - - if ffn_layer == "mlp": - logger.info("using MLP layer as FFN") - ffn_layer = Mlp - elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": - logger.info("using SwiGLU layer as FFN") - ffn_layer = SwiGLUFFNFused - elif ffn_layer == "identity": - logger.info("using Identity layer as FFN") - - def f(*args, **kwargs): - return nn.Identity() - - ffn_layer = f - else: - raise NotImplementedError - - blocks_list = [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - drop_path=dpr[i], - norm_layer=norm_layer, - act_layer=act_layer, - ffn_layer=ffn_layer, - init_values=init_values, - ) - for i in range(depth) - ] - if block_chunks > 0: - self.chunked_blocks = True - chunked_blocks = [] - chunksize = depth // block_chunks - for i in range(0, depth, chunksize): - # this is to keep the block index consistent if we chunk the block list - chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) - self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) - else: - self.chunked_blocks = False - self.blocks = nn.ModuleList(blocks_list) - - self.norm = norm_layer(embed_dim) - self.head = nn.Identity() - - self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) - - self.nh = -1 - self.nw = -1 - try: - H = cfg.data_basic['crop_size'][0] - W = cfg.data_basic['crop_size'][1] - pad_h = (self.patch_size - H % self.patch_size) - pad_w = (self.patch_size - W % self.patch_size) - if pad_h == self.patch_size: - pad_h = 0 - if pad_w == self.patch_size: - pad_w = 0 - self.nh = (H + pad_h) // self.patch_size - self.nw = (W + pad_w) // self.patch_size - self.prepare_attn_bias((self.nh, self.nw)) - except: - pass - self.init_weights() - - self.total_step = 10000 # For PE -> GPE transfer - self.start_step = 2000 - self.current_step = 20000 - - def init_weights(self): - #trunc_normal_(self.pos_embed, std=0.02) - #nn.init.normal_(self.cls_token, std=1e-6) - named_apply(init_weights_vit_timm, self) - for i in range(4): - try: - nn.init.constant_(self.conv_block[i].conv2.weight, 0.0) - except: - pass - - def interpolate_pos_encoding(self, x, w, h): - previous_dtype = x.dtype - #npatch = x.shape[1] - 1 - #N = self.pos_embed.shape[1] - 1 - npatch = x.shape[1] - N = self.pos_embed.shape[1] - if npatch == N and w == h: - return self.pos_embed - pos_embed = self.pos_embed.float() - #class_pos_embed = pos_embed[:, 0] - #patch_pos_embed = pos_embed[:, 1:] - patch_pos_embed = pos_embed - dim = x.shape[-1] - w0 = w // self.patch_size - h0 = h // self.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - w0, h0 = w0 + 0.1, h0 + 0.1 - - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), - scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), - mode="bicubic", - ) - - assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return patch_pos_embed.to(previous_dtype) - #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - - def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]: - """ - Partition into non-overlapping windows with padding if needed. - Args: - x (tensor): input tokens with [B, H, W, C]. - window_size (int): window size. - - Returns: - windows: windows after partition with [B * num_windows, window_size, window_size, C]. - (Hp, Wp): padded height and width before partition - """ - if conv_feature == False: - B, N, C = x.shape - H, W = hw[0], hw[1] - - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C) - else: - B, C, H, W = x.shape - - x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) - - windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C) - - #y = torch.cat((x_cls, windows), dim=1) - return windows #, (Hp, Wp) - - - def window_unpartition(self, - windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False - ) -> torch.Tensor: - """ - Window unpartition into original sequences and removing padding. - Args: - windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. - window_size (int): window size. - pad_hw (Tuple): padded height and width (Hp, Wp). - hw (Tuple): original height and width (H, W) before padding. - - Returns: - x: unpartitioned sequences with [B, H, W, C]. - """ - H, W = hw - - B = windows.shape[0] // (H * W // window_size // window_size) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - - if conv_feature == False: - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1) - else: - C = windows.shape[-1] - x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W) - - # if Hp > H or Wp > W: - # x = x[:, :H, :W, :].contiguous() - return x - - def prepare_tokens_with_masks(self, x, masks=None, step=-1): - B, nc, w, h = x.shape - x = self.patch_embed(x) - if masks is not None: - x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - - #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - if step == -1: - step = self.current_step - else: - self.current_step = step - - if step < self.start_step: - coef = 0.0 - elif step < self.total_step: - coef = (step - self.start_step) / (self.total_step - self.start_step) - else: - coef = 1.0 - - x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw)) - - return x - - def prepare_attn_bias(self, shape): - window_size = self.window_size - if window_size <= 0: - return - - import xformers.components.attention.attention_patterns as AP - - nh, nw = shape - radius = (window_size-1)//2 - mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() - - pad = (8 - (nh * nw) % 8) - if pad == 8: - pad = 0 - mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous() - if pad > 0: - mask = mask_pad[:, :-pad].view(nh, nw, nh, nw) - else: - mask = mask_pad[:, :].view(nh, nw, nh, nw) - - # angle - mask[:radius+1, :radius+1, :window_size, :window_size] = True - mask[:radius+1, -radius-1:, :window_size, -window_size:] = True - mask[-radius-1:, :radius+1, -window_size:, :window_size] = True - mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True - - # edge - mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :] - mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :] - mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :] - mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :] - - mask = mask.view(nh*nw, nh*nw) - bias_pad = torch.log(mask_pad) - #bias = bias_pad[:, :-pad] - self.register_buffer('attn_bias', bias_pad) - - return bias_pad - - def forward_features_list(self, x_list, masks_list): - x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] - for blk in self.blocks: - x = blk(x) - - all_x = x - output = [] - for x, masks in zip(all_x, masks_list): - x_norm = self.norm(x) - output.append( - { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_patchtokens": x_norm[:, 1:], - "x_prenorm": x, - "masks": masks, - } - ) - return output - - def forward_features(self, x, masks=None, **kwargs): - if isinstance(x, list): - return self.forward_features_list(x, masks) - - B, C, H, W = x.size() - pad_h = (self.patch_size - H % self.patch_size) - pad_w = (self.patch_size - W % self.patch_size) - if pad_h == self.patch_size: - pad_h = 0 - if pad_w == self.patch_size: - pad_w = 0 - #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) - if pad_h + pad_w > 0: - x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') - - nh = (H+pad_h)//self.patch_size - nw = (W+pad_w)//self.patch_size - - if self.window_size > 0: - if nh == self.nh and nw == self.nw: - attn_bias = self.attn_bias - else: - attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size)) - self.nh = nh - self.nw = nw - attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1) - else: - attn_bias = None - - x = self.prepare_tokens_with_masks(x, masks) - #x = self.patch_embed(x) - - features = [] - #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size)) - for blk in self.blocks: - x = blk(x, attn_bias) - #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size)) - - # for idx in range(len(self.blocks[0])): - # x = self.blocks[0][idx](x, attn_bias) - - # if (idx + 1) % (len(self.blocks[0]) // 4) == 0: - # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True) - # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x) - # if idx + 1 != len(self.blocks[0]): - # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True) - # else: - # b, c, h, w = x.size() - # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c) - #features.append(x) - - #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] - - x_norm = self.norm(x) - # return { - # "x_norm_clstoken": x_norm[:, 0], - # "x_norm_patchtokens": x_norm[:, 1:], - # "x_prenorm": x, - # "masks": masks, - # } - features = [] - features.append(x_norm) - features.append(x_norm) - features.append(x_norm) - features.append(x_norm) - return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] - - def _get_intermediate_layers_not_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - # If n is an int, take the n last blocks. If it's a list, take them - output, total_block_len = [], len(self.blocks) - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in blocks_to_take: - output.append(x) - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def _get_intermediate_layers_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - output, i, total_block_len = [], 0, len(self.blocks[-1]) - # If n is an int, take the n last blocks. If it's a list, take them - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for block_chunk in self.blocks: - for blk in block_chunk[i:]: # Passing the nn.Identity() - x = blk(x) - if i in blocks_to_take: - output.append(x) - i += 1 - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def get_intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, # Layers or n last layers to take - reshape: bool = False, - return_class_token: bool = False, - norm=True, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - if self.chunked_blocks: - outputs = self._get_intermediate_layers_chunked(x, n) - else: - outputs = self._get_intermediate_layers_not_chunked(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1:] for out in outputs] - if reshape: - B, _, w, h = x.shape - outputs = [ - out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() - for out in outputs - ] - if return_class_token: - return tuple(zip(outputs, class_tokens)) - return tuple(outputs) - - def forward(self, *args, is_training=False, **kwargs): - ret = self.forward_features(*args, **kwargs) - return ret - # if is_training: - # return ret - # else: - # return self.head(ret["x_norm_clstoken"]) - - - - -def init_weights_vit_timm(module: nn.Module, name: str = ""): - """ViT weight initialization, original timm impl (for reproducibility)""" - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - - -def vit_small(patch_size=14, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=384, - depth=12, - num_heads=6, - mlp_ratio=4, - block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), - **kwargs, - ) - return model - - -def vit_base(patch_size=14, **kwargs): - model = DinoWindowVisionTransformer( - patch_size=patch_size, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4, - block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), - **kwargs, - ) - return model - - -def vit_large(patch_size=14, checkpoint=None, **kwargs): - model = DinoVisionTransformer( - img_size = 518, - patch_size=patch_size, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), - **kwargs, - ) - - if checkpoint is not None: - with open(checkpoint, "rb") as f: - state_dict = torch.load(f) - try: - model.load_state_dict(state_dict, strict=True) - except: - new_state_dict = {} - for key, value in state_dict.items(): - if 'blocks' in key: - key_new = 'blocks.0' + key[len('blocks'):] - else: - key_new = key - new_state_dict[key_new] = value - - model.load_state_dict(new_state_dict, strict=True) - #del model.norm - del model.mask_token - return model - - # model = DinoWindowVisionTransformer( - # img_size = 518, - # patch_size=patch_size, - # embed_dim=1024, - # depth=24, - # num_heads=16, - # mlp_ratio=4, - # block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), - # window_size=37, - # **kwargs, - # ) - - # if checkpoint is not None: - # with open(checkpoint, "rb") as f: - # state_dict = torch.load(f) - # try: - # model.load_state_dict(state_dict, strict=True) - # except: - # new_state_dict = {} - # for key, value in state_dict.items(): - # if 'blocks' in key: - # key_new = 'blocks.0' + key[len('blocks'):] - # else: - # key_new = key - # if 'pos_embed' in key: - # value = value[:, 1:, :] - # new_state_dict[key_new] = value - - # model.load_state_dict(new_state_dict, strict=False) - # #del model.norm - # del model.mask_token - return model - - -def vit_giant2(patch_size=16, **kwargs): - """ - Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 - """ - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1536, - depth=40, - num_heads=24, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - **kwargs, - ) - return model - -if __name__ == '__main__': - try: - from mmcv.utils import Config - except: - from mmengine import Config - - #rgb = torch.rand((2, 3, 518, 518)).cuda() - - #cfg.data_basic['crop_size']['0'] - #cfg.data_basic['crop_size']['1'] - cfg = Config.fromfile('/cpfs01/user/mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py') - - #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036) - rgb = torch.zeros(1, 3, 1400, 1680).cuda() - model = vit_large(checkpoint="/cpfs02/shared/public/custom/group_local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda() - - #import timm - #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda() - #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn) - - out1 = model(rgb) - #out2 = model2(rgb) - temp = 0 - - - -# import time -# window_size = 37 -# def prepare_window_masks(shape): -# if window_size <= 0: -# return None -# import xformers.components.attention.attention_patterns as AP - -# B, nh, nw, _, _ = shape -# radius = (window_size-1)//2 -# #time0 = time.time() -# d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() -# #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() -# # mask = mask.view(nh, nw, nh, nw) -# # #time1 = time.time() - time0 - -# # # angle -# # mask[:radius+1, :radius+1, :window_size, :window_size] = True -# # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True -# # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True -# # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True -# # time2 = time.time() - time0 - time1 - -# # # edge -# # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :] -# # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :] -# # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :] -# # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :] -# # time3 = time.time() - time0 - time2 -# # print(time1, time2, time3) - -# # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1) - -# shape = (1, 55, 55, None, None) -# mask = prepare_window_masks(shape) -# # temp = 1 \ No newline at end of file diff --git a/mono/model/backbones/ViT_DINO_reg.py b/mono/model/backbones/ViT_DINO_reg.py deleted file mode 100644 index 854f96320ea93752e023c8cd845bf38353dfab17..0000000000000000000000000000000000000000 --- a/mono/model/backbones/ViT_DINO_reg.py +++ /dev/null @@ -1,1293 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py - -from functools import partial -import math -import logging -from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List - -import torch -import torch.nn as nn -from torch import Tensor -import torch.utils.checkpoint -from torch.nn.init import trunc_normal_ -import torch.nn.init -import torch.nn.functional as F - -#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block - -logger = logging.getLogger("dinov2") - -# SSF finetuning originally by dongzelian -def init_ssf_scale_shift(dim): - scale = nn.Parameter(torch.ones(dim)) - shift = nn.Parameter(torch.zeros(dim)) - - nn.init.normal_(scale, mean=1, std=.02) - nn.init.normal_(shift, std=.02) - - return scale, shift - -def ssf_ada(x, scale, shift): - assert scale.shape == shift.shape - if x.shape[-1] == scale.shape[0]: - return x * scale + shift - elif x.shape[1] == scale.shape[0]: - return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) - else: - raise ValueError('the input tensor shape does not match the shape of the scale factor.') - -# LoRA finetuning originally by edwardjhu -class LoRALayer(): - def __init__( - self, - r: int, - lora_alpha: int, - lora_dropout: float, - merge_weights: bool, - ): - self.r = r - self.lora_alpha = lora_alpha - # Optional dropout - if lora_dropout > 0.: - self.lora_dropout = nn.Dropout(p=lora_dropout) - else: - self.lora_dropout = lambda x: x - # Mark the weight as unmerged - self.merged = False - self.merge_weights = merge_weights - -class LoRALinear(nn.Linear, LoRALayer): - # LoRA implemented in a dense layer - def __init__( - self, - in_features: int, - out_features: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0., - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - merge_weights: bool = True, - **kwargs - ): - nn.Linear.__init__(self, in_features, out_features, **kwargs) - LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, - merge_weights=merge_weights) - - self.fan_in_fan_out = fan_in_fan_out - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) - self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - self.reset_parameters() - if fan_in_fan_out: - self.weight.data = self.weight.data.transpose(0, 1) - - def reset_parameters(self): - #nn.Linear.reset_parameters(self) - if hasattr(self, 'lora_A'): - # initialize B the same way as the default for nn.Linear and A to zero - # this is different than what is described in the paper but should not affect performance - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - - # def train(self, mode: bool = True): - # def T(w): - # return w.transpose(0, 1) if self.fan_in_fan_out else w - # nn.Linear.train(self, mode) - # if mode: - # if self.merge_weights and self.merged: - # # Make sure that the weights are not merged - # if self.r > 0: - # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling - # self.merged = False - # else: - # if self.merge_weights and not self.merged: - # # Merge the weights and mark it - # if self.r > 0: - # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - # self.merged = True - - def forward(self, x: torch.Tensor): - def T(w): - return w.transpose(0, 1) if self.fan_in_fan_out else w - if self.r > 0 and not self.merged: - result = F.linear(x, T(self.weight), bias=self.bias) - result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling - return result - else: - return F.linear(x, T(self.weight), bias=self.bias) - - - -def make_2tuple(x): - if isinstance(x, tuple): - assert len(x) == 2 - return x - - assert isinstance(x, int) - return (x, x) - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0: - random_tensor.div_(keep_prob) - output = x * random_tensor - return output - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - -class LayerScale(nn.Module): - def __init__( - self, - dim: int, - init_values: Union[float, Tensor] = 1e-5, - inplace: bool = False, - ) -> None: - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: Tensor) -> Tensor: - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - -class PatchEmbed(nn.Module): - """ - 2D image to patch embedding: (B,C,H,W) -> (B,N,D) - - Args: - img_size: Image size. - patch_size: Patch token size. - in_chans: Number of input image channels. - embed_dim: Number of linear projection output channels. - norm_layer: Normalization layer. - """ - - def __init__( - self, - img_size: Union[int, Tuple[int, int]] = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer: Optional[Callable] = None, - flatten_embedding: bool = True, - tuning_mode: Optional[str] = None - ) -> None: - super().__init__() - - image_HW = make_2tuple(img_size) - patch_HW = make_2tuple(patch_size) - patch_grid_size = ( - image_HW[0] // patch_HW[0], - image_HW[1] // patch_HW[1], - ) - - self.img_size = image_HW - self.patch_size = patch_HW - self.patches_resolution = patch_grid_size - self.num_patches = patch_grid_size[0] * patch_grid_size[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.flatten_embedding = flatten_embedding - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - if tuning_mode != None: - self.tuning_mode = tuning_mode - if tuning_mode == 'ssf': - self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) - else: - pass - #raise NotImplementedError() - else: - self.tuning_mode = None - - def forward(self, x: Tensor) -> Tensor: - _, _, H, W = x.shape - patch_H, patch_W = self.patch_size - - assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" - assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" - - x = self.proj(x) # B C H W - H, W = x.size(2), x.size(3) - x = x.flatten(2).transpose(1, 2) # B HW C - x = self.norm(x) - if self.tuning_mode == 'ssf': - x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) - if not self.flatten_embedding: - x = x.reshape(-1, H, W, self.embed_dim) # B H W C - return x - - def flops(self) -> float: - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class Mlp(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = nn.GELU, - drop: float = 0.0, - bias: bool = True, - tuning_mode: Optional[int] = None - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) - self.drop = nn.Dropout(drop) - - if tuning_mode != None: - self.tuning_mode = tuning_mode - if tuning_mode == 'ssf': - self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) - self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) - else: - pass - #raise NotImplementedError() - else: - self.tuning_mode = None - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - if self.tuning_mode == 'ssf': - x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) - - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - if self.tuning_mode == 'ssf': - x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) - - x = self.drop(x) - return x - - -class SwiGLUFFN(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - tuning_mode: Optional[int] = None - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) - self.w3 = nn.Linear(hidden_features, out_features, bias=bias) - - if tuning_mode != None: - self.tuning_mode = tuning_mode - if tuning_mode == 'ssf': - self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(2 * hidden_features) - self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) - else: - pass - #raise NotImplementedError() - else: - self.tuning_mode = None - - - def forward(self, x: Tensor) -> Tensor: - x12 = self.w12(x) - if self.tuning_mode == 'ssf': - x12 = ssf_ada(x12, self.ssf_scale_1, self.ssf_shift_1) - - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - out = self.w3(hidden) - - if self.tuning_mode == 'ssf': - out = ssf_ada(out, self.ssf_scale_2, self.ssf_scale_2) - - return out - - -try: - from xformers.ops import SwiGLU - #import numpy.bool - XFORMERS_AVAILABLE = True -except ImportError: - SwiGLU = SwiGLUFFN - XFORMERS_AVAILABLE = False - -class SwiGLUFFNFused(SwiGLU): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - out_features = out_features or in_features - hidden_features = hidden_features or in_features - hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 - super().__init__( - in_features=in_features, - hidden_features=hidden_features, - out_features=out_features, - bias=bias, - ) - - -try: - from xformers.ops import memory_efficient_attention, unbind, fmha - from xformers.components.attention import ScaledDotProduct - from xformers.components import MultiHeadDispatch - #import numpy.bool - XFORMERS_AVAILABLE = True -except ImportError: - logger.warning("xFormers not available") - XFORMERS_AVAILABLE = False - - -class Attention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - proj_bias: bool = True, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - window_size: int = 0, - tuning_mode: Optional[int] = None - ) -> None: - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - if tuning_mode == 'lora': - self.tuning_mode = tuning_mode - self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=8) - else: - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - - self.attn_drop = nn.Dropout(attn_drop) - - if tuning_mode == 'lora': - self.tuning_mode = tuning_mode - self.proj = LoRALinear(dim, dim, bias=proj_bias, r=8) - else: - self.proj = nn.Linear(dim, dim, bias=proj_bias) - self.proj_drop = nn.Dropout(proj_drop) - - if tuning_mode != None: - self.tuning_mode = tuning_mode - if tuning_mode == 'ssf': - self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3) - self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) - else: - pass - #raise NotImplementedError() - else: - self.tuning_mode = None - - #if not self.training: - # - # self.attn = ScaledDotProduct() - #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn) - - def forward(self, x: Tensor, attn_bias=None) -> Tensor: - B, N, C = x.shape - if self.tuning_mode == 'ssf': - qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - else: - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - - q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = q @ k.transpose(-2, -1) - - if attn_bias is not None: - attn = attn + attn_bias[:, :, :N] - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - - if self.tuning_mode == 'ssf': - x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) - - x = self.proj_drop(x) - return x - - -class MemEffAttention(Attention): - def forward(self, x: Tensor, attn_bias=None) -> Tensor: - if not XFORMERS_AVAILABLE: - #if True: - assert attn_bias is None, "xFormers is required for nested tensors usage" - return super().forward(x, attn_bias) - - B, N, C = x.shape - if self.tuning_mode == 'ssf': - qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads) - else: - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - - q, k, v = unbind(qkv, 2) - if attn_bias is not None: - x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N]) - else: - x = memory_efficient_attention(q, k, v) - x = x.reshape([B, N, C]) - - x = self.proj(x) - if self.tuning_mode == 'ssf': - x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) - - x = self.proj_drop(x) - return x - -try: - from xformers.ops import fmha - from xformers.ops import scaled_index_add, index_select_cat - #import numpy.bool - XFORMERS_AVAILABLE = True -except ImportError: - logger.warning("xFormers not available") - XFORMERS_AVAILABLE = False - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = False, - proj_bias: bool = True, - ffn_bias: bool = True, - drop: float = 0.0, - attn_drop: float = 0.0, - init_values = None, - drop_path: float = 0.0, - act_layer: Callable[..., nn.Module] = nn.GELU, - norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_class: Callable[..., nn.Module] = Attention, - ffn_layer: Callable[..., nn.Module] = Mlp, - tuning_mode: Optional[int] = None - ) -> None: - super().__init__() - # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") - self.norm1 = norm_layer(dim) - self.attn = attn_class( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - attn_drop=attn_drop, - proj_drop=drop, - tuning_mode=tuning_mode - ) - - if tuning_mode != None: - self.tuning_mode = tuning_mode - if tuning_mode == 'ssf': - self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) - self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) - else: - pass - #raise NotImplementedError() - else: - self.tuning_mode = None - - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = ffn_layer( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - bias=ffn_bias, - ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.sample_drop_ratio = drop_path - - def forward(self, x: Tensor, attn_bias=None) -> Tensor: - def attn_residual_func(x: Tensor, attn_bias) -> Tensor: - if self.tuning_mode == 'ssf': - return self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1), attn_bias)) - else: - return self.ls1(self.attn(self.norm1(x), attn_bias)) - - def ffn_residual_func(x: Tensor) -> Tensor: - if self.tuning_mode == 'ssf': - return self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2))) - else: - return self.ls2(self.mlp(self.norm2(x))) - - if self.training and self.sample_drop_ratio > 0.1: - # the overhead is compensated only for a drop path rate larger than 0.1 - x = drop_add_residual_stochastic_depth( - x, - residual_func=attn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - attn_bias=attn_bias - ) - x = drop_add_residual_stochastic_depth( - x, - residual_func=ffn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - ) - elif self.training and self.sample_drop_ratio > 0.0: - x = x + self.drop_path1(attn_residual_func(x, attn_bias)) - x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 - else: - x = x + attn_residual_func(x, attn_bias) - x = x + ffn_residual_func(x) - return x - - -def drop_add_residual_stochastic_depth( - x: Tensor, - residual_func: Callable[[Tensor], Tensor], - sample_drop_ratio: float = 0.0, attn_bias=None -) -> Tensor: - # 1) extract subset using permutation - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - x_subset = x[brange] - - # 2) apply residual_func to get residual - residual = residual_func(x_subset, attn_bias) - - x_flat = x.flatten(1) - residual = residual.flatten(1) - - residual_scale_factor = b / sample_subset_size - - # 3) add the residual - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) - return x_plus_residual.view_as(x) - - -def get_branges_scales(x, sample_drop_ratio=0.0): - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - residual_scale_factor = b / sample_subset_size - return brange, residual_scale_factor - - -def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): - if scaling_vector is None: - x_flat = x.flatten(1) - residual = residual.flatten(1) - x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) - else: - x_plus_residual = scaled_index_add( - x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor - ) - return x_plus_residual - - -attn_bias_cache: Dict[Tuple, Any] = {} - - -def get_attn_bias_and_cat(x_list, branges=None): - """ - this will perform the index select, cat the tensors, and provide the attn_bias from cache - """ - batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] - all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) - if all_shapes not in attn_bias_cache.keys(): - seqlens = [] - for b, x in zip(batch_sizes, x_list): - for _ in range(b): - seqlens.append(x.shape[1]) - attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) - attn_bias._batch_sizes = batch_sizes - attn_bias_cache[all_shapes] = attn_bias - - if branges is not None: - cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) - else: - tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) - cat_tensors = torch.cat(tensors_bs1, dim=1) - - return attn_bias_cache[all_shapes], cat_tensors - - -def drop_add_residual_stochastic_depth_list( - x_list: List[Tensor], - residual_func: Callable[[Tensor, Any], Tensor], - sample_drop_ratio: float = 0.0, - scaling_vector=None, -) -> Tensor: - # 1) generate random set of indices for dropping samples in the batch - branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] - branges = [s[0] for s in branges_scales] - residual_scale_factors = [s[1] for s in branges_scales] - - # 2) get attention bias and index+concat the tensors - attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) - - # 3) apply residual_func to get residual, and split the result - residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore - - outputs = [] - for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): - outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) - return outputs - - -class NestedTensorBlock(Block): - def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: - """ - x_list contains a list of tensors to nest together and run - """ - assert isinstance(self.attn, MemEffAttention) - - if self.training and self.sample_drop_ratio > 0.0: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.attn(self.norm1(x), attn_bias=attn_bias) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.mlp(self.norm2(x)) - - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=attn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, - ) - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=ffn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, - ) - return x_list - else: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - attn_bias, x = get_attn_bias_and_cat(x_list) - x = x + attn_residual_func(x, attn_bias=attn_bias) - x = x + ffn_residual_func(x) - return attn_bias.split(x) - - def forward(self, x_or_x_list, attn_bias=None): - if isinstance(x_or_x_list, Tensor): - return super().forward(x_or_x_list, attn_bias) - elif isinstance(x_or_x_list, list): - assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" - return self.forward_nested(x_or_x_list) - else: - raise AssertionError - - -def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = ".".join((name, child_name)) if name else child_name - named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -class BlockChunk(nn.ModuleList): - def forward(self, x, others=None): - for b in self: - if others == None: - x = b(x) - else: - x = b(x, others) - return x - - -class DinoVisionTransformer(nn.Module): - def __init__( - self, - img_size=518, - patch_size=16, - in_chans=3, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - ffn_bias=True, - proj_bias=True, - drop_path_rate=0.0, - drop_path_uniform=False, - init_values=1e-5, # for layerscale: None or 0 => no layerscale - embed_layer=PatchEmbed, - act_layer=nn.GELU, - block_fn=Block, - ffn_layer="mlp", - block_chunks=1, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - tuning_mode=None, - **kwargs - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - proj_bias (bool): enable bias for proj in attn if True - ffn_bias (bool): enable bias for ffn if True - drop_path_rate (float): stochastic depth rate - drop_path_uniform (bool): apply uniform drop rate across blocks - weight_init (str): weight init scheme - init_values (float): layer-scale init values - embed_layer (nn.Module): patch embedding layer - act_layer (nn.Module): MLP activation layer - block_fn (nn.Module): transformer block class - ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" - block_chunks: (int) split block sequence into block_chunks units for FSDP wrap - num_register_tokens: (int) number of extra cls tokens (so-called "registers") - interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings - interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings - """ - super().__init__() - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 - self.n_blocks = depth - self.num_heads = num_heads - self.patch_size = patch_size - self.num_register_tokens = num_register_tokens - self.interpolate_antialias = interpolate_antialias - self.interpolate_offset = interpolate_offset - - if tuning_mode != None: - self.tuning_mode = tuning_mode - if tuning_mode == 'ssf': - self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) - else: - pass - #raise NotImplementedError() - else: - self.tuning_mode = None - tuning_mode_list = [tuning_mode] * depth - - self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - assert num_register_tokens >= 0 - self.register_tokens = ( - nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None - ) - - if drop_path_uniform is True: - dpr = [drop_path_rate] * depth - else: - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - - if ffn_layer == "mlp": - logger.info("using MLP layer as FFN") - ffn_layer = Mlp - elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": - logger.info("using SwiGLU layer as FFN") - ffn_layer = SwiGLUFFNFused - elif ffn_layer == "identity": - logger.info("using Identity layer as FFN") - - def f(*args, **kwargs): - return nn.Identity() - - ffn_layer = f - else: - raise NotImplementedError - - blocks_list = [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - drop_path=dpr[i], - norm_layer=norm_layer, - act_layer=act_layer, - ffn_layer=ffn_layer, - init_values=init_values, - tuning_mode=tuning_mode_list[i] - ) - for i in range(depth) - ] - if block_chunks > 0: - self.chunked_blocks = True - chunked_blocks = [] - chunksize = depth // block_chunks - for i in range(0, depth, chunksize): - # this is to keep the block index consistent if we chunk the block list - chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) - self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) - else: - self.chunked_blocks = False - self.blocks = nn.ModuleList(blocks_list) - - self.norm = norm_layer(embed_dim) - self.head = nn.Identity() - - self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) - - self.init_weights() - - def init_weights(self): - trunc_normal_(self.pos_embed, std=0.02) - nn.init.normal_(self.cls_token, std=1e-6) - if self.register_tokens is not None: - nn.init.normal_(self.register_tokens, std=1e-6) - named_apply(init_weights_vit_timm, self) - - def interpolate_pos_encoding(self, x, w, h): - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = self.pos_embed.shape[1] - 1 - if npatch == N and w == h: - return self.pos_embed - pos_embed = self.pos_embed.float() - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - w0 = w // self.patch_size - h0 = h // self.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset - - sqrt_N = math.sqrt(N) - sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), - scale_factor=(sx, sy), - mode="bicubic", - antialias=self.interpolate_antialias, - ) - - assert int(w0) == patch_pos_embed.shape[-2] - assert int(h0) == patch_pos_embed.shape[-1] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - - def prepare_tokens_with_masks(self, x, masks=None): - B, nc, w, h = x.shape - x = self.patch_embed(x) - if masks is not None: - x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.interpolate_pos_encoding(x, w, h) - - if self.register_tokens is not None: - x = torch.cat( - ( - x[:, :1], - self.register_tokens.expand(x.shape[0], -1, -1), - x[:, 1:], - ), - dim=1, - ) - - return x - - def forward_features_list(self, x_list, masks_list): - x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] - for blk in self.blocks: - x = blk(x) - - all_x = x - output = [] - for x, masks in zip(all_x, masks_list): - x_norm = self.norm(x) - output.append( - { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - } - ) - return output - - def forward_features(self, x, masks=None): - if isinstance(x, list): - return self.forward_features_list(x, masks) - - B, C, H, W = x.size() - pad_h = (self.patch_size - H % self.patch_size) - pad_w = (self.patch_size - W % self.patch_size) - if pad_h == self.patch_size: - pad_h = 0 - if pad_w == self.patch_size: - pad_w = 0 - #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) - if pad_h + pad_w > 0: - x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') - - x = self.prepare_tokens_with_masks(x, masks) - - for blk in self.blocks: - x = blk(x) - - x_norm = self.norm(x) - if self.tuning_mode == 'ssf': - x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1) - - # return { - # "x_norm_clstoken": x_norm[:, 0], - # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - # "x_prenorm": x, - # "masks": masks, - # } - features = [] - features.append(x_norm) - features.append(x_norm) - features.append(x_norm) - features.append(x_norm) - return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)] - - - def _get_intermediate_layers_not_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - # If n is an int, take the n last blocks. If it's a list, take them - output, total_block_len = [], len(self.blocks) - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in blocks_to_take: - output.append(x) - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def _get_intermediate_layers_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - output, i, total_block_len = [], 0, len(self.blocks[-1]) - # If n is an int, take the n last blocks. If it's a list, take them - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for block_chunk in self.blocks: - for blk in block_chunk[i:]: # Passing the nn.Identity() - x = blk(x) - if i in blocks_to_take: - output.append(x) - i += 1 - assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def get_intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, # Layers or n last layers to take - reshape: bool = False, - return_class_token: bool = False, - norm=True, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - if self.chunked_blocks: - outputs = self._get_intermediate_layers_chunked(x, n) - else: - outputs = self._get_intermediate_layers_not_chunked(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1:] for out in outputs] - if reshape: - B, _, w, h = x.shape - outputs = [ - out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() - for out in outputs - ] - if return_class_token: - return tuple(zip(outputs, class_tokens)) - return tuple(outputs) - - def forward(self, *args, is_training=False, **kwargs): - ret = self.forward_features(*args, **kwargs) - return ret - # if is_training: - # return ret - # else: - # return self.head(ret["x_norm_clstoken"]) - - -def init_weights_vit_timm(module: nn.Module, name: str = ""): - """ViT weight initialization, original timm impl (for reproducibility)""" - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - - -def load_ckpt_dino(checkpoint, model): - if checkpoint is not None: - try: - with open(checkpoint, "rb") as f: - state_dict = torch.load(f) - except: - print('NO pretrained imagenet ckpt available! Check your path!') - del model.mask_token - return - - try: - model.load_state_dict(state_dict, strict=True) - except: - new_state_dict = {} - for key, value in state_dict.items(): - if 'blocks' in key: - key_new = 'blocks.0' + key[len('blocks'):] - else: - key_new = key - new_state_dict[key_new] = value - - model.load_state_dict(new_state_dict, strict=True) - del model.mask_token - return - else: - return - - -def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=384, - depth=12, - num_heads=6, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - - load_ckpt_dino(checkpoint, model) - - return model - - -def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model - - -def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - - if checkpoint is not None: - with open(checkpoint, "rb") as f: - state_dict = torch.load(f) - try: - model.load_state_dict(state_dict, strict=True) - except: - new_state_dict = {} - for key, value in state_dict.items(): - if 'blocks' in key: - key_new = 'blocks.0' + key[len('blocks'):] - else: - key_new = key - new_state_dict[key_new] = value - - model.load_state_dict(new_state_dict, strict=True) - del model.mask_token - return model - - -def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): - """ - Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 - """ - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1536, - depth=40, - num_heads=24, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - ffn_layer='swiglu', - **kwargs, - ) - return model - - - -def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=384, - depth=12, - num_heads=6, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - tuning_mode=tuning_mode, - **kwargs, - ) - - load_ckpt_dino(checkpoint, model) - - return model - - -def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - - load_ckpt_dino(checkpoint, model) - - return model - - -def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): - model = DinoVisionTransformer( - img_size = 518, - patch_size=patch_size, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - tuning_mode=tuning_mode, - **kwargs, - ) - - load_ckpt_dino(checkpoint, model) - - return model - - -def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): - """ - Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 - """ - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1536, - depth=40, - num_heads=24, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - ffn_layer='swiglu', - tuning_mode=tuning_mode, - **kwargs, - ) - - load_ckpt_dino(checkpoint, model) - - return model - -if __name__ == '__main__': - try: - from mmcv.utils import Config - except: - from mmengine import Config - - #rgb = torch.rand((2, 3, 518, 518)).cuda() - - #cfg.data_basic['crop_size']['0'] - #cfg.data_basic['crop_size']['1'] - cfg = Config.fromfile('/opt/ml/project/mu.hu/projects/monodepth_vit/mono/configs/RAFTDecoder/vit.raft5.large.kitti.py') - - #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036) - rgb = torch.zeros(1, 3, 616, 1064).cuda() - cfg['tuning_mode'] = 'ssf' - #model = vit_large_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_reg4_pretrain.pth", kwarg=cfg).cuda() - model = vit_large_reg(tuning_mode='ssf').cuda() - - #import timm - #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda() - #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn) - - out1 = model(rgb) - #out2 = model2(rgb) - temp = 0 - - diff --git a/mono/model/backbones/__init__.py b/mono/model/backbones/__init__.py deleted file mode 100644 index 8cc3ba70ef5ef867f0518d73a189e7531466cbab..0000000000000000000000000000000000000000 --- a/mono/model/backbones/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .ConvNeXt import convnext_xlarge -from .ConvNeXt import convnext_small -from .ConvNeXt import convnext_base -from .ConvNeXt import convnext_large -from .ConvNeXt import convnext_tiny -from .ViT_DINO import vit_large -from .ViT_DINO_reg import vit_small_reg, vit_large_reg - -__all__ = [ - 'convnext_xlarge', 'convnext_small', 'convnext_base', 'convnext_large', 'convnext_tiny', 'vit_small_reg', 'vit_large_reg' -] diff --git a/mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc b/mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc deleted file mode 100644 index 126ed2ec9338fdbaf1a3d9445815a8ff3f03aea5..0000000000000000000000000000000000000000 Binary files a/mono/model/backbones/__pycache__/ConvNeXt.cpython-39.pyc and /dev/null differ diff --git a/mono/model/backbones/__pycache__/__init__.cpython-39.pyc b/mono/model/backbones/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 16cdbeb696a2cca7a544cc37a01f66e764632aeb..0000000000000000000000000000000000000000 Binary files a/mono/model/backbones/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/mono/model/decode_heads/HourGlassDecoder.py b/mono/model/decode_heads/HourGlassDecoder.py deleted file mode 100644 index e084382601e21e6ce5144abbd6a65f563905b659..0000000000000000000000000000000000000000 --- a/mono/model/decode_heads/HourGlassDecoder.py +++ /dev/null @@ -1,274 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -import math -import torch.nn.functional as F - -def compute_depth_expectation(prob, depth_values): - depth_values = depth_values.view(*depth_values.shape, 1, 1) - depth = torch.sum(prob * depth_values, 1) - return depth - -class ConvBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3): - super(ConvBlock, self).__init__() - - if kernel_size == 3: - self.conv = nn.Sequential( - nn.ReflectionPad2d(1), - nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1), - ) - elif kernel_size == 1: - self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1) - - self.nonlin = nn.ELU(inplace=True) - - def forward(self, x): - out = self.conv(x) - out = self.nonlin(out) - return out - - -class ConvBlock_double(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3): - super(ConvBlock_double, self).__init__() - - if kernel_size == 3: - self.conv = nn.Sequential( - nn.ReflectionPad2d(1), - nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1), - ) - elif kernel_size == 1: - self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1) - - self.nonlin = nn.ELU(inplace=True) - self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1) - self.nonlin_2 =nn.ELU(inplace=True) - - def forward(self, x): - out = self.conv(x) - out = self.nonlin(out) - out = self.conv_2(out) - out = self.nonlin_2(out) - return out - -class DecoderFeature(nn.Module): - def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]): - super(DecoderFeature, self).__init__() - self.num_ch_dec = num_ch_dec - self.feat_channels = feat_channels - - self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1) - self.upconv_3_1 = ConvBlock_double( - self.feat_channels[2] + self.num_ch_dec[3], - self.num_ch_dec[3], - kernel_size=1) - - self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3) - self.upconv_2_1 = ConvBlock_double( - self.feat_channels[1] + self.num_ch_dec[2], - self.num_ch_dec[2], - kernel_size=3) - - self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3) - self.upconv_1_1 = ConvBlock_double( - self.feat_channels[0] + self.num_ch_dec[1], - self.num_ch_dec[1], - kernel_size=3) - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - - def forward(self, ref_feature): - x = ref_feature[3] - - x = self.upconv_3_0(x) - x = torch.cat((self.upsample(x), ref_feature[2]), 1) - x = self.upconv_3_1(x) - - x = self.upconv_2_0(x) - x = torch.cat((self.upsample(x), ref_feature[1]), 1) - x = self.upconv_2_1(x) - - x = self.upconv_1_0(x) - x = torch.cat((self.upsample(x), ref_feature[0]), 1) - x = self.upconv_1_1(x) - return x - - -class UNet(nn.Module): - def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'): - super(UNet, self).__init__() - basic_block = ConvBnReLU - num_depth = 128 - - self.conv0 = basic_block(inp_ch, num_depth) - if channel_mode == 'v0': - channels = [num_depth, num_depth//2, num_depth//4, num_depth//8, num_depth // 8] - elif channel_mode == 'v1': - channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth] - self.down_sample_times = down_sample_times - for i in range(down_sample_times): - setattr( - self, 'conv_%d' % i, - nn.Sequential( - basic_block(channels[i], channels[i+1], stride=2), - basic_block(channels[i+1], channels[i+1]) - ) - ) - for i in range(down_sample_times-1,-1,-1): - setattr(self, 'deconv_%d' % i, - nn.Sequential( - nn.ConvTranspose2d( - channels[i+1], - channels[i], - kernel_size=3, - padding=1, - output_padding=1, - stride=2, - bias=False), - nn.BatchNorm2d(channels[i]), - nn.ReLU(inplace=True) - ) - ) - self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0) - - def forward(self, x): - features = {} - conv0 = self.conv0(x) - x = conv0 - features[0] = conv0 - for i in range(self.down_sample_times): - x = getattr(self, 'conv_%d' % i)(x) - features[i+1] = x - for i in range(self.down_sample_times-1,-1,-1): - x = features[i] + getattr(self, 'deconv_%d' % i)(x) - x = self.prob(x) - return x - -class ConvBnReLU(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): - super(ConvBnReLU, self).__init__() - self.conv = nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=pad, - bias=False - ) - self.bn = nn.BatchNorm2d(out_channels) - - def forward(self, x): - return F.relu(self.bn(self.conv(x)), inplace=True) - - -class HourglassDecoder(nn.Module): - def __init__(self, cfg): - super(HourglassDecoder, self).__init__() - self.inchannels = cfg.model.decode_head.in_channels # [256, 512, 1024, 2048] - self.decoder_channels = cfg.model.decode_head.decoder_channel # [64, 64, 128, 256] - self.min_val = cfg.data_basic.depth_normalize[0] - self.max_val = cfg.data_basic.depth_normalize[1] - - self.num_ch_dec = self.decoder_channels # [64, 64, 128, 256] - self.num_depth_regressor_anchor = 512 - self.feat_channels = self.inchannels - unet_in_channel = self.num_ch_dec[1] - unet_out_channel = 256 - - self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec) - self.conv_out_2 = UNet(inp_ch=unet_in_channel, - output_chal=unet_out_channel + 1, - down_sample_times=3, - channel_mode='v0', - ) - - self.depth_regressor_2 = nn.Sequential( - nn.Conv2d(unet_out_channel, - self.num_depth_regressor_anchor, - kernel_size=3, - padding=1, - ), - nn.BatchNorm2d(self.num_depth_regressor_anchor), - nn.ReLU(inplace=True), - nn.Conv2d( - self.num_depth_regressor_anchor, - self.num_depth_regressor_anchor, - kernel_size=1, - ) - ) - self.residual_channel = 16 - self.conv_up_2 = nn.Sequential( - nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1), - nn.BatchNorm2d(self.residual_channel), - nn.ReLU(), - nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1), - nn.Upsample(scale_factor=4), - nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1), - nn.ReLU(), - nn.Conv2d(self.residual_channel, 1, 1, padding=0), - ) - - def get_bins(self, bins_num): - depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda') - depth_bins_vec = torch.exp(depth_bins_vec) - return depth_bins_vec - - def register_depth_expectation_anchor(self, bins_num, B): - depth_bins_vec = self.get_bins(bins_num) - depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) - self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) - - def upsample(self, x, scale_factor=2): - return F.interpolate(x, scale_factor=scale_factor, mode='nearest') - - def regress_depth_2(self, feature_map_d): - prob = self.depth_regressor_2(feature_map_d).softmax(dim=1) - B = prob.shape[0] - if "depth_expectation_anchor" not in self._buffers: - self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B) - d = compute_depth_expectation( - prob, - self.depth_expectation_anchor[:B, ...] - ).unsqueeze(1) - return d - - def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True): - y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device), - torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij') - meshgrid = torch.stack((x, y)) - meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1) - return meshgrid - - def forward(self, features_mono, **kwargs): - ''' - trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4] - inv_intrinsic_pool: list of inverse intrinsic matrix. - features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...]. - ''' - outputs = {} - # get encoder feature of the reference view - ref_feat = features_mono - - feature_map_mono = self.decoder_mono(ref_feat) - feature_map_mono_pred = self.conv_out_2(feature_map_mono) - confidence_map_2 = feature_map_mono_pred[:, -1:, :, :] - feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :] - - depth_pred_2 = self.regress_depth_2(feature_map_d_2) - - B, _, H, W = depth_pred_2.shape - - meshgrid = self.create_mesh_grid(H, W, B) - - depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \ - self.conv_up_2( - torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1) - ) - confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4) - - outputs=dict( - prediction=depth_pred_mono, - confidence=confidence_map_mono, - pred_logit=None, - ) - return outputs \ No newline at end of file diff --git a/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py b/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py deleted file mode 100644 index 9af89f9b4b1878a2e4bcfcd489075c2e97cd8d3d..0000000000000000000000000000000000000000 --- a/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py +++ /dev/null @@ -1,1033 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np -import math -import torch.nn.functional as F - -# LORA finetuning originally by edwardjhu -class LoRALayer(): - def __init__( - self, - r: int, - lora_alpha: int, - lora_dropout: float, - merge_weights: bool, - ): - self.r = r - self.lora_alpha = lora_alpha - # Optional dropout - if lora_dropout > 0.: - self.lora_dropout = nn.Dropout(p=lora_dropout) - else: - self.lora_dropout = lambda x: x - # Mark the weight as unmerged - self.merged = False - self.merge_weights = merge_weights - -class LoRALinear(nn.Linear, LoRALayer): - # LoRA implemented in a dense layer - def __init__( - self, - in_features: int, - out_features: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0., - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - merge_weights: bool = True, - **kwargs - ): - nn.Linear.__init__(self, in_features, out_features, **kwargs) - LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, - merge_weights=merge_weights) - - self.fan_in_fan_out = fan_in_fan_out - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) - self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - self.reset_parameters() - if fan_in_fan_out: - self.weight.data = self.weight.data.transpose(0, 1) - - def reset_parameters(self): - #nn.Linear.reset_parameters(self) - if hasattr(self, 'lora_A'): - # initialize B the same way as the default for nn.Linear and A to zero - # this is different than what is described in the paper but should not affect performance - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - - # def train(self, mode: bool = True): - # def T(w): - # return w.transpose(0, 1) if self.fan_in_fan_out else w - # nn.Linear.train(self, mode) - # if mode: - # if self.merge_weights and self.merged: - # # Make sure that the weights are not merged - # if self.r > 0: - # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling - # self.merged = False - # else: - # if self.merge_weights and not self.merged: - # # Merge the weights and mark it - # if self.r > 0: - # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - # self.merged = True - - def forward(self, x: torch.Tensor): - def T(w): - return w.transpose(0, 1) if self.fan_in_fan_out else w - if self.r > 0 and not self.merged: - result = F.linear(x, T(self.weight), bias=self.bias) - result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling - return result - else: - return F.linear(x, T(self.weight), bias=self.bias) - -class ConvLoRA(nn.Conv2d, LoRALayer): - def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): - #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) - nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) - LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) - assert isinstance(kernel_size, int) - - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Parameter( - self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) - ) - self.lora_B = nn.Parameter( - self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) - ) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - self.reset_parameters() - self.merged = False - - def reset_parameters(self): - #self.conv.reset_parameters() - if hasattr(self, 'lora_A'): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - - # def train(self, mode=True): - # super(ConvLoRA, self).train(mode) - # if mode: - # if self.merge_weights and self.merged: - # if self.r > 0: - # # Make sure that the weights are not merged - # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling - # self.merged = False - # else: - # if self.merge_weights and not self.merged: - # if self.r > 0: - # # Merge the weights and mark it - # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling - # self.merged = True - - def forward(self, x): - if self.r > 0 and not self.merged: - # return self.conv._conv_forward( - # x, - # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, - # self.conv.bias - # ) - weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling - bias = self.bias - - return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) - else: - return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) - -class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer): - def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): - #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) - nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) - LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) - assert isinstance(kernel_size, int) - - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Parameter( - self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) - ) - self.lora_B = nn.Parameter( - self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) - ) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - self.reset_parameters() - self.merged = False - - def reset_parameters(self): - #self.conv.reset_parameters() - if hasattr(self, 'lora_A'): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - - # def train(self, mode=True): - # super(ConvTransposeLoRA, self).train(mode) - # if mode: - # if self.merge_weights and self.merged: - # if self.r > 0: - # # Make sure that the weights are not merged - # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling - # self.merged = False - # else: - # if self.merge_weights and not self.merged: - # if self.r > 0: - # # Merge the weights and mark it - # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling - # self.merged = True - - def forward(self, x): - if self.r > 0 and not self.merged: - weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling - bias = self.bias - return F.conv_transpose2d(x, weight, - bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding, - groups=self.groups, dilation=self.dilation) - else: - return F.conv_transpose2d(x, self.weight, - bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding, - groups=self.groups, dilation=self.dilation) - #return self.conv(x) - -class Conv2dLoRA(ConvLoRA): - def __init__(self, *args, **kwargs): - super(Conv2dLoRA, self).__init__(*args, **kwargs) - -class ConvTranspose2dLoRA(ConvTransposeLoRA): - def __init__(self, *args, **kwargs): - super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs) - - -def compute_depth_expectation(prob, depth_values): - depth_values = depth_values.view(*depth_values.shape, 1, 1) - depth = torch.sum(prob * depth_values, 1) - return depth - -def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None): - with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): - return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners) - -# def upflow8(flow, mode='bilinear'): -# new_size = (8 * flow.shape[2], 8 * flow.shape[3]) -# return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) - -def upflow4(flow, mode='bilinear'): - new_size = (4 * flow.shape[2], 4 * flow.shape[3]) - with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): - return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) - -def coords_grid(batch, ht, wd): - # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) - coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd))) - coords = torch.stack(coords[::-1], dim=0).float() - return coords[None].repeat(batch, 1, 1, 1) - -def norm_normalize(norm_out): - min_kappa = 0.01 - norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) - norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 - kappa = F.elu(kappa) + 1.0 + min_kappa - final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) - return final_out - -# uncertainty-guided sampling (only used during training) -@torch.no_grad() -def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): - device = init_normal.device - B, _, H, W = init_normal.shape - N = int(sampling_ratio * H * W) - beta = beta - - # uncertainty map - uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W - - # gt_invalid_mask (B, H, W) - if gt_norm_mask is not None: - gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') - gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 - uncertainty_map[gt_invalid_mask] = -1e4 - - # (B, H*W) - _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) - - # importance sampling - if int(beta * N) > 0: - importance = idx[:, :int(beta * N)] # B, beta*N - - # remaining - remaining = idx[:, int(beta * N):] # B, H*W - beta*N - - # coverage - num_coverage = N - int(beta * N) - - if num_coverage <= 0: - samples = importance - else: - coverage_list = [] - for i in range(B): - idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" - coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N - coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N - samples = torch.cat((importance, coverage), dim=1) # B, N - - else: - # remaining - remaining = idx[:, :] # B, H*W - - # coverage - num_coverage = N - - coverage_list = [] - for i in range(B): - idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" - coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N - coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N - samples = coverage - - # point coordinates - rows_int = samples // W # 0 for first row, H-1 for last row - rows_float = rows_int / float(H-1) # 0 to 1.0 - rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 - - cols_int = samples % W # 0 for first column, W-1 for last column - cols_float = cols_int / float(W-1) # 0 to 1.0 - cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 - - point_coords = torch.zeros(B, 1, N, 2) - point_coords[:, 0, :, 0] = cols_float # x coord - point_coords[:, 0, :, 1] = rows_float # y coord - point_coords = point_coords.to(device) - return point_coords, rows_int, cols_int - -class FlowHead(nn.Module): - def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None): - super(FlowHead, self).__init__() - self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) - self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) - - self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) - self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - depth = self.conv2d(self.relu(self.conv1d(x))) - normal = self.conv2n(self.relu(self.conv1n(x))) - return torch.cat((depth, normal), dim=1) - - -class ConvGRU(nn.Module): - def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None): - super(ConvGRU, self).__init__() - self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) - self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) - self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) - - def forward(self, h, cz, cr, cq, *x_list): - x = torch.cat(x_list, dim=1) - hx = torch.cat([h, x], dim=1) - - z = torch.sigmoid((self.convz(hx) + cz)) - r = torch.sigmoid((self.convr(hx) + cr)) - q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq)) - - # z = torch.sigmoid((self.convz(hx) + cz).float()) - # r = torch.sigmoid((self.convr(hx) + cr).float()) - # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float()) - - h = (1-z) * h + z * q - return h - -def pool2x(x): - return F.avg_pool2d(x, 3, stride=2, padding=1) - -def pool4x(x): - return F.avg_pool2d(x, 5, stride=4, padding=1) - -def interp(x, dest): - interp_args = {'mode': 'bilinear', 'align_corners': True} - return interpolate_float32(x, dest.shape[2:], **interp_args) - -class BasicMultiUpdateBlock(nn.Module): - def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None): - super().__init__() - self.args = args - self.n_gru_layers = args.model.decode_head.n_gru_layers # 3 - self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K) - - # self.encoder = BasicMotionEncoder(args) - # encoder_output_dim = 128 # if there is corr volume - encoder_output_dim = 6 # no corr volume - - self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode) - self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode) - self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode) - self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode) - factor = 2**self.n_downsample - - self.mask = nn.Sequential( - Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0), - nn.ReLU(inplace=True), - Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0)) - - def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True): - - if iter32: - net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1])) - if iter16: - if self.n_gru_layers > 2: - net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1])) - else: - net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1])) - if iter08: - if corr is not None: - motion_features = self.encoder(flow, corr) - else: - motion_features = flow - if self.n_gru_layers > 1: - net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0])) - else: - net[0] = self.gru08(net[0], *(inp[0]), motion_features) - - if not update: - return net - - delta_flow = self.flow_head(net[0]) - - # scale mask to balence gradients - mask = .25 * self.mask(net[0]) - return net, mask, delta_flow - -class LayerNorm2d(nn.LayerNorm): - def __init__(self, dim): - super(LayerNorm2d, self).__init__(dim) - - def forward(self, x): - x = x.permute(0, 2, 3, 1).contiguous() - x = super(LayerNorm2d, self).forward(x) - x = x.permute(0, 3, 1, 2).contiguous() - return x - -class ResidualBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None): - super(ResidualBlock, self).__init__() - - self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0) - self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0) - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not (stride == 1 and in_planes == planes): - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(planes) - self.norm2 = nn.BatchNorm2d(planes) - if not (stride == 1 and in_planes == planes): - self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - if not (stride == 1 and in_planes == planes): - self.norm3 = nn.InstanceNorm2d(planes) - - elif norm_fn == 'layer': - self.norm1 = LayerNorm2d(planes) - self.norm2 = LayerNorm2d(planes) - if not (stride == 1 and in_planes == planes): - self.norm3 = LayerNorm2d(planes) - - elif norm_fn == 'none': - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - if not (stride == 1 and in_planes == planes): - self.norm3 = nn.Sequential() - - if stride == 1 and in_planes == planes: - self.downsample = None - - else: - self.downsample = nn.Sequential( - Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0), self.norm3) - - def forward(self, x): - y = x - y = self.conv1(y) - y = self.norm1(y) - y = self.relu(y) - y = self.conv2(y) - y = self.norm2(y) - y = self.relu(y) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x+y) - - -class ContextFeatureEncoder(nn.Module): - ''' - Encoder features are used to: - 1. initialize the hidden state of the update operator - 2. and also injected into the GRU during each iteration of the update operator - ''' - def __init__(self, in_dim, output_dim, tuning_mode=None): - ''' - in_dim = [x4, x8, x16, x32] - output_dim = [hindden_dims, context_dims] - [[x4,x8,x16,x32],[x4,x8,x16,x32]] - ''' - super().__init__() - - output_list = [] - for dim in output_dim: - conv_out = nn.Sequential( - ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode), - Conv2dLoRA(dim[0], dim[0], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) - output_list.append(conv_out) - - self.outputs04 = nn.ModuleList(output_list) - - output_list = [] - for dim in output_dim: - conv_out = nn.Sequential( - ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode), - Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) - output_list.append(conv_out) - - self.outputs08 = nn.ModuleList(output_list) - - output_list = [] - for dim in output_dim: - conv_out = nn.Sequential( - ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode), - Conv2dLoRA(dim[2], dim[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) - output_list.append(conv_out) - - self.outputs16 = nn.ModuleList(output_list) - - # output_list = [] - # for dim in output_dim: - # conv_out = Conv2dLoRA(in_dim[3], dim[3], 3, padding=1) - # output_list.append(conv_out) - - # self.outputs32 = nn.ModuleList(output_list) - - def forward(self, encoder_features): - x_4, x_8, x_16, x_32 = encoder_features - - outputs04 = [f(x_4) for f in self.outputs04] - outputs08 = [f(x_8) for f in self.outputs08] - outputs16 = [f(x_16)for f in self.outputs16] - # outputs32 = [f(x_32) for f in self.outputs32] - - return (outputs04, outputs08, outputs16) - -class ConvBlock(nn.Module): - # reimplementation of DPT - def __init__(self, channels, tuning_mode=None): - super(ConvBlock, self).__init__() - - self.act = nn.ReLU(inplace=True) - self.conv1 = Conv2dLoRA( - channels, - channels, - kernel_size=3, - stride=1, - padding=1, - r = 8 if tuning_mode == 'lora' else 0 - ) - self.conv2 = Conv2dLoRA( - channels, - channels, - kernel_size=3, - stride=1, - padding=1, - r = 8 if tuning_mode == 'lora' else 0 - ) - - def forward(self, x): - out = self.act(x) - out = self.conv1(out) - out = self.act(out) - out = self.conv2(out) - return x + out - -class FuseBlock(nn.Module): - # reimplementation of DPT - def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None): - super(FuseBlock, self).__init__() - - self.fuse = fuse - self.scale_factor = scale_factor - self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode) - if self.fuse: - self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode) - - self.out_conv = Conv2dLoRA( - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - r = 8 if tuning_mode == 'lora' else 0 - ) - self.upsample = upsample - - def forward(self, x1, x2=None): - if x2 is not None: - x2 = self.way_branch(x2) - x1 = x1 + x2 - - out = self.way_trunk(x1) - - if self.upsample: - out = interpolate_float32( - out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True - ) - out = self.out_conv(out) - return out - -class Readout(nn.Module): - # From DPT - def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None): - super(Readout, self).__init__() - self.use_cls_token = use_cls_token - if self.use_cls_token == True: - self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0) - self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0) - self.act = nn.GELU() - else: - self.project = nn.Identity() - - def forward(self, x): - - if self.use_cls_token == True: - x_patch = self.project_patch(x[0]) - x_learn = self.project_learn(x[1]) - x_learn = x_learn.expand_as(x_patch).contiguous() - features = x_patch + x_learn - return self.act(features) - else: - return self.project(x) - -class Token2Feature(nn.Module): - # From DPT - def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None): - super(Token2Feature, self).__init__() - self.scale_factor = scale_factor - self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) - if scale_factor > 1 and isinstance(scale_factor, int): - self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0, - in_channels=vit_channel, - out_channels=feature_channel, - kernel_size=scale_factor, - stride=scale_factor, - padding=0, - ) - - elif scale_factor > 1: - self.sample = nn.Sequential( - # Upsample2(upscale=scale_factor), - # nn.Upsample(scale_factor=scale_factor), - Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0, - in_channels=vit_channel, - out_channels=feature_channel, - kernel_size=1, - stride=1, - padding=0, - ), - ) - - - elif scale_factor < 1: - scale_factor = int(1.0 / scale_factor) - self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0, - in_channels=vit_channel, - out_channels=feature_channel, - kernel_size=scale_factor+1, - stride=scale_factor, - padding=1, - ) - - else: - self.sample = nn.Identity() - - def forward(self, x): - x = self.readoper(x) - #if use_cls_token == True: - x = x.permute(0, 3, 1, 2).contiguous() - if isinstance(self.scale_factor, float): - x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest') - x = self.sample(x) - return x - -class EncoderFeature(nn.Module): - def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None): - super(EncoderFeature, self).__init__() - self.vit_channel = vit_channel - self.num_ch_dec = num_ch_dec - - self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) - self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) - self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) - self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) - - def forward(self, ref_feature): - x = self.read_3(ref_feature[3]) # 1/14 - x2 = self.read_2(ref_feature[2]) # 1/14 - x1 = self.read_1(ref_feature[1]) # 1/7 - x0 = self.read_0(ref_feature[0]) # 1/4 - - return x, x2, x1, x0 - -class DecoderFeature(nn.Module): - def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None): - super(DecoderFeature, self).__init__() - self.vit_channel = vit_channel - self.num_ch_dec = num_ch_dec - - self.upconv_3 = FuseBlock( - self.num_ch_dec[4], - self.num_ch_dec[3], - fuse=False, upsample=False, tuning_mode=tuning_mode) - - self.upconv_2 = FuseBlock( - self.num_ch_dec[3], - self.num_ch_dec[2], - tuning_mode=tuning_mode) - - self.upconv_1 = FuseBlock( - self.num_ch_dec[2], - self.num_ch_dec[1] + 2, - scale_factor=7/4, - tuning_mode=tuning_mode) - - # self.upconv_0 = FuseBlock( - # self.num_ch_dec[1], - # self.num_ch_dec[0] + 1, - # ) - - def forward(self, ref_feature): - x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4 - - x = self.upconv_3(x) # 1/14 - x = self.upconv_2(x, x2) # 1/7 - x = self.upconv_1(x, x1) # 1/4 - # x = self.upconv_0(x, x0) # 4/7 - return x - -class RAFTDepthNormalDPT5(nn.Module): - def __init__(self, cfg): - super().__init__() - self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024] - self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14] - self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14] - self.use_cls_token = cfg.model.decode_head.use_cls_token - self.up_scale = cfg.model.decode_head.up_scale - self.num_register_tokens = cfg.model.decode_head.num_register_tokens - self.min_val = cfg.data_basic.depth_normalize[0] - self.max_val = cfg.data_basic.depth_normalize[1] - self.regress_scale = 100.0\ - - try: - tuning_mode = cfg.model.decode_head.tuning_mode - except: - tuning_mode = None - self.tuning_mode = tuning_mode - - self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128] - self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3 - self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K) - self.iters = cfg.model.decode_head.iters # 22 - self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True - - self.num_depth_regressor_anchor = 256 # 512 - self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res - self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode) - self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode) - self.depth_regressor = nn.Sequential( - Conv2dLoRA(self.used_res_channel, - self.num_depth_regressor_anchor, - kernel_size=3, - padding=1, r = 8 if tuning_mode == 'lora' else 0), - # nn.BatchNorm2d(self.num_depth_regressor_anchor), - nn.ReLU(inplace=True), - Conv2dLoRA(self.num_depth_regressor_anchor, - self.num_depth_regressor_anchor, - kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), - ) - self.normal_predictor = nn.Sequential( - Conv2dLoRA(self.used_res_channel, - 128, - kernel_size=3, - padding=1, r = 8 if tuning_mode == 'lora' else 0,), - # nn.BatchNorm2d(128), - nn.ReLU(inplace=True), - Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), - Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), - Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), - ) - - self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode) - self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)]) - self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode) - - self.relu = nn.ReLU(inplace=True) - - def get_bins(self, bins_num): - depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda") - depth_bins_vec = torch.exp(depth_bins_vec) - return depth_bins_vec - - def register_depth_expectation_anchor(self, bins_num, B): - depth_bins_vec = self.get_bins(bins_num) - depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) - self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) - - def clamp(self, x): - y = self.relu(x - self.min_val) + self.min_val - y = self.max_val - self.relu(self.max_val - y) - return y - - def regress_depth(self, feature_map_d): - prob_feature = self.depth_regressor(feature_map_d) - prob = prob_feature.softmax(dim=1) - #prob = prob_feature.float().softmax(dim=1) - - ## Error logging - if torch.isnan(prob).any(): - print('prob_feat_nan!!!') - if torch.isinf(prob).any(): - print('prob_feat_inf!!!') - - # h = prob[0,:,0,0].cpu().numpy().reshape(-1) - # import matplotlib.pyplot as plt - # plt.bar(range(len(h)), h) - B = prob.shape[0] - if "depth_expectation_anchor" not in self._buffers: - self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B) - d = compute_depth_expectation( - prob, - self.depth_expectation_anchor[:B, ...]).unsqueeze(1) - - ## Error logging - if torch.isnan(d ).any(): - print('d_nan!!!') - if torch.isinf(d ).any(): - print('d_inf!!!') - - return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature - - def pred_normal(self, feature_map, confidence): - normal_out = self.normal_predictor(feature_map) - - ## Error logging - if torch.isnan(normal_out).any(): - print('norm_nan!!!') - if torch.isinf(normal_out).any(): - print('norm_feat_inf!!!') - - return norm_normalize(torch.cat([normal_out, confidence], dim=1)) - #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float()) - - def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True): - y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device), - torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij') - meshgrid = torch.stack((x, y)) - meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1) - #self.register_buffer('meshgrid', meshgrid, persistent=False) - return meshgrid - - def upsample_flow(self, flow, mask): - """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ - N, D, H, W = flow.shape - factor = 2 ** self.n_downsample - mask = mask.view(N, 1, 9, factor, factor, H, W) - mask = torch.softmax(mask, dim=2) - #mask = torch.softmax(mask.float(), dim=2) - - #up_flow = F.unfold(factor * flow, [3,3], padding=1) - up_flow = F.unfold(flow, [3,3], padding=1) - up_flow = up_flow.view(N, D, 9, 1, 1, H, W) - - up_flow = torch.sum(mask * up_flow, dim=2) - up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) - return up_flow.reshape(N, D, factor*H, factor*W) - - def initialize_flow(self, img): - """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" - N, _, H, W = img.shape - - coords0 = coords_grid(N, H, W).to(img.device) - coords1 = coords_grid(N, H, W).to(img.device) - - return coords0, coords1 - - def upsample(self, x, scale_factor=2): - """Upsample input tensor by a factor of 2 - """ - return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest") - - def forward(self, vit_features, **kwargs): - ## read vit token to multi-scale features - B, H, W, _, _, num_register_tokens = vit_features[1] - vit_features = vit_features[0] - - ## Error logging - if torch.isnan(vit_features[0]).any(): - print('vit_feature_nan!!!') - if torch.isinf(vit_features[0]).any(): - print('vit_feature_inf!!!') - - if self.use_cls_token == True: - vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \ - ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features] - else: - vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features] - encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4 - - ## Error logging - for en_ft in encoder_features: - if torch.isnan(en_ft).any(): - print('decoder_feature_nan!!!') - print(en_ft.shape) - if torch.isinf(en_ft).any(): - print('decoder_feature_inf!!!') - print(en_ft.shape) - - ## decode features to init-depth (and confidence) - ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth - - ## Error logging - if torch.isnan(ref_feat).any(): - print('ref_feat_nan!!!') - if torch.isinf(ref_feat).any(): - print('ref_feat_inf!!!') - - feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction - depth_confidence_map = ref_feat[:, -2:-1, :, :] - normal_confidence_map = ref_feat[:, -1:, :, :] - depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth - normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal - - depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W) - - ## encoder features to context-feature for init-hidden-state and contex-features - cnet_list = self.context_feature_encoder(encoder_features[::-1]) - net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state - inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features - - # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning - inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)] - - coords0, coords1 = self.initialize_flow(net_list[0]) - if depth_init is not None: - coords1 = coords1 + depth_init - - if self.training: - low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())] - init_depth = upflow4(depth_init) - flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)] - conf_predictions = [init_depth[:,1:2]] - normal_outs = [norm_normalize(init_depth[:,2:].clone())] - - else: - flow_predictions = [] - conf_predictions = [] - samples_pred_list = [] - coord_list = [] - normal_outs = [] - low_resolution_init = [] - - for itr in range(self.iters): - # coords1 = coords1.detach() - flow = coords1 - coords0 - if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU - net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False) - if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU - net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False) - net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2) - - # F(t+1) = F(t) + \Delta(t) - coords1 = coords1 + delta_flow - - # We do not need to upsample or output intermediate results in test_mode - #if (not self.training) and itr < self.iters-1: - #continue - - # upsample predictions - if up_mask is None: - flow_up = self.upsample(coords1-coords0, 4) - else: - flow_up = self.upsample_flow(coords1 - coords0, up_mask) - # flow_up = self.upsample(coords1-coords0, 4) - - flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val)) - conf_predictions.append(flow_up[:,1:2]) - normal_outs.append(norm_normalize(flow_up[:,2:].clone())) - - outputs=dict( - prediction=flow_predictions[-1], - predictions_list=flow_predictions, - confidence=conf_predictions[-1], - confidence_list=conf_predictions, - pred_logit=None, - # samples_pred_list=samples_pred_list, - # coord_list=coord_list, - prediction_normal=normal_outs[-1], - normal_out_list=normal_outs, - low_resolution_init=low_resolution_init, - ) - - return outputs - - -if __name__ == "__main__": - try: - from mmcv.utils import Config - except: - from mmengine import Config - cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py') - cfg.model.decode_head.in_channels = [384, 384, 384, 384] - cfg.model.decode_head.feature_channels = [96, 192, 384, 768] - cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384] - cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48] - cfg.model.decode_head.up_scale = 7 - - # cfg.model.decode_head.use_cls_token = True - # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ - # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ - # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ - # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]] - - cfg.model.decode_head.use_cls_token = True - cfg.model.decode_head.num_register_tokens = 4 - vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\ - torch.rand((2, (74 * 74) + 5, 384)).cuda(), \ - torch.rand((2, (74 * 74) + 5, 384)).cuda(), \ - torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)] - - decoder = RAFTDepthNormalDPT5(cfg).cuda() - output = decoder(vit_feature) - temp = 1 - - - - diff --git a/mono/model/decode_heads/__init__.py b/mono/model/decode_heads/__init__.py deleted file mode 100644 index 92381a5fc3dad0ca8009c1ab0a153ce6b107c634..0000000000000000000000000000000000000000 --- a/mono/model/decode_heads/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .HourGlassDecoder import HourglassDecoder -from .RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5 - -__all__=['HourglassDecoder', 'RAFTDepthNormalDPT5'] diff --git a/mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc b/mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc deleted file mode 100644 index 47c981bd3124006156222a76cf044e3d5033d77c..0000000000000000000000000000000000000000 Binary files a/mono/model/decode_heads/__pycache__/HourGlassDecoder.cpython-39.pyc and /dev/null differ diff --git a/mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc b/mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index e1433a90f49acb41d64bf291bebc75e844e4bc5b..0000000000000000000000000000000000000000 Binary files a/mono/model/decode_heads/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/mono/model/model_pipelines/__base_model__.py b/mono/model/model_pipelines/__base_model__.py deleted file mode 100644 index d599c418b3d9677a195fe87d45bb31bf1068fbce..0000000000000000000000000000000000000000 --- a/mono/model/model_pipelines/__base_model__.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch -import torch.nn as nn -from mono.utils.comm import get_func - - -class BaseDepthModel(nn.Module): - def __init__(self, cfg, **kwargs) -> None: - super(BaseDepthModel, self).__init__() - model_type = cfg.model.type - self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg) - - def forward(self, data): - output = self.depth_model(**data) - - return output['prediction'], output['confidence'], output - - def inference(self, data): - with torch.no_grad(): - pred_depth, confidence, _ = self.forward(data) - return pred_depth, confidence \ No newline at end of file diff --git a/mono/model/model_pipelines/__init__.py b/mono/model/model_pipelines/__init__.py deleted file mode 100644 index b962a3f858573466e429219c4ad70951b545b637..0000000000000000000000000000000000000000 --- a/mono/model/model_pipelines/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ - -from .dense_pipeline import DensePredModel -from .__base_model__ import BaseDepthModel -__all__ = [ - 'DensePredModel', 'BaseDepthModel', -] \ No newline at end of file diff --git a/mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc b/mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc deleted file mode 100644 index 09f4eaf42b12b9e4820d30f7d2e0a651bef48ad1..0000000000000000000000000000000000000000 Binary files a/mono/model/model_pipelines/__pycache__/__base_model__.cpython-39.pyc and /dev/null differ diff --git a/mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc b/mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index efc35f49a0d2964b808e7900728e404c68ba5435..0000000000000000000000000000000000000000 Binary files a/mono/model/model_pipelines/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc b/mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc deleted file mode 100644 index ee8194cf14465d26eed75d612ba77ff7c19699f9..0000000000000000000000000000000000000000 Binary files a/mono/model/model_pipelines/__pycache__/dense_pipeline.cpython-39.pyc and /dev/null differ diff --git a/mono/model/model_pipelines/dense_pipeline.py b/mono/model/model_pipelines/dense_pipeline.py deleted file mode 100644 index 1362a11b6b9d45e50795dd705906aa3f79ec4a9a..0000000000000000000000000000000000000000 --- a/mono/model/model_pipelines/dense_pipeline.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -import torch.nn as nn -from mono.utils.comm import get_func - -class DensePredModel(nn.Module): - def __init__(self, cfg) -> None: - super(DensePredModel, self).__init__() - - self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone) - self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg) - - def forward(self, input, **kwargs): - # [f_32, f_16, f_8, f_4] - features = self.encoder(input) - out = self.decoder(features, **kwargs) - return out \ No newline at end of file diff --git a/mono/model/monodepth_model.py b/mono/model/monodepth_model.py deleted file mode 100644 index 0b58b7643ee43f84fd4e621e5b3b61b1f3f85564..0000000000000000000000000000000000000000 --- a/mono/model/monodepth_model.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch -import torch.nn as nn -from .model_pipelines.__base_model__ import BaseDepthModel - -class DepthModel(BaseDepthModel): - def __init__(self, cfg, **kwards): - super(DepthModel, self).__init__(cfg) - model_type = cfg.model.type - - def inference(self, data): - with torch.no_grad(): - pred_depth, confidence, output_dict = self.forward(data) - return pred_depth, confidence, output_dict - -def get_monodepth_model( - cfg : dict, - **kwargs - ) -> nn.Module: - # config depth model - model = DepthModel(cfg, **kwargs) - #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath) - assert isinstance(model, nn.Module) - return model - -def get_configured_monodepth_model( - cfg: dict, - ) -> nn.Module: - """ - Args: - @ configs: configures for the network. - @ load_imagenet_model: whether to initialize from ImageNet-pretrained model. - @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with. - Returns: - # model: depth model. - """ - model = get_monodepth_model(cfg) - return model diff --git a/mono/tools/test_scale_cano.py b/mono/tools/test_scale_cano.py deleted file mode 100644 index 684fb841a004833e27edd52192ad0821bf2d43af..0000000000000000000000000000000000000000 --- a/mono/tools/test_scale_cano.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -import os.path as osp -import cv2 -import time -import sys -CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(CODE_SPACE) -import argparse -import mmcv -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -try: - from mmcv.utils import Config, DictAction -except: - from mmengine import Config, DictAction -from datetime import timedelta -import random -import numpy as np -from mono.utils.logger import setup_logger -import glob -from mono.utils.comm import init_env -from mono.model.monodepth_model import get_configured_monodepth_model -from mono.utils.running import load_ckpt -from mono.utils.do_test import do_scalecano_test_with_custom_data -from mono.utils.mldb import load_data_info, reset_ckpt_path -from mono.utils.custom_data import load_from_annos, load_data - -def parse_args(): - parser = argparse.ArgumentParser(description='Train a segmentor') - parser.add_argument('config', help='train config file path') - parser.add_argument('--show-dir', help='the dir to save logs and visualization results') - parser.add_argument('--load-from', help='the checkpoint file to load weights from') - parser.add_argument('--node_rank', type=int, default=0) - parser.add_argument('--nnodes', type=int, default=1, help='number of nodes') - parser.add_argument('--options', nargs='+', action=DictAction, help='custom options') - parser.add_argument('--launcher', choices=['None', 'pytorch', 'slurm', 'mpi', 'ror'], default='slurm', help='job launcher') - parser.add_argument('--test_data_path', default='None', type=str, help='the path of test data') - args = parser.parse_args() - return args - -def main(args): - os.chdir(CODE_SPACE) - cfg = Config.fromfile(args.config) - - if args.options is not None: - cfg.merge_from_dict(args.options) - - # show_dir is determined in this priority: CLI > segment in file > filename - if args.show_dir is not None: - # update configs according to CLI args if args.show_dir is not None - cfg.show_dir = args.show_dir - else: - # use condig filename + timestamp as default show_dir if args.show_dir is None - cfg.show_dir = osp.join('./show_dirs', - osp.splitext(osp.basename(args.config))[0], - args.timestamp) - - # ckpt path - if args.load_from is None: - raise RuntimeError('Please set model path!') - cfg.load_from = args.load_from - - # load data info - data_info = {} - load_data_info('data_info', data_info=data_info) - cfg.mldb_info = data_info - # update check point info - reset_ckpt_path(cfg.model, data_info) - - # create show dir - os.makedirs(osp.abspath(cfg.show_dir), exist_ok=True) - - # init the logger before other steps - cfg.log_file = osp.join(cfg.show_dir, f'{args.timestamp}.log') - logger = setup_logger(cfg.log_file) - - # log some basic info - logger.info(f'Config:\n{cfg.pretty_text}') - - # init distributed env dirst, since logger depends on the dist info - if args.launcher == 'None': - cfg.distributed = False - else: - cfg.distributed = True - init_env(args.launcher, cfg) - logger.info(f'Distributed training: {cfg.distributed}') - - # dump config - cfg.dump(osp.join(cfg.show_dir, osp.basename(args.config))) - test_data_path = args.test_data_path - if not os.path.isabs(test_data_path): - test_data_path = osp.join(CODE_SPACE, test_data_path) - - if 'json' in test_data_path: - test_data = load_from_annos(test_data_path) - else: - test_data = load_data(args.test_data_path) - - if not cfg.distributed: - main_worker(0, cfg, args.launcher, test_data) - else: - # distributed training - if args.launcher == 'ror': - local_rank = cfg.dist_params.local_rank - main_worker(local_rank, cfg, args.launcher, test_data) - else: - mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher, test_data)) - -def main_worker(local_rank: int, cfg: dict, launcher: str, test_data: list): - if cfg.distributed: - cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank - cfg.dist_params.local_rank = local_rank - - if launcher == 'ror': - init_torch_process_group(use_hvd=False) - else: - torch.cuda.set_device(local_rank) - default_timeout = timedelta(minutes=30) - dist.init_process_group( - backend=cfg.dist_params.backend, - init_method=cfg.dist_params.dist_url, - world_size=cfg.dist_params.world_size, - rank=cfg.dist_params.global_rank, - timeout=default_timeout) - - logger = setup_logger(cfg.log_file) - # build model - model = get_configured_monodepth_model(cfg, ) - - # config distributed training - if cfg.distributed: - model = torch.nn.parallel.DistributedDataParallel(model.cuda(), - device_ids=[local_rank], - output_device=local_rank, - find_unused_parameters=True) - else: - model = torch.nn.DataParallel(model).cuda() - - # load ckpt - model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False) - model.eval() - - do_scalecano_test_with_custom_data( - model, - cfg, - test_data, - logger, - cfg.distributed, - local_rank - ) - -if __name__ == '__main__': - args = parse_args() - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - args.timestamp = timestamp - main(args) \ No newline at end of file diff --git a/mono/utils/__init__.py b/mono/utils/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/mono/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/mono/utils/__pycache__/__init__.cpython-39.pyc b/mono/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 1124c910af2391269228568f250c6894520aab54..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/avg_meter.cpython-39.pyc b/mono/utils/__pycache__/avg_meter.cpython-39.pyc deleted file mode 100644 index b12793a25471d7debe0eb2159c47beb8d732a51d..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/avg_meter.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/comm.cpython-39.pyc b/mono/utils/__pycache__/comm.cpython-39.pyc deleted file mode 100644 index 8bb531447618f36553e4a63ef17b8d874a97759c..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/comm.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/custom_data.cpython-39.pyc b/mono/utils/__pycache__/custom_data.cpython-39.pyc deleted file mode 100644 index cbcc9fc30c9e18379a164f674430f83afa87eb78..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/custom_data.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/do_test.cpython-39.pyc b/mono/utils/__pycache__/do_test.cpython-39.pyc deleted file mode 100644 index 9a32049271c1c59666f2f40d6e28775e96be8d64..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/do_test.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/logger.cpython-39.pyc b/mono/utils/__pycache__/logger.cpython-39.pyc deleted file mode 100644 index e74e49882c0ff109592180b63ed208e949db92fa..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/logger.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/mldb.cpython-39.pyc b/mono/utils/__pycache__/mldb.cpython-39.pyc deleted file mode 100644 index a19b422fb3981948973a2fd58165f160bf9e2824..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/mldb.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/running.cpython-39.pyc b/mono/utils/__pycache__/running.cpython-39.pyc deleted file mode 100644 index 63270065ee2d57c2b20eeb400302824402cb0738..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/running.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/transform.cpython-39.pyc b/mono/utils/__pycache__/transform.cpython-39.pyc deleted file mode 100644 index e7886dadc9311f0ae919a928fc157f377537b4ec..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/transform.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/unproj_pcd.cpython-39.pyc b/mono/utils/__pycache__/unproj_pcd.cpython-39.pyc deleted file mode 100644 index 54df84153aadb93b17d666f2a058a7c53543ba2d..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/unproj_pcd.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/__pycache__/visualization.cpython-39.pyc b/mono/utils/__pycache__/visualization.cpython-39.pyc deleted file mode 100644 index 908f7aa5fecdfbc13df19333c0a8ac19a49d75a5..0000000000000000000000000000000000000000 Binary files a/mono/utils/__pycache__/visualization.cpython-39.pyc and /dev/null differ diff --git a/mono/utils/avg_meter.py b/mono/utils/avg_meter.py deleted file mode 100644 index 3f935df9760cee1d73c6cba00b954d03e659ccb3..0000000000000000000000000000000000000000 --- a/mono/utils/avg_meter.py +++ /dev/null @@ -1,475 +0,0 @@ -import numpy as np -import torch -import torch.distributed as dist -import torch.nn.functional as F -import matplotlib.pyplot as plt - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self) -> None: - self.reset() - - def reset(self) -> None: - self.val = np.longdouble(0.0) - self.avg = np.longdouble(0.0) - self.sum = np.longdouble(0.0) - self.count = np.longdouble(0.0) - - def update(self, val, n: float = 1) -> None: - self.val = val - self.sum += val - self.count += n - self.avg = self.sum / (self.count + 1e-6) - -class MetricAverageMeter(AverageMeter): - """ - An AverageMeter designed specifically for evaluating segmentation results. - """ - def __init__(self, metrics: list) -> None: - """ Initialize object. """ - # average meters for metrics - self.abs_rel = AverageMeter() - self.rmse = AverageMeter() - self.silog = AverageMeter() - self.delta1 = AverageMeter() - self.delta2 = AverageMeter() - self.delta3 = AverageMeter() - - self.metrics = metrics - - self.consistency = AverageMeter() - self.log10 = AverageMeter() - self.rmse_log = AverageMeter() - self.sq_rel = AverageMeter() - - # normal - self.normal_mean = AverageMeter() - self.normal_rmse = AverageMeter() - self.normal_a1 = AverageMeter() - self.normal_a2 = AverageMeter() - - self.normal_median = AverageMeter() - self.normal_a3 = AverageMeter() - self.normal_a4 = AverageMeter() - self.normal_a5 = AverageMeter() - - - def update_metrics_cpu(self, - pred: torch.Tensor, - target: torch.Tensor, - mask: torch.Tensor,): - """ - Update metrics on cpu - """ - - assert pred.shape == target.shape - - if len(pred.shape) == 3: - pred = pred[:, None, :, :] - target = target[:, None, :, :] - mask = mask[:, None, :, :] - elif len(pred.shape) == 2: - pred = pred[None, None, :, :] - target = target[None, None, :, :] - mask = mask[None, None, :, :] - - - # Absolute relative error - abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask) - abs_rel_sum = abs_rel_sum.numpy() - valid_pics = valid_pics.numpy() - self.abs_rel.update(abs_rel_sum, valid_pics) - - # squared relative error - sqrel_sum, _ = get_sqrel_err(pred, target, mask) - sqrel_sum = sqrel_sum.numpy() - self.sq_rel.update(sqrel_sum, valid_pics) - - # root mean squared error - rmse_sum, _ = get_rmse_err(pred, target, mask) - rmse_sum = rmse_sum.numpy() - self.rmse.update(rmse_sum, valid_pics) - - # log root mean squared error - log_rmse_sum, _ = get_rmse_log_err(pred, target, mask) - log_rmse_sum = log_rmse_sum.numpy() - self.rmse.update(log_rmse_sum, valid_pics) - - # log10 error - log10_sum, _ = get_log10_err(pred, target, mask) - log10_sum = log10_sum.numpy() - self.rmse.update(log10_sum, valid_pics) - - # scale-invariant root mean squared error in log space - silog_sum, _ = get_silog_err(pred, target, mask) - silog_sum = silog_sum.numpy() - self.silog.update(silog_sum, valid_pics) - - # ratio error, delta1, .... - delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask) - delta1_sum = delta1_sum.numpy() - delta2_sum = delta2_sum.numpy() - delta3_sum = delta3_sum.numpy() - - self.delta1.update(delta1_sum, valid_pics) - self.delta2.update(delta1_sum, valid_pics) - self.delta3.update(delta1_sum, valid_pics) - - - def update_metrics_gpu( - self, - pred: torch.Tensor, - target: torch.Tensor, - mask: torch.Tensor, - is_distributed: bool, - pred_next: torch.tensor = None, - pose_f1_to_f2: torch.tensor = None, - intrinsic: torch.tensor = None): - """ - Update metric on GPU. It supports distributed processing. If multiple machines are employed, please - set 'is_distributed' as True. - """ - assert pred.shape == target.shape - - if len(pred.shape) == 3: - pred = pred[:, None, :, :] - target = target[:, None, :, :] - mask = mask[:, None, :, :] - elif len(pred.shape) == 2: - pred = pred[None, None, :, :] - target = target[None, None, :, :] - mask = mask[None, None, :, :] - - - # Absolute relative error - abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask) - if is_distributed: - dist.all_reduce(abs_rel_sum), dist.all_reduce(valid_pics) - abs_rel_sum = abs_rel_sum.cpu().numpy() - valid_pics = int(valid_pics) - self.abs_rel.update(abs_rel_sum, valid_pics) - - # root mean squared error - rmse_sum, _ = get_rmse_err(pred, target, mask) - if is_distributed: - dist.all_reduce(rmse_sum) - rmse_sum = rmse_sum.cpu().numpy() - self.rmse.update(rmse_sum, valid_pics) - - # log root mean squared error - log_rmse_sum, _ = get_rmse_log_err(pred, target, mask) - if is_distributed: - dist.all_reduce(log_rmse_sum) - log_rmse_sum = log_rmse_sum.cpu().numpy() - self.rmse_log.update(log_rmse_sum, valid_pics) - - # log10 error - log10_sum, _ = get_log10_err(pred, target, mask) - if is_distributed: - dist.all_reduce(log10_sum) - log10_sum = log10_sum.cpu().numpy() - self.log10.update(log10_sum, valid_pics) - - # scale-invariant root mean squared error in log space - silog_sum, _ = get_silog_err(pred, target, mask) - if is_distributed: - dist.all_reduce(silog_sum) - silog_sum = silog_sum.cpu().numpy() - self.silog.update(silog_sum, valid_pics) - - # ratio error, delta1, .... - delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask) - if is_distributed: - dist.all_reduce(delta1_sum), dist.all_reduce(delta2_sum), dist.all_reduce(delta3_sum) - delta1_sum = delta1_sum.cpu().numpy() - delta2_sum = delta2_sum.cpu().numpy() - delta3_sum = delta3_sum.cpu().numpy() - - self.delta1.update(delta1_sum, valid_pics) - self.delta2.update(delta2_sum, valid_pics) - self.delta3.update(delta3_sum, valid_pics) - - # video consistency error - consistency_rel_sum, valid_warps = get_video_consistency_err(pred, pred_next, pose_f1_to_f2, intrinsic) - if is_distributed: - dist.all_reduce(consistency_rel_sum), dist.all_reduce(valid_warps) - consistency_rel_sum = consistency_rel_sum.cpu().numpy() - valid_warps = int(valid_warps) - self.consistency.update(consistency_rel_sum, valid_warps) - - ## for surface normal - def update_normal_metrics_gpu( - self, - pred: torch.Tensor, # (B, 3, H, W) - target: torch.Tensor, # (B, 3, H, W) - mask: torch.Tensor, # (B, 1, H, W) - is_distributed: bool, - ): - """ - Update metric on GPU. It supports distributed processing. If multiple machines are employed, please - set 'is_distributed' as True. - """ - assert pred.shape == target.shape - - valid_pics = torch.sum(mask, dtype=torch.float32) + 1e-6 - - if valid_pics < 10: - return - - mean_error = rmse_error = a1_error = a2_error = dist_node_cnt = valid_pics - normal_error = torch.cosine_similarity(pred, target, dim=1) - normal_error = torch.clamp(normal_error, min=-1.0, max=1.0) - angle_error = torch.acos(normal_error) * 180.0 / torch.pi - angle_error = angle_error[:, None, :, :] - angle_error = angle_error[mask] - # Calculation error - mean_error = angle_error.sum() / valid_pics - rmse_error = torch.sqrt( torch.sum(torch.square(angle_error)) / valid_pics ) - median_error = angle_error.median() - a1_error = 100.0 * (torch.sum(angle_error < 5) / valid_pics) - a2_error = 100.0 * (torch.sum(angle_error < 7.5) / valid_pics) - - a3_error = 100.0 * (torch.sum(angle_error < 11.25) / valid_pics) - a4_error = 100.0 * (torch.sum(angle_error < 22.5) / valid_pics) - a5_error = 100.0 * (torch.sum(angle_error < 30) / valid_pics) - - # if valid_pics > 1e-5: - # If the current node gets data with valid normal - dist_node_cnt = (valid_pics - 1e-6) / valid_pics - - if is_distributed: - dist.all_reduce(dist_node_cnt) - dist.all_reduce(mean_error) - dist.all_reduce(rmse_error) - dist.all_reduce(a1_error) - dist.all_reduce(a2_error) - - dist.all_reduce(a3_error) - dist.all_reduce(a4_error) - dist.all_reduce(a5_error) - - dist_node_cnt = dist_node_cnt.cpu().numpy() - self.normal_mean.update(mean_error.cpu().numpy(), dist_node_cnt) - self.normal_rmse.update(rmse_error.cpu().numpy(), dist_node_cnt) - self.normal_a1.update(a1_error.cpu().numpy(), dist_node_cnt) - self.normal_a2.update(a2_error.cpu().numpy(), dist_node_cnt) - - self.normal_median.update(median_error.cpu().numpy(), dist_node_cnt) - self.normal_a3.update(a3_error.cpu().numpy(), dist_node_cnt) - self.normal_a4.update(a4_error.cpu().numpy(), dist_node_cnt) - self.normal_a5.update(a5_error.cpu().numpy(), dist_node_cnt) - - - def get_metrics(self,): - """ - """ - metrics_dict = {} - for metric in self.metrics: - metrics_dict[metric] = self.__getattribute__(metric).avg - return metrics_dict - - - def get_metrics(self,): - """ - """ - metrics_dict = {} - for metric in self.metrics: - metrics_dict[metric] = self.__getattribute__(metric).avg - return metrics_dict - -def get_absrel_err(pred: torch.tensor, - target: torch.tensor, - mask: torch.tensor, - ): - """ - Computes absolute relative error. - Tasks preprocessed depths (no nans, infs and non-positive values). - pred, target, and mask should be in the shape of [b, c, h, w] - """ - - assert len(pred.shape) == 4, len(target.shape) == 4 - b, c, h, w = pred.shape - mask = mask.to(torch.float) - t_m = target * mask - p_m = pred * mask - - # Mean Absolute Relative Error - rel = torch.abs(t_m - p_m) / (t_m + 1e-10) # compute errors - abs_rel_sum = torch.sum(rel.reshape((b, c, -1)), dim=2) # [b, c] - num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] - abs_err = abs_rel_sum / (num + 1e-10) - valid_pics = torch.sum(num > 0) - return torch.sum(abs_err), valid_pics - -def get_sqrel_err(pred: torch.tensor, - target: torch.tensor, - mask: torch.tensor, - ): - """ - Computes squared relative error. - Tasks preprocessed depths (no nans, infs and non-positive values). - pred, target, and mask should be in the shape of [b, c, h, w] - """ - - assert len(pred.shape) == 4, len(target.shape) == 4 - b, c, h, w = pred.shape - mask = mask.to(torch.float) - t_m = target * mask - p_m = pred * mask - - # squared Relative Error - sq_rel = torch.abs(t_m - p_m) ** 2 / (t_m + 1e-10) # compute errors - sq_rel_sum = torch.sum(sq_rel.reshape((b, c, -1)), dim=2) # [b, c] - num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] - sqrel_err = sq_rel_sum / (num + 1e-10) - valid_pics = torch.sum(num > 0) - return torch.sum(sqrel_err), valid_pics - -def get_log10_err(pred: torch.tensor, - target: torch.tensor, - mask: torch.tensor, - ): - """ - Computes log10 error. - Tasks preprocessed depths (no nans, infs and non-positive values). - pred, target, and mask should be in the shape of [b, c, h, w] - """ - - assert len(pred.shape) == 4, len(target.shape) == 4 - b, c, h, w = pred.shape - mask = mask.to(torch.float) - t_m = target * mask - p_m = pred * mask - - diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask - log10_diff = torch.abs(diff_log) - log10_sum = torch.sum(log10_diff.reshape((b, c, -1)), dim=2) # [b, c] - num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] - log10_err = log10_sum / (num + 1e-10) - valid_pics = torch.sum(num > 0) - return torch.sum(log10_err), valid_pics - -def get_rmse_err(pred: torch.tensor, - target: torch.tensor, - mask: torch.tensor, - ): - """ - Computes rmse error. - Tasks preprocessed depths (no nans, infs and non-positive values). - pred, target, and mask should be in the shape of [b, c, h, w] - """ - - assert len(pred.shape) == 4, len(target.shape) == 4 - b, c, h, w = pred.shape - mask = mask.to(torch.float) - t_m = target * mask - p_m = pred * mask - - square = (t_m - p_m) ** 2 - rmse_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c] - num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] - rmse = torch.sqrt(rmse_sum / (num + 1e-10)) - valid_pics = torch.sum(num > 0) - return torch.sum(rmse), valid_pics - -def get_rmse_log_err(pred: torch.tensor, - target: torch.tensor, - mask: torch.tensor, - ): - """ - Computes log rmse error. - Tasks preprocessed depths (no nans, infs and non-positive values). - pred, target, and mask should be in the shape of [b, c, h, w] - """ - - assert len(pred.shape) == 4, len(target.shape) == 4 - b, c, h, w = pred.shape - mask = mask.to(torch.float) - t_m = target * mask - p_m = pred * mask - - diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask - square = diff_log ** 2 - rmse_log_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c] - num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] - rmse_log = torch.sqrt(rmse_log_sum / (num + 1e-10)) - valid_pics = torch.sum(num > 0) - return torch.sum(rmse_log), valid_pics - -def get_silog_err(pred: torch.tensor, - target: torch.tensor, - mask: torch.tensor, - ): - """ - Computes log rmse error. - Tasks preprocessed depths (no nans, infs and non-positive values). - pred, target, and mask should be in the shape of [b, c, h, w] - """ - - assert len(pred.shape) == 4, len(target.shape) == 4 - b, c, h, w = pred.shape - mask = mask.to(torch.float) - t_m = target * mask - p_m = pred * mask - - diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask - diff_log_sum = torch.sum(diff_log.reshape((b, c, -1)), dim=2) # [b, c] - diff_log_square = diff_log ** 2 - diff_log_square_sum = torch.sum(diff_log_square.reshape((b, c, -1)), dim=2) # [b, c] - num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] - silog = torch.sqrt(diff_log_square_sum / (num + 1e-10) - (diff_log_sum / (num + 1e-10)) ** 2) - valid_pics = torch.sum(num > 0) - return torch.sum(silog), valid_pics - -def get_ratio_err(pred: torch.tensor, - target: torch.tensor, - mask: torch.tensor, - ): - """ - Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold. - Tasks preprocessed depths (no nans, infs and non-positive values). - pred, target, and mask should be in the shape of [b, c, h, w] - """ - assert len(pred.shape) == 4, len(target.shape) == 4 - b, c, h, w = pred.shape - mask = mask.to(torch.float) - t_m = target * mask - p_m = pred - - gt_pred = t_m / (p_m + 1e-10) - pred_gt = p_m / (t_m + 1e-10) - gt_pred = gt_pred.reshape((b, c, -1)) - pred_gt = pred_gt.reshape((b, c, -1)) - gt_pred_gt = torch.cat((gt_pred, pred_gt), axis=1) - ratio_max = torch.amax(gt_pred_gt, axis=1) - - delta_1_sum = torch.sum((ratio_max < 1.25), dim=1) # [b, ] - delta_2_sum = torch.sum((ratio_max < 1.25 ** 2), dim=1) # [b, ] - delta_3_sum = torch.sum((ratio_max < 1.25 ** 3), dim=1) # [b, ] - num = torch.sum(mask.reshape((b, -1)), dim=1) # [b, ] - - delta_1 = delta_1_sum / (num + 1e-10) - delta_2 = delta_2_sum / (num + 1e-10) - delta_3 = delta_3_sum / (num + 1e-10) - valid_pics = torch.sum(num > 0) - - return torch.sum(delta_1), torch.sum(delta_2), torch.sum(delta_3), valid_pics - - -if __name__ == '__main__': - cfg = ['abs_rel', 'delta1'] - dam = MetricAverageMeter(cfg) - - pred_depth = np.random.random([2, 480, 640]) - gt_depth = np.random.random([2, 480, 640]) - 0.5 - intrinsic = [[100, 100, 200, 200], [200, 200, 300, 300]] - - pred = torch.from_numpy(pred_depth).cuda() - gt = torch.from_numpy(gt_depth).cuda() - - mask = gt > 0 - dam.update_metrics_gpu(pred, gt, mask, False) - eval_error = dam.get_metrics() - print(eval_error) - \ No newline at end of file diff --git a/mono/utils/comm.py b/mono/utils/comm.py deleted file mode 100644 index 939e4e175c14563d5d13e77e6b56fd1a34668ebf..0000000000000000000000000000000000000000 --- a/mono/utils/comm.py +++ /dev/null @@ -1,322 +0,0 @@ -import importlib -import torch -import torch.distributed as dist -from .avg_meter import AverageMeter -from collections import defaultdict, OrderedDict -import os -import socket -from mmcv.utils import collect_env as collect_base_env -try: - from mmcv.utils import get_git_hash -except: - from mmengine.utils import get_git_hash -#import mono.mmseg as mmseg -# import mmseg -import time -import datetime -import logging - - -def main_process() -> bool: - return get_rank() == 0 - #return not cfg.distributed or \ - # (cfg.distributed and cfg.local_rank == 0) - -def get_world_size() -> int: - if not dist.is_available(): - return 1 - if not dist.is_initialized(): - return 1 - return dist.get_world_size() - -def get_rank() -> int: - if not dist.is_available(): - return 0 - if not dist.is_initialized(): - return 0 - return dist.get_rank() - -def _find_free_port(): - # refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # Binding to port 0 will cause the OS to find an available port for us - sock.bind(('', 0)) - port = sock.getsockname()[1] - sock.close() - # NOTE: there is still a chance the port could be taken by other processes. - return port - -def _is_free_port(port): - ips = socket.gethostbyname_ex(socket.gethostname())[-1] - ips.append('localhost') - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return all(s.connect_ex((ip, port)) != 0 for ip in ips) - - -# def collect_env(): -# """Collect the information of the running environments.""" -# env_info = collect_base_env() -# env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' - -# return env_info - -def init_env(launcher, cfg): - """Initialize distributed training environment. - If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system - environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system - environment variable, then a default port ``29500`` will be used. - """ - if launcher == 'slurm': - _init_dist_slurm(cfg) - elif launcher == 'ror': - _init_dist_ror(cfg) - elif launcher == 'None': - _init_none_dist(cfg) - else: - raise RuntimeError(f'{cfg.launcher} has not been supported!') - -def _init_none_dist(cfg): - cfg.dist_params.num_gpus_per_node = 1 - cfg.dist_params.world_size = 1 - cfg.dist_params.nnodes = 1 - cfg.dist_params.node_rank = 0 - cfg.dist_params.global_rank = 0 - cfg.dist_params.local_rank = 0 - os.environ["WORLD_SIZE"] = str(1) - -def _init_dist_ror(cfg): - from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size - cfg.dist_params.num_gpus_per_node = get_local_size() - cfg.dist_params.world_size = get_world_size() - cfg.dist_params.nnodes = (get_world_size()) // (get_local_size()) - cfg.dist_params.node_rank = get_node_rank() - cfg.dist_params.global_rank = get_world_rank() - cfg.dist_params.local_rank = get_local_rank() - os.environ["WORLD_SIZE"] = str(get_world_size()) - - -def _init_dist_slurm(cfg): - if 'NNODES' not in os.environ: - os.environ['NNODES'] = str(cfg.dist_params.nnodes) - if 'NODE_RANK' not in os.environ: - os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank) - - #cfg.dist_params. - num_gpus = torch.cuda.device_count() - world_size = int(os.environ['NNODES']) * num_gpus - os.environ['WORLD_SIZE'] = str(world_size) - - # config port - if 'MASTER_PORT' in os.environ: - master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable - else: - # if torch.distributed default port(29500) is available - # then use it, else find a free port - if _is_free_port(16500): - master_port = '16500' - else: - master_port = str(_find_free_port()) - os.environ['MASTER_PORT'] = master_port - - # config addr - if 'MASTER_ADDR' in os.environ: - master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable - # elif cfg.dist_params.dist_url is not None: - # master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2]) - else: - master_addr = '127.0.0.1' #'tcp://127.0.0.1' - os.environ['MASTER_ADDR'] = master_addr - - # set dist_url to 'env://' - cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}" - - cfg.dist_params.num_gpus_per_node = num_gpus - cfg.dist_params.world_size = world_size - cfg.dist_params.nnodes = int(os.environ['NNODES']) - cfg.dist_params.node_rank = int(os.environ['NODE_RANK']) - - # if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"): - # raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://") - - -def get_func(func_name): - """ - Helper to return a function object by name. func_name must identify - a function in this module or the path to a function relative to the base - module. - @ func_name: function name. - """ - if func_name == '': - return None - try: - parts = func_name.split('.') - # Refers to a function in this module - if len(parts) == 1: - return globals()[parts[0]] - # Otherwise, assume we're referencing a module under modeling - module_name = '.'.join(parts[:-1]) - module = importlib.import_module(module_name) - return getattr(module, parts[-1]) - except: - raise RuntimeError(f'Failed to find function: {func_name}') - -class Timer(object): - """A simple timer.""" - - def __init__(self): - self.reset() - - def tic(self): - # using time.time instead of time.clock because time time.clock - # does not normalize for multithreading - self.start_time = time.time() - - def toc(self, average=True): - self.diff = time.time() - self.start_time - self.total_time += self.diff - self.calls += 1 - self.average_time = self.total_time / self.calls - if average: - return self.average_time - else: - return self.diff - - def reset(self): - self.total_time = 0. - self.calls = 0 - self.start_time = 0. - self.diff = 0. - self.average_time = 0. - -class TrainingStats(object): - """Track vital training statistics.""" - def __init__(self, log_period, tensorboard_logger=None): - self.log_period = log_period - self.tblogger = tensorboard_logger - self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time'] - self.iter_timer = Timer() - # Window size for smoothing tracked values (with median filtering) - self.filter_size = log_period - def create_smoothed_value(): - return AverageMeter() - self.smoothed_losses = defaultdict(create_smoothed_value) - #self.smoothed_metrics = defaultdict(create_smoothed_value) - #self.smoothed_total_loss = AverageMeter() - - - def IterTic(self): - self.iter_timer.tic() - - def IterToc(self): - return self.iter_timer.toc(average=False) - - def reset_iter_time(self): - self.iter_timer.reset() - - def update_iter_stats(self, losses_dict): - """Update tracked iteration statistics.""" - for k, v in losses_dict.items(): - self.smoothed_losses[k].update(float(v), 1) - - def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}): - """Log the tracked statistics.""" - if (cur_iter % self.log_period == 0): - stats = self.get_stats(cur_iter, optimizer, max_iters, val_err) - log_stats(stats) - if self.tblogger: - self.tb_log_stats(stats, cur_iter) - for k, v in self.smoothed_losses.items(): - v.reset() - - def tb_log_stats(self, stats, cur_iter): - """Log the tracked statistics to tensorboard""" - for k in stats: - # ignore some logs - if k not in self.tb_ignored_keys: - v = stats[k] - if isinstance(v, dict): - self.tb_log_stats(v, cur_iter) - else: - self.tblogger.add_scalar(k, v, cur_iter) - - - def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}): - eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter) - - eta = str(datetime.timedelta(seconds=int(eta_seconds))) - stats = OrderedDict( - iter=cur_iter, # 1-indexed - time=self.iter_timer.average_time, - eta=eta, - ) - optimizer_state_dict = optimizer.state_dict() - lr = {} - for i in range(len(optimizer_state_dict['param_groups'])): - lr_name = 'group%d_lr' % i - lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr'] - - stats['lr'] = OrderedDict(lr) - for k, v in self.smoothed_losses.items(): - stats[k] = v.avg - - stats['val_err'] = OrderedDict(val_err) - stats['max_iters'] = max_iters - return stats - - -def reduce_dict(input_dict, average=True): - """ - Reduce the values in the dictionary from all processes so that process with rank - 0 has the reduced results. - Args: - @input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. - @average (bool): whether to do average or sum - Returns: - a dict with the same keys as input_dict, after reduction. - """ - world_size = get_world_size() - if world_size < 2: - return input_dict - with torch.no_grad(): - names = [] - values = [] - # sort the keys so that they are consistent across processes - for k in sorted(input_dict.keys()): - names.append(k) - values.append(input_dict[k]) - values = torch.stack(values, dim=0) - dist.reduce(values, dst=0) - if dist.get_rank() == 0 and average: - # only main process gets accumulated, so only divide by - # world_size in this case - values /= world_size - reduced_dict = {k: v for k, v in zip(names, values)} - return reduced_dict - - -def log_stats(stats): - logger = logging.getLogger() - """Log training statistics to terminal""" - lines = "[Step %d/%d]\n" % ( - stats['iter'], stats['max_iters']) - - lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % ( - stats['total_loss'], stats['time'], stats['eta']) - - # log loss - lines += "\t\t" - for k, v in stats.items(): - if 'loss' in k.lower() and 'total_loss' not in k.lower(): - lines += "%s: %.3f" % (k, v) + ", " - lines = lines[:-3] - lines += '\n' - - # validate criteria - lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", " - lines += '\n' - - # lr in different groups - lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items()) - lines += '\n' - logger.info(lines[:-1]) # remove last new linen_pxl - diff --git a/mono/utils/custom_data.py b/mono/utils/custom_data.py deleted file mode 100644 index d9fab47478bc471c51b5454cc15550079ebec21b..0000000000000000000000000000000000000000 --- a/mono/utils/custom_data.py +++ /dev/null @@ -1,34 +0,0 @@ -import glob -import os -import json -import cv2 - -def load_from_annos(anno_path): - with open(anno_path, 'r') as f: - annos = json.load(f)['files'] - - datas = [] - for i, anno in enumerate(annos): - rgb = anno['rgb'] - depth = anno['depth'] if 'depth' in anno else None - depth_scale = anno['depth_scale'] if 'depth_scale' in anno else 1.0 - intrinsic = anno['cam_in'] if 'cam_in' in anno else None - normal = anno['normal'] if 'normal' in anno else None - - data_i = { - 'rgb': rgb, - 'depth': depth, - 'depth_scale': depth_scale, - 'intrinsic': intrinsic, - 'filename': os.path.basename(rgb), - 'folder': rgb.split('/')[-3], - 'normal': normal - } - datas.append(data_i) - return datas - -def load_data(path: str): - rgbs = glob.glob(path + '/*.jpg') + glob.glob(path + '/*.png') - #intrinsic = [835.8179931640625, 835.8179931640625, 961.5419921875, 566.8090209960938] #[721.53769, 721.53769, 609.5593, 172.854] - data = [{'rgb': i, 'depth': None, 'intrinsic': None, 'filename': os.path.basename(i), 'folder': i.split('/')[-3]} for i in rgbs] - return data \ No newline at end of file diff --git a/mono/utils/do_test.py b/mono/utils/do_test.py deleted file mode 100644 index 89ee4afc9d6cd67ec491af6726c850347cafc099..0000000000000000000000000000000000000000 --- a/mono/utils/do_test.py +++ /dev/null @@ -1,364 +0,0 @@ -import torch -import torch.nn.functional as F -import logging -import os -import os.path as osp -from mono.utils.avg_meter import MetricAverageMeter -from mono.utils.visualization import save_val_imgs, create_html, save_raw_imgs, save_normal_val_imgs -import cv2 -from tqdm import tqdm -import numpy as np -from PIL import Image -import matplotlib.pyplot as plt - -from mono.utils.unproj_pcd import reconstruct_pcd, save_point_cloud - -def to_cuda(data: dict): - for k, v in data.items(): - if isinstance(v, torch.Tensor): - data[k] = v.cuda(non_blocking=True) - if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor): - for i, l_i in enumerate(v): - data[k][i] = l_i.cuda(non_blocking=True) - return data - -def align_scale(pred: torch.tensor, target: torch.tensor): - mask = target > 0 - if torch.sum(mask) > 10: - scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) - else: - scale = 1 - pred_scaled = pred * scale - return pred_scaled, scale - -def align_scale_shift(pred: torch.tensor, target: torch.tensor): - mask = target > 0 - target_mask = target[mask].cpu().numpy() - pred_mask = pred[mask].cpu().numpy() - if torch.sum(mask) > 10: - scale, shift = np.polyfit(pred_mask, target_mask, deg=1) - if scale < 0: - scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) - shift = 0 - else: - scale = 1 - shift = 0 - pred = pred * scale + shift - return pred, scale - -def align_scale_shift_numpy(pred: np.array, target: np.array): - mask = target > 0 - target_mask = target[mask] - pred_mask = pred[mask] - if np.sum(mask) > 10: - scale, shift = np.polyfit(pred_mask, target_mask, deg=1) - if scale < 0: - scale = np.median(target[mask]) / (np.median(pred[mask]) + 1e-8) - shift = 0 - else: - scale = 1 - shift = 0 - pred = pred * scale + shift - return pred, scale - - -def build_camera_model(H : int, W : int, intrinsics : list) -> np.array: - """ - Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map. - """ - fx, fy, u0, v0 = intrinsics - f = (fx + fy) / 2.0 - # principle point location - x_row = np.arange(0, W).astype(np.float32) - x_row_center_norm = (x_row - u0) / W - x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W] - - y_col = np.arange(0, H).astype(np.float32) - y_col_center_norm = (y_col - v0) / H - y_center = np.tile(y_col_center_norm, (W, 1)).T # [H, W] - - # FoV - fov_x = np.arctan(x_center / (f / W)) - fov_y = np.arctan(y_center / (f / H)) - - cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2) - return cam_model - -def resize_for_input(image, output_shape, intrinsic, canonical_shape, to_canonical_ratio): - """ - Resize the input. - Resizing consists of two processed, i.e. 1) to the canonical space (adjust the camera model); 2) resize the image while the camera model holds. Thus the - label will be scaled with the resize factor. - """ - padding = [123.675, 116.28, 103.53] - h, w, _ = image.shape - resize_ratio_h = output_shape[0] / canonical_shape[0] - resize_ratio_w = output_shape[1] / canonical_shape[1] - to_scale_ratio = min(resize_ratio_h, resize_ratio_w) - - resize_ratio = to_canonical_ratio * to_scale_ratio - - reshape_h = int(resize_ratio * h) - reshape_w = int(resize_ratio * w) - - pad_h = max(output_shape[0] - reshape_h, 0) - pad_w = max(output_shape[1] - reshape_w, 0) - pad_h_half = int(pad_h / 2) - pad_w_half = int(pad_w / 2) - - # resize - image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) - # padding - image = cv2.copyMakeBorder( - image, - pad_h_half, - pad_h - pad_h_half, - pad_w_half, - pad_w - pad_w_half, - cv2.BORDER_CONSTANT, - value=padding) - - # Resize, adjust principle point - intrinsic[2] = intrinsic[2] * to_scale_ratio - intrinsic[3] = intrinsic[3] * to_scale_ratio - - cam_model = build_camera_model(reshape_h, reshape_w, intrinsic) - cam_model = cv2.copyMakeBorder( - cam_model, - pad_h_half, - pad_h - pad_h_half, - pad_w_half, - pad_w - pad_w_half, - cv2.BORDER_CONSTANT, - value=-1) - - pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] - label_scale_factor=1/to_scale_ratio - return image, cam_model, pad, label_scale_factor - - -def get_prediction( - model: torch.nn.Module, - input: torch.tensor, - cam_model: torch.tensor, - pad_info: torch.tensor, - scale_info: torch.tensor, - gt_depth: torch.tensor, - normalize_scale: float, - ori_shape: list=[], -): - - data = dict( - input=input, - cam_model=cam_model, - ) - pred_depth, confidence, output_dict = model.module.inference(data) - pred_depth = pred_depth - pred_depth = pred_depth.squeeze() - pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]] - if gt_depth is not None: - resize_shape = gt_depth.shape - elif ori_shape != []: - resize_shape = ori_shape - else: - resize_shape = pred_depth.shape - - pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], resize_shape, mode='bilinear').squeeze() # to original size - pred_depth = pred_depth * normalize_scale / scale_info - if gt_depth is not None: - pred_depth_scale, scale = align_scale(pred_depth, gt_depth) - else: - pred_depth_scale = None - scale = None - - return pred_depth, pred_depth_scale, scale, output_dict - -def transform_test_data_scalecano(rgb, intrinsic, data_basic): - """ - Pre-process the input for forwarding. Employ `label scale canonical transformation.' - Args: - rgb: input rgb image. [H, W, 3] - intrinsic: camera intrinsic parameter, [fx, fy, u0, v0] - data_basic: predefined canonical space in configs. - """ - canonical_space = data_basic['canonical_space'] - forward_size = data_basic.crop_size - mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] - std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None] - - # BGR to RGB - rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) - - ori_h, ori_w, _ = rgb.shape - ori_focal = (intrinsic[0] + intrinsic[1]) / 2 - canonical_focal = canonical_space['focal_length'] - - cano_label_scale_ratio = canonical_focal / ori_focal - - canonical_intrinsic = [ - intrinsic[0] * cano_label_scale_ratio, - intrinsic[1] * cano_label_scale_ratio, - intrinsic[2], - intrinsic[3], - ] - - # resize - rgb, cam_model, pad, resize_label_scale_ratio = resize_for_input(rgb, forward_size, canonical_intrinsic, [ori_h, ori_w], 1.0) - - # label scale factor - label_scale_factor = cano_label_scale_ratio * resize_label_scale_ratio - - rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float() - rgb = torch.div((rgb - mean), std) - rgb = rgb[None, :, :, :].cuda() - - cam_model = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() - cam_model = cam_model[None, :, :, :].cuda() - cam_model_stacks = [ - torch.nn.functional.interpolate(cam_model, size=(cam_model.shape[2]//i, cam_model.shape[3]//i), mode='bilinear', align_corners=False) - for i in [2, 4, 8, 16, 32] - ] - return rgb, cam_model_stacks, pad, label_scale_factor - -def do_scalecano_test_with_custom_data( - model: torch.nn.Module, - cfg: dict, - test_data: list, - logger: logging.RootLogger, - is_distributed: bool = True, - local_rank: int = 0, -): - - show_dir = cfg.show_dir - save_interval = 1 - save_imgs_dir = show_dir + '/vis' - os.makedirs(save_imgs_dir, exist_ok=True) - save_pcd_dir = show_dir + '/pcd' - os.makedirs(save_pcd_dir, exist_ok=True) - - normalize_scale = cfg.data_basic.depth_range[1] - dam = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) - dam_median = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) - dam_global = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) - - for i, an in tqdm(enumerate(test_data)): - #for i, an in enumerate(test_data): - print(an['rgb']) - rgb_origin = cv2.imread(an['rgb'])[:, :, ::-1].copy() - if an['depth'] is not None: - gt_depth = cv2.imread(an['depth'], -1) - gt_depth_scale = an['depth_scale'] - gt_depth = gt_depth / gt_depth_scale - gt_depth_flag = True - else: - gt_depth = None - gt_depth_flag = False - intrinsic = an['intrinsic'] - if intrinsic is None: - intrinsic = [1000.0, 1000.0, rgb_origin.shape[1]/2, rgb_origin.shape[0]/2] - # intrinsic = [542.0, 542.0, 963.706, 760.199] - print(intrinsic) - rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(rgb_origin, intrinsic, cfg.data_basic) - - pred_depth, pred_depth_scale, scale, output = get_prediction( - model = model, - input = rgb_input, - cam_model = cam_models_stacks, - pad_info = pad, - scale_info = label_scale_factor, - gt_depth = None, - normalize_scale = normalize_scale, - ori_shape=[rgb_origin.shape[0], rgb_origin.shape[1]], - ) - - pred_depth = (pred_depth > 0) * (pred_depth < 300) * pred_depth - if gt_depth_flag: - - pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], (gt_depth.shape[0], gt_depth.shape[1]), mode='bilinear').squeeze() # to original size - - gt_depth = torch.from_numpy(gt_depth).cuda() - - pred_depth_median = pred_depth * gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median() - pred_global, _ = align_scale_shift(pred_depth, gt_depth) - - mask = (gt_depth > 1e-8) - dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed) - dam_median.update_metrics_gpu(pred_depth_median, gt_depth, mask, is_distributed) - dam_global.update_metrics_gpu(pred_global, gt_depth, mask, is_distributed) - print(gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median(), ) - - if i % save_interval == 0: - os.makedirs(osp.join(save_imgs_dir, an['folder']), exist_ok=True) - rgb_torch = torch.from_numpy(rgb_origin).to(pred_depth.device).permute(2, 0, 1) - mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None].to(rgb_torch.device) - std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None].to(rgb_torch.device) - rgb_torch = torch.div((rgb_torch - mean), std) - - save_val_imgs( - i, - pred_depth, - gt_depth if gt_depth is not None else torch.ones_like(pred_depth, device=pred_depth.device), - rgb_torch, - osp.join(an['folder'], an['filename']), - save_imgs_dir, - ) - #save_raw_imgs(pred_depth.detach().cpu().numpy(), rgb_torch, osp.join(an['folder'], an['filename']), save_imgs_dir, 1000.0) - - # pcd - pred_depth = pred_depth.detach().cpu().numpy() - #pcd = reconstruct_pcd(pred_depth, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3]) - #os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True) - #save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4]+'.ply')) - - if an['intrinsic'] == None: - #for r in [0.9, 1.0, 1.1]: - for r in [1.0]: - #for f in [600, 800, 1000, 1250, 1500]: - for f in [1000]: - pcd = reconstruct_pcd(pred_depth, f * r, f * (2-r), intrinsic[2], intrinsic[3]) - fstr = '_fx_' + str(int(f * r)) + '_fy_' + str(int(f * (2-r))) - os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True) - save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4] + fstr +'.ply')) - - if "normal_out_list" in output.keys(): - - normal_out_list = output['normal_out_list'] - pred_normal = normal_out_list[0][:, :3, :, :] # (B, 3, H, W) - H, W = pred_normal.shape[2:] - pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] - - gt_normal = None - #if gt_normal_flag: - if False: - pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True) - gt_normal = cv2.imread(norm_path) - gt_normal = cv2.cvtColor(gt_normal, cv2.COLOR_BGR2RGB) - gt_normal = np.array(gt_normal).astype(np.uint8) - gt_normal = ((gt_normal.astype(np.float32) / 255.0) * 2.0) - 1.0 - norm_valid_mask = (np.linalg.norm(gt_normal, axis=2, keepdims=True) > 0.5) - gt_normal = gt_normal * norm_valid_mask - gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True) - dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)# save valiad normal - - if i % save_interval == 0: - save_normal_val_imgs(iter, - pred_normal, - gt_normal if gt_normal is not None else torch.ones_like(pred_normal, device=pred_normal.device), - rgb_torch, # data['input'], - osp.join(an['folder'], 'normal_'+an['filename']), - save_imgs_dir, - ) - - - #if gt_depth_flag: - if False: - eval_error = dam.get_metrics() - print('w/o match :', eval_error) - - eval_error_median = dam_median.get_metrics() - print('median match :', eval_error_median) - - eval_error_global = dam_global.get_metrics() - print('global match :', eval_error_global) - else: - print('missing gt_depth, only save visualizations...') diff --git a/mono/utils/logger.py b/mono/utils/logger.py deleted file mode 100644 index ca48c613b2fdc5352b13ccb7d0bfdc1df5e3b531..0000000000000000000000000000000000000000 --- a/mono/utils/logger.py +++ /dev/null @@ -1,102 +0,0 @@ -import atexit -import logging -import os -import sys -import time -import torch -from termcolor import colored - -__all__ = ["setup_logger", ] - -class _ColorfulFormatter(logging.Formatter): - def __init__(self, *args, **kwargs): - self._root_name = kwargs.pop("root_name") + "." - self._abbrev_name = kwargs.pop("abbrev_name", "") - if len(self._abbrev_name): - self._abbrev_name = self._abbrev_name + "." - super(_ColorfulFormatter, self).__init__(*args, **kwargs) - - def formatMessage(self, record): - record.name = record.name.replace(self._root_name, self._abbrev_name) - log = super(_ColorfulFormatter, self).formatMessage(record) - if record.levelno == logging.WARNING: - prefix = colored("WARNING", "red", attrs=["blink"]) - elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: - prefix = colored("ERROR", "red", attrs=["blink", "underline"]) - else: - return log - return prefix + " " + log - -def setup_logger( - output=None, distributed_rank=0, *, name='metricdepth', color=True, abbrev_name=None -): - """ - Initialize the detectron2 logger and set its verbosity level to "DEBUG". - Args: - output (str): a file name or a directory to save log. If None, will not save log file. - If ends with ".txt" or ".log", assumed to be a file name. - Otherwise, logs will be saved to `output/log.txt`. - abbrev_name (str): an abbreviation of the module, to avoid log names in logs. - Set to "" not log the root module in logs. - By default, will abbreviate "detectron2" to "d2" and leave other - modules unchanged. - Returns: - logging.Logger: a logger - """ - logger = logging.getLogger() - logger.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG - logger.propagate = False - - if abbrev_name is None: - abbrev_name = "d2" - - plain_formatter = logging.Formatter( - "[%(asctime)s] %(name)s %(levelname)s %(message)s ", datefmt="%m/%d %H:%M:%S" - ) - # stdout logging: master only - if distributed_rank == 0: - ch = logging.StreamHandler(stream=sys.stdout) - ch.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG - if color: - formatter = _ColorfulFormatter( - colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", - datefmt="%m/%d %H:%M:%S", - root_name=name, - abbrev_name=str(abbrev_name), - ) - else: - formatter = plain_formatter - ch.setFormatter(formatter) - logger.addHandler(ch) - - # file logging: all workers - if output is not None: - if output.endswith(".txt") or output.endswith(".log"): - filename = output - else: - filename = os.path.join(output, "log.txt") - if distributed_rank > 0: - filename = filename + ".rank{}".format(distributed_rank) - os.makedirs(os.path.dirname(filename), exist_ok=True) - - fh = logging.StreamHandler(_cached_log_stream(filename)) - fh.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG - fh.setFormatter(plain_formatter) - logger.addHandler(fh) - - - return logger - -from iopath.common.file_io import PathManager as PathManagerBase - - -PathManager = PathManagerBase() - -# cache the opened file object, so that different calls to 'setup_logger -# with the same file name can safely write to the same file. -def _cached_log_stream(filename): - # use 1K buffer if writting to cloud storage - io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1) - atexit.register(io.close) - return io - \ No newline at end of file diff --git a/mono/utils/mldb.py b/mono/utils/mldb.py deleted file mode 100644 index d74ac53fd0302e2e954105bade52e6de4c18e2f6..0000000000000000000000000000000000000000 --- a/mono/utils/mldb.py +++ /dev/null @@ -1,34 +0,0 @@ -from types import ModuleType -import data_info - -def load_data_info(module_name, data_info={}, mldb_type='mldb_info', module=None): - if module is None: - module = globals().get(module_name, None) - if module: - for key, value in module.__dict__.items(): - if not (key.startswith('__')) and not (key.startswith('_')): - if key == 'mldb_info': - data_info.update(value) - elif isinstance(value, ModuleType): - load_data_info(module_name + '.' + key, data_info, module=value) - else: - raise RuntimeError(f'Try to access "mldb_info", but cannot find {module_name} module.') - -def reset_ckpt_path(cfg, data_info): - if isinstance(cfg, dict): - for key in cfg.keys(): - if key == 'backbone': - new_ckpt_path = data_info['checkpoint']['mldb_root'] + '/' + data_info['checkpoint'][cfg.backbone.type] - cfg.backbone.update(checkpoint=new_ckpt_path) - continue - elif isinstance(cfg.get(key), dict): - reset_ckpt_path(cfg.get(key), data_info) - else: - continue - else: - return - -if __name__ == '__main__': - mldb_info_tmp = {} - load_data_info('mldb_data_info', mldb_info_tmp) - print('results', mldb_info_tmp.keys()) \ No newline at end of file diff --git a/mono/utils/pcd_filter.py b/mono/utils/pcd_filter.py deleted file mode 100644 index 2d26314d806ea961f6bf09d1fb195bf5e364f181..0000000000000000000000000000000000000000 --- a/mono/utils/pcd_filter.py +++ /dev/null @@ -1,24 +0,0 @@ -import open3d as o3d -import numpy as np - -def downsample_and_filter(pcd_file): - pcd = o3d.io.read_point_cloud(pcd_file, max_bound_div = 750, neighbor_num = 8) - point_num = len(pcd.points) - if (point_num > 10000000): - voxel_down_pcd = o3d.geometry.PointCloud.uniform_down_sample(pcd, int(point_num / 10000000)+1) - else: - voxel_down_pcd = pcd - max_bound = voxel_down_pcd.get_max_bound() - ball_radius = np.linalg.norm(max_bound) / max_bound_div - pcd_filter, _ = voxel_down_pcd.remove_radius_outlier(neighbor_num, ball_radius) - print('filtered size', len(pcd_filter.points), 'pre size:', len(pcd.points)) - o3d.io.write_point_cloud(pcd_file[:-4] + '_filtered.ply', pcd_filter) - - -if __name__ == "__main__": - import os - dir_path = './data/demo_pcd' - for pcd_file in os.listdir(dir_path): - #if 'jonathan' in pcd_file: set max_bound_div to 300 and neighbot_num to 8 - downsample_and_filter(os.path.join(dir_path, pcd_file)) - \ No newline at end of file diff --git a/mono/utils/running.py b/mono/utils/running.py deleted file mode 100644 index 8a8b8d2c1f355717f46f784a28ac5f327c01dfc5..0000000000000000000000000000000000000000 --- a/mono/utils/running.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -import torch -import torch.nn as nn -from mono.utils.comm import main_process -import copy -import inspect -import logging -import glob - - -def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None): - """ - Load the check point for resuming training or finetuning. - """ - logger = logging.getLogger() - if os.path.isfile(load_path): - if main_process(): - logger.info(f"Loading weight '{load_path}'") - checkpoint = torch.load(load_path, map_location="cpu") - ckpt_state_dict = checkpoint['model_state_dict'] - model.module.load_state_dict(ckpt_state_dict, strict=strict_match) - - if optimizer is not None: - optimizer.load_state_dict(checkpoint['optimizer']) - if scheduler is not None: - scheduler.load_state_dict(checkpoint['scheduler']) - if loss_scaler is not None and 'scaler' in checkpoint: - scheduler.load_state_dict(checkpoint['scaler']) - del ckpt_state_dict - del checkpoint - if main_process(): - logger.info(f"Successfully loaded weight: '{load_path}'") - if scheduler is not None and optimizer is not None: - logger.info(f"Resume training from: '{load_path}'") - else: - if main_process(): - raise RuntimeError(f"No weight found at '{load_path}'") - return model, optimizer, scheduler, loss_scaler - - -def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None): - """ - Save the model, optimizer, lr scheduler. - """ - logger = logging.getLogger() - - if 'IterBasedRunner' in cfg.runner.type: - max_iters = cfg.runner.max_iters - elif 'EpochBasedRunner' in cfg.runner.type: - max_iters = cfg.runner.max_epochs - else: - raise TypeError(f'{cfg.runner.type} is not supported') - - ckpt = dict( - model_state_dict=model.module.state_dict(), - optimizer=optimizer.state_dict(), - max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \ - else cfg.runner.max_epochs, - scheduler=scheduler.state_dict(), - ) - - if loss_scaler is not None: - ckpt.update(dict(scaler=loss_scaler.state_dict())) - - ckpt_dir = os.path.join(cfg.work_dir, 'ckpt') - os.makedirs(ckpt_dir, exist_ok=True) - - save_name = os.path.join(ckpt_dir, 'step%08d.pth' %curr_iter) - saved_ckpts = glob.glob(ckpt_dir + '/step*.pth') - torch.save(ckpt, save_name) - - # keep the last 8 ckpts - if len(saved_ckpts) > 20: - saved_ckpts.sort() - os.remove(saved_ckpts.pop(0)) - - logger.info(f'Save model: {save_name}') diff --git a/mono/utils/transform.py b/mono/utils/transform.py deleted file mode 100644 index 2af94efe754d6f72325db6fdc170f30fbfb8c2fe..0000000000000000000000000000000000000000 --- a/mono/utils/transform.py +++ /dev/null @@ -1,408 +0,0 @@ -import collections -import cv2 -import math -import numpy as np -import numbers -import random -import torch - -import matplotlib -import matplotlib.cm - - -""" -Provides a set of Pytorch transforms that use OpenCV instead of PIL (Pytorch default) -for image manipulation. -""" - -class Compose(object): - # Composes transforms: transforms.Compose([transforms.RandScale([0.5, 2.0]), transforms.ToTensor()]) - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): - for t in self.transforms: - images, labels, intrinsics, cam_models, other_labels, transform_paras = t(images, labels, intrinsics, cam_models, other_labels, transform_paras) - return images, labels, intrinsics, cam_models, other_labels, transform_paras - - -class ToTensor(object): - # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). - def __init__(self, **kwargs): - return - def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): - if not isinstance(images, list) or not isinstance(labels, list) or not isinstance(intrinsics, list): - raise (RuntimeError("transform.ToTensor() only handle inputs/labels/intrinsics lists.")) - if len(images) != len(intrinsics): - raise (RuntimeError("Numbers of images and intrinsics are not matched.")) - if not isinstance(images[0], np.ndarray) or not isinstance(labels[0], np.ndarray): - raise (RuntimeError("transform.ToTensor() only handle np.ndarray for the input and label." - "[eg: data readed by cv2.imread()].\n")) - if not isinstance(intrinsics[0], list): - raise (RuntimeError("transform.ToTensor() only handle list for the camera intrinsics")) - - if len(images[0].shape) > 3 or len(images[0].shape) < 2: - raise (RuntimeError("transform.ToTensor() only handle image(np.ndarray) with 3 dims or 2 dims.\n")) - if len(labels[0].shape) > 3 or len(labels[0].shape) < 2: - raise (RuntimeError("transform.ToTensor() only handle label(np.ndarray) with 3 dims or 2 dims.\n")) - - if len(intrinsics[0]) >4 or len(intrinsics[0]) < 3: - raise (RuntimeError("transform.ToTensor() only handle intrinsic(list) with 3 sizes or 4 sizes.\n")) - - for i, img in enumerate(images): - if len(img.shape) == 2: - img = np.expand_dims(img, axis=2) - images[i] = torch.from_numpy(img.transpose((2, 0, 1))).float() - for i, lab in enumerate(labels): - if len(lab.shape) == 2: - lab = np.expand_dims(lab, axis=0) - labels[i] = torch.from_numpy(lab).float() - for i, intrinsic in enumerate(intrinsics): - if len(intrinsic) == 3: - intrinsic = [intrinsic[0],] + intrinsic - intrinsics[i] = torch.tensor(intrinsic, dtype=torch.float) - if cam_models is not None: - for i, cam_model in enumerate(cam_models): - cam_models[i] = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() if cam_model is not None else None - if other_labels is not None: - for i, lab in enumerate(other_labels): - if len(lab.shape) == 2: - lab = np.expand_dims(lab, axis=0) - other_labels[i] = torch.from_numpy(lab).float() - return images, labels, intrinsics, cam_models, other_labels, transform_paras - - -class Normalize(object): - # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std - def __init__(self, mean, std=None, **kwargs): - if std is None: - assert len(mean) > 0 - else: - assert len(mean) == len(std) - self.mean = torch.tensor(mean).float()[:, None, None] - self.std = torch.tensor(std).float()[:, None, None] if std is not None \ - else torch.tensor([1.0, 1.0, 1.0]).float()[:, None, None] - - def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): - # if self.std is None: - # # for t, m in zip(image, self.mean): - # # t.sub(m) - # image = image - self.mean - # if ref_images is not None: - # for i, ref_i in enumerate(ref_images): - # ref_images[i] = ref_i - self.mean - # else: - # # for t, m, s in zip(image, self.mean, self.std): - # # t.sub(m).div(s) - # image = (image - self.mean) / self.std - # if ref_images is not None: - # for i, ref_i in enumerate(ref_images): - # ref_images[i] = (ref_i - self.mean) / self.std - for i, img in enumerate(images): - img = torch.div((img - self.mean), self.std) - images[i] = img - return images, labels, intrinsics, cam_models, other_labels, transform_paras - - -class LableScaleCanonical(object): - """ - To solve the ambiguity observation for the mono branch, i.e. different focal length (object size) with the same depth, cameras are - mapped to a canonical space. To mimic this, we set the focal length to a canonical one and scale the depth value. NOTE: resize the image based on the ratio can also solve - Args: - images: list of RGB images. - labels: list of depth/disparity labels. - other labels: other labels, such as instance segmentations, semantic segmentations... - """ - def __init__(self, **kwargs): - self.canonical_focal = kwargs['focal_length'] - - def _get_scale_ratio(self, intrinsic): - target_focal_x = intrinsic[0] - label_scale_ratio = self.canonical_focal / target_focal_x - pose_scale_ratio = 1.0 - return label_scale_ratio, pose_scale_ratio - - def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): - assert len(images[0].shape) == 3 and len(labels[0].shape) == 2 - assert labels[0].dtype == np.float32 - - label_scale_ratio = None - pose_scale_ratio = None - - for i in range(len(intrinsics)): - img_i = images[i] - label_i = labels[i] if i < len(labels) else None - intrinsic_i = intrinsics[i].copy() - cam_model_i = cam_models[i] if cam_models is not None and i < len(cam_models) else None - - label_scale_ratio, pose_scale_ratio = self._get_scale_ratio(intrinsic_i) - - # adjust the focal length, map the current camera to the canonical space - intrinsics[i] = [intrinsic_i[0] * label_scale_ratio, intrinsic_i[1] * label_scale_ratio, intrinsic_i[2], intrinsic_i[3]] - - # scale the label to the canonical space - if label_i is not None: - labels[i] = label_i * label_scale_ratio - - if cam_model_i is not None: - # As the focal length is adjusted (canonical focal length), the camera model should be re-built - ori_h, ori_w, _ = img_i.shape - cam_models[i] = build_camera_model(ori_h, ori_w, intrinsics[i]) - - - if transform_paras is not None: - transform_paras.update(label_scale_factor=label_scale_ratio, focal_scale_factor=label_scale_ratio) - - return images, labels, intrinsics, cam_models, other_labels, transform_paras - - -class ResizeKeepRatio(object): - """ - Resize and pad to a given size. Hold the aspect ratio. - This resizing assumes that the camera model remains unchanged. - Args: - resize_size: predefined output size. - """ - def __init__(self, resize_size, padding=None, ignore_label=-1, **kwargs): - if isinstance(resize_size, int): - self.resize_h = resize_size - self.resize_w = resize_size - elif isinstance(resize_size, collections.Iterable) and len(resize_size) == 2 \ - and isinstance(resize_size[0], int) and isinstance(resize_size[1], int) \ - and resize_size[0] > 0 and resize_size[1] > 0: - self.resize_h = resize_size[0] - self.resize_w = resize_size[1] - else: - raise (RuntimeError("crop size error.\n")) - if padding is None: - self.padding = padding - elif isinstance(padding, list): - if all(isinstance(i, numbers.Number) for i in padding): - self.padding = padding - else: - raise (RuntimeError("padding in Crop() should be a number list\n")) - if len(padding) != 3: - raise (RuntimeError("padding channel is not equal with 3\n")) - else: - raise (RuntimeError("padding in Crop() should be a number list\n")) - if isinstance(ignore_label, int): - self.ignore_label = ignore_label - else: - raise (RuntimeError("ignore_label should be an integer number\n")) - # self.crop_size = kwargs['crop_size'] - self.canonical_focal = kwargs['focal_length'] - - def main_data_transform(self, image, label, intrinsic, cam_model, resize_ratio, padding, to_scale_ratio): - """ - Resize data first and then do the padding. - 'label' will be scaled. - """ - h, w, _ = image.shape - reshape_h = int(resize_ratio * h) - reshape_w = int(resize_ratio * w) - - pad_h, pad_w, pad_h_half, pad_w_half = padding - - # resize - image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) - # padding - image = cv2.copyMakeBorder( - image, - pad_h_half, - pad_h - pad_h_half, - pad_w_half, - pad_w - pad_w_half, - cv2.BORDER_CONSTANT, - value=self.padding) - - if label is not None: - # label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST) - label = resize_depth_preserve(label, (reshape_h, reshape_w)) - label = cv2.copyMakeBorder( - label, - pad_h_half, - pad_h - pad_h_half, - pad_w_half, - pad_w - pad_w_half, - cv2.BORDER_CONSTANT, - value=self.ignore_label) - # scale the label - label = label / to_scale_ratio - - # Resize, adjust principle point - if intrinsic is not None: - intrinsic[0] = intrinsic[0] * resize_ratio / to_scale_ratio - intrinsic[1] = intrinsic[1] * resize_ratio / to_scale_ratio - intrinsic[2] = intrinsic[2] * resize_ratio - intrinsic[3] = intrinsic[3] * resize_ratio - - if cam_model is not None: - #cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) - cam_model = build_camera_model(reshape_h, reshape_w, intrinsic) - cam_model = cv2.copyMakeBorder( - cam_model, - pad_h_half, - pad_h - pad_h_half, - pad_w_half, - pad_w - pad_w_half, - cv2.BORDER_CONSTANT, - value=self.ignore_label) - - # Pad, adjust the principle point - if intrinsic is not None: - intrinsic[2] = intrinsic[2] + pad_w_half - intrinsic[3] = intrinsic[3] + pad_h_half - return image, label, intrinsic, cam_model - - def get_label_scale_factor(self, image, intrinsic, resize_ratio): - ori_h, ori_w, _ = image.shape - # crop_h, crop_w = self.crop_size - ori_focal = intrinsic[0] - - to_canonical_ratio = self.canonical_focal / ori_focal - to_scale_ratio = resize_ratio / to_canonical_ratio - return to_scale_ratio - - def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): - target_h, target_w, _ = images[0].shape - resize_ratio_h = self.resize_h / target_h - resize_ratio_w = self.resize_w / target_w - resize_ratio = min(resize_ratio_h, resize_ratio_w) - reshape_h = int(resize_ratio * target_h) - reshape_w = int(resize_ratio * target_w) - pad_h = max(self.resize_h - reshape_h, 0) - pad_w = max(self.resize_w - reshape_w, 0) - pad_h_half = int(pad_h / 2) - pad_w_half = int(pad_w / 2) - - pad_info = [pad_h, pad_w, pad_h_half, pad_w_half] - to_scale_ratio = self.get_label_scale_factor(images[0], intrinsics[0], resize_ratio) - - for i in range(len(images)): - img = images[i] - label = labels[i] if i < len(labels) else None - intrinsic = intrinsics[i] if i < len(intrinsics) else None - cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None - img, label, intrinsic, cam_model = self.main_data_transform( - img, label, intrinsic, cam_model, resize_ratio, pad_info, to_scale_ratio) - images[i] = img - if label is not None: - labels[i] = label - if intrinsic is not None: - intrinsics[i] = intrinsic - if cam_model is not None: - cam_models[i] = cam_model - - if other_labels is not None: - - for i, other_lab in enumerate(other_labels): - # resize - other_lab = cv2.resize(other_lab, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST) - # pad - other_labels[i] = cv2.copyMakeBorder( - other_lab, - pad_h_half, - pad_h - pad_h_half, - pad_w_half, - pad_w - pad_w_half, - cv2.BORDER_CONSTANT, - value=self.ignore_label) - - pad = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] - if transform_paras is not None: - pad_old = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0] - new_pad = [pad_old[0] + pad[0], pad_old[1] + pad[1], pad_old[2] + pad[2], pad_old[3] + pad[3]] - transform_paras.update(dict(pad=new_pad)) - if 'label_scale_factor' in transform_paras: - transform_paras['label_scale_factor'] = transform_paras['label_scale_factor'] * 1.0 / to_scale_ratio - else: - transform_paras.update(label_scale_factor=1.0/to_scale_ratio) - return images, labels, intrinsics, cam_models, other_labels, transform_paras - - -class BGR2RGB(object): - # Converts image from BGR order to RGB order, for model initialized from Pytorch - def __init__(self, **kwargs): - return - def __call__(self, images, labels, intrinsics, cam_models=None,other_labels=None, transform_paras=None): - for i, img in enumerate(images): - images[i] = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - return images, labels, intrinsics, cam_models, other_labels, transform_paras - - -def resize_depth_preserve(depth, shape): - """ - Resizes depth map preserving all valid depth pixels - Multiple downsampled points can be assigned to the same pixel. - - Parameters - ---------- - depth : np.array [h,w] - Depth map - shape : tuple (H,W) - Output shape - - Returns - ------- - depth : np.array [H,W,1] - Resized depth map - """ - # Store dimensions and reshapes to single column - depth = np.squeeze(depth) - h, w = depth.shape - x = depth.reshape(-1) - # Create coordinate grid - uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2) - # Filters valid points - idx = x > 0 - crd, val = uv[idx], x[idx] - # Downsamples coordinates - crd[:, 0] = (crd[:, 0] * (shape[0] / h) + 0.5).astype(np.int32) - crd[:, 1] = (crd[:, 1] * (shape[1] / w) + 0.5).astype(np.int32) - # Filters points inside image - idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1]) - crd, val = crd[idx], val[idx] - # Creates downsampled depth image and assigns points - depth = np.zeros(shape) - depth[crd[:, 0], crd[:, 1]] = val - # Return resized depth map - return depth - - -def build_camera_model(H : int, W : int, intrinsics : list) -> np.array: - """ - Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map. - """ - fx, fy, u0, v0 = intrinsics - f = (fx + fy) / 2.0 - # principle point location - x_row = np.arange(0, W).astype(np.float32) - x_row_center_norm = (x_row - u0) / W - x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W] - - y_col = np.arange(0, H).astype(np.float32) - y_col_center_norm = (y_col - v0) / H - y_center = np.tile(y_col_center_norm, (W, 1)).T - - # FoV - fov_x = np.arctan(x_center / (f / W)) - fov_y = np.arctan(y_center/ (f / H)) - - cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2) - return cam_model - -def gray_to_colormap(img, cmap='rainbow'): - """ - Transfer gray map to matplotlib colormap - """ - assert img.ndim == 2 - - img[img<0] = 0 - mask_invalid = img < 1e-10 - img = img / (img.max() + 1e-8) - norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1) - cmap_m = matplotlib.cm.get_cmap(cmap) - map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m) - colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8) - colormap[mask_invalid] = 0 - return colormap \ No newline at end of file diff --git a/mono/utils/unproj_pcd.py b/mono/utils/unproj_pcd.py deleted file mode 100644 index a0986d482a2ec68be1dd65719adec662272b833c..0000000000000000000000000000000000000000 --- a/mono/utils/unproj_pcd.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import torch -from plyfile import PlyData, PlyElement -import cv2 - - -def get_pcd_base(H, W, u0, v0, fx, fy): - x_row = np.arange(0, W) - x = np.tile(x_row, (H, 1)) - x = x.astype(np.float32) - u_m_u0 = x - u0 - - y_col = np.arange(0, H) # y_col = np.arange(0, height) - y = np.tile(y_col, (W, 1)).T - y = y.astype(np.float32) - v_m_v0 = y - v0 - - x = u_m_u0 / fx - y = v_m_v0 / fy - z = np.ones_like(x) - pw = np.stack([x, y, z], axis=2) # [h, w, c] - return pw - - -def reconstruct_pcd(depth, fx, fy, u0, v0, pcd_base=None, mask=None): - if type(depth) == torch.__name__: - depth = depth.cpu().numpy().squeeze() - depth = cv2.medianBlur(depth, 5) - if pcd_base is None: - H, W = depth.shape - pcd_base = get_pcd_base(H, W, u0, v0, fx, fy) - pcd = depth[:, :, None] * pcd_base - if mask: - pcd[mask] = 0 - return pcd - - -def save_point_cloud(pcd, rgb, filename, binary=True): - """Save an RGB point cloud as a PLY file. - :paras - @pcd: Nx3 matrix, the XYZ coordinates - @rgb: Nx3 matrix, the rgb colors for each 3D point - """ - assert pcd.shape[0] == rgb.shape[0] - - if rgb is None: - gray_concat = np.tile(np.array([128], dtype=np.uint8), - (pcd.shape[0], 3)) - points_3d = np.hstack((pcd, gray_concat)) - else: - points_3d = np.hstack((pcd, rgb)) - python_types = (float, float, float, int, int, int) - npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), - ('green', 'u1'), ('blue', 'u1')] - if binary is True: - # Format into Numpy structured array - vertices = [] - for row_idx in range(points_3d.shape[0]): - cur_point = points_3d[row_idx] - vertices.append( - tuple( - dtype(point) - for dtype, point in zip(python_types, cur_point))) - vertices_array = np.array(vertices, dtype=npy_types) - el = PlyElement.describe(vertices_array, 'vertex') - - # write - PlyData([el]).write(filename) - else: - x = np.squeeze(points_3d[:, 0]) - y = np.squeeze(points_3d[:, 1]) - z = np.squeeze(points_3d[:, 2]) - r = np.squeeze(points_3d[:, 3]) - g = np.squeeze(points_3d[:, 4]) - b = np.squeeze(points_3d[:, 5]) - - ply_head = 'ply\n' \ - 'format ascii 1.0\n' \ - 'element vertex %d\n' \ - 'property float x\n' \ - 'property float y\n' \ - 'property float z\n' \ - 'property uchar red\n' \ - 'property uchar green\n' \ - 'property uchar blue\n' \ - 'end_header' % r.shape[0] - # ---- Save ply data to disk - np.savetxt(filename, np.column_stack[x, y, z, r, g, b], fmt='%f %f %f %d %d %d', header=ply_head, comments='') \ No newline at end of file diff --git a/mono/utils/visualization.py b/mono/utils/visualization.py deleted file mode 100644 index 07275030c48aeea062c0041b11ba60d911c14a3f..0000000000000000000000000000000000000000 --- a/mono/utils/visualization.py +++ /dev/null @@ -1,140 +0,0 @@ -import matplotlib.pyplot as plt -import os, cv2 -import numpy as np -from mono.utils.transform import gray_to_colormap -import shutil -import glob -from mono.utils.running import main_process -import torch -from html4vision import Col, imagetable - -def save_raw_imgs( - pred: torch.tensor, - rgb: torch.tensor, - filename: str, - save_dir: str, - scale: float=200.0, - target: torch.tensor=None, - ): - """ - Save raw GT, predictions, RGB in the same file. - """ - cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb) - cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_d.png'), (pred*scale).astype(np.uint16)) - if target is not None: - cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16)) - - -def save_val_imgs( - iter: int, - pred: torch.tensor, - target: torch.tensor, - rgb: torch.tensor, - filename: str, - save_dir: str, - tb_logger=None - ): - """ - Save GT, predictions, RGB in the same file. - """ - rgb, pred_scale, target_scale, pred_color, target_color = get_data_for_log(pred, target, rgb) - rgb = rgb.transpose((1, 2, 0)) - cat_img = np.concatenate([rgb, pred_color, target_color], axis=0) - plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img) - - # save to tensorboard - if tb_logger is not None: - tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter) - -def save_normal_val_imgs( - iter: int, - pred: torch.tensor, - targ: torch.tensor, - rgb: torch.tensor, - filename: str, - save_dir: str, - tb_logger=None, - mask=None, - ): - """ - Save GT, predictions, RGB in the same file. - """ - mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :] - std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :] - pred = pred.squeeze() - targ = targ.squeeze() - rgb = rgb.squeeze() - - if pred.size(0) == 3: - pred = pred.permute(1,2,0) - if targ.size(0) == 3: - targ = targ.permute(1,2,0) - if rgb.size(0) == 3: - rgb = rgb.permute(1,2,0) - - pred_color = vis_surface_normal(pred, mask) - targ_color = vis_surface_normal(targ, mask) - rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8) - - try: - cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0) - except: - pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0])) - targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0])) - cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0) - - plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img) - # cv2.imwrite(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color) - # save to tensorboard - if tb_logger is not None: - tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter) - -def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor): - mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis] - std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis] - - pred = pred.squeeze().cpu().numpy() - target = target.squeeze().cpu().numpy() - rgb = rgb.squeeze().cpu().numpy() - - pred[pred<0] = 0 - target[target<0] = 0 - max_scale = max(pred.max(), target.max()) - pred_scale = (pred/max_scale * 10000).astype(np.uint16) - target_scale = (target/max_scale * 10000).astype(np.uint16) - pred_color = gray_to_colormap(pred) - target_color = gray_to_colormap(target) - pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1])) - target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1])) - - rgb = ((rgb * std) + mean).astype(np.uint8) - return rgb, pred_scale, target_scale, pred_color, target_color - - -def create_html(name2path, save_path='index.html', size=(256, 384)): - # table description - cols = [] - for k, v in name2path.items(): - col_i = Col('img', k, v) # specify image content for column - cols.append(col_i) - # html table generation - imagetable(cols, out_file=save_path, imsize=size) - -def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array: - """ - Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255] - Aargs: - normal (torch.tensor, [h, w, 3]): surface normal - mask (torch.tensor, [h, w]): valid masks - """ - normal = normal.cpu().numpy().squeeze() - n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True)) - n_img_norm = normal / (n_img_L2 + 1e-8) - normal_vis = n_img_norm * 127 - normal_vis += 128 - normal_vis = normal_vis.astype(np.uint8) - if mask is not None: - mask = mask.cpu().numpy().squeeze() - normal_vis[~mask] = 0 - return normal_vis -