comfyanonymous
1 year ago
20 changed files with 1 additions and 4034 deletions
@ -1,24 +0,0 @@
|
||||
import torch |
||||
|
||||
from ldm.modules.midas.api import load_midas_transform |
||||
|
||||
|
||||
class AddMiDaS(object): |
||||
def __init__(self, model_type): |
||||
super().__init__() |
||||
self.transform = load_midas_transform(model_type) |
||||
|
||||
def pt2np(self, x): |
||||
x = ((x + 1.0) * .5).detach().cpu().numpy() |
||||
return x |
||||
|
||||
def np2pt(self, x): |
||||
x = torch.from_numpy(x) * 2 - 1. |
||||
return x |
||||
|
||||
def __call__(self, sample): |
||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point |
||||
x = self.pt2np(sample['jpg']) |
||||
x = self.transform({"image": x})["image"] |
||||
sample['midas_in'] = x |
||||
return sample |
File diff suppressed because it is too large
Load Diff
@ -1,59 +0,0 @@
|
||||
|
||||
|
||||
from typing import List, Tuple, Union |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
|
||||
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py |
||||
|
||||
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: |
||||
r"""Normalize an image/video tensor with mean and standard deviation. |
||||
.. math:: |
||||
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]} |
||||
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, |
||||
Args: |
||||
data: Image tensor of size :math:`(B, C, *)`. |
||||
mean: Mean for each channel. |
||||
std: Standard deviations for each channel. |
||||
Return: |
||||
Normalised tensor with same size as input :math:`(B, C, *)`. |
||||
Examples: |
||||
>>> x = torch.rand(1, 4, 3, 3) |
||||
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.])) |
||||
>>> out.shape |
||||
torch.Size([1, 4, 3, 3]) |
||||
>>> x = torch.rand(1, 4, 3, 3) |
||||
>>> mean = torch.zeros(4) |
||||
>>> std = 255. * torch.ones(4) |
||||
>>> out = normalize(x, mean, std) |
||||
>>> out.shape |
||||
torch.Size([1, 4, 3, 3]) |
||||
""" |
||||
shape = data.shape |
||||
if len(mean.shape) == 0 or mean.shape[0] == 1: |
||||
mean = mean.expand(shape[1]) |
||||
if len(std.shape) == 0 or std.shape[0] == 1: |
||||
std = std.expand(shape[1]) |
||||
|
||||
# Allow broadcast on channel dimension |
||||
if mean.shape and mean.shape[0] != 1: |
||||
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]: |
||||
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") |
||||
|
||||
# Allow broadcast on channel dimension |
||||
if std.shape and std.shape[0] != 1: |
||||
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]: |
||||
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") |
||||
|
||||
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) |
||||
std = torch.as_tensor(std, device=data.device, dtype=data.dtype) |
||||
|
||||
if mean.shape: |
||||
mean = mean[..., :, None] |
||||
if std.shape: |
||||
std = std[..., :, None] |
||||
|
||||
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std |
||||
|
||||
return out.view(shape) |
@ -1,314 +0,0 @@
|
||||
import torch |
||||
import torch.nn as nn |
||||
from . import kornia_functions |
||||
from torch.utils.checkpoint import checkpoint |
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel |
||||
|
||||
import open_clip |
||||
from ldm.util import default, count_params |
||||
|
||||
|
||||
class AbstractEncoder(nn.Module): |
||||
def __init__(self): |
||||
super().__init__() |
||||
|
||||
def encode(self, *args, **kwargs): |
||||
raise NotImplementedError |
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder): |
||||
|
||||
def encode(self, x): |
||||
return x |
||||
|
||||
|
||||
class ClassEmbedder(nn.Module): |
||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): |
||||
super().__init__() |
||||
self.key = key |
||||
self.embedding = nn.Embedding(n_classes, embed_dim) |
||||
self.n_classes = n_classes |
||||
self.ucg_rate = ucg_rate |
||||
|
||||
def forward(self, batch, key=None, disable_dropout=False): |
||||
if key is None: |
||||
key = self.key |
||||
# this is for use in crossattn |
||||
c = batch[key][:, None] |
||||
if self.ucg_rate > 0. and not disable_dropout: |
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) |
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) |
||||
c = c.long() |
||||
c = self.embedding(c) |
||||
return c |
||||
|
||||
def get_unconditional_conditioning(self, bs, device="cuda"): |
||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) |
||||
uc = torch.ones((bs,), device=device) * uc_class |
||||
uc = {self.key: uc} |
||||
return uc |
||||
|
||||
|
||||
def disabled_train(self, mode=True): |
||||
"""Overwrite model.train with this function to make sure train/eval mode |
||||
does not change anymore.""" |
||||
return self |
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder): |
||||
"""Uses the T5 transformer encoder for text""" |
||||
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, |
||||
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl |
||||
super().__init__() |
||||
self.tokenizer = T5Tokenizer.from_pretrained(version) |
||||
self.transformer = T5EncoderModel.from_pretrained(version) |
||||
self.device = device |
||||
self.max_length = max_length # TODO: typical value? |
||||
if freeze: |
||||
self.freeze() |
||||
|
||||
def freeze(self): |
||||
self.transformer = self.transformer.eval() |
||||
# self.train = disabled_train |
||||
for param in self.parameters(): |
||||
param.requires_grad = False |
||||
|
||||
def forward(self, text): |
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, |
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") |
||||
tokens = batch_encoding["input_ids"].to(self.device) |
||||
outputs = self.transformer(input_ids=tokens) |
||||
|
||||
z = outputs.last_hidden_state |
||||
return z |
||||
|
||||
def encode(self, text): |
||||
return self(text) |
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder): |
||||
"""Uses the CLIP transformer encoder for text (from huggingface)""" |
||||
LAYERS = [ |
||||
"last", |
||||
"pooled", |
||||
"hidden" |
||||
] |
||||
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, |
||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 |
||||
super().__init__() |
||||
assert layer in self.LAYERS |
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version) |
||||
self.transformer = CLIPTextModel.from_pretrained(version) |
||||
self.device = device |
||||
self.max_length = max_length |
||||
if freeze: |
||||
self.freeze() |
||||
self.layer = layer |
||||
self.layer_idx = layer_idx |
||||
if layer == "hidden": |
||||
assert layer_idx is not None |
||||
assert 0 <= abs(layer_idx) <= 12 |
||||
|
||||
def freeze(self): |
||||
self.transformer = self.transformer.eval() |
||||
# self.train = disabled_train |
||||
for param in self.parameters(): |
||||
param.requires_grad = False |
||||
|
||||
def forward(self, text): |
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, |
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") |
||||
tokens = batch_encoding["input_ids"].to(self.device) |
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") |
||||
if self.layer == "last": |
||||
z = outputs.last_hidden_state |
||||
elif self.layer == "pooled": |
||||
z = outputs.pooler_output[:, None, :] |
||||
else: |
||||
z = outputs.hidden_states[self.layer_idx] |
||||
return z |
||||
|
||||
def encode(self, text): |
||||
return self(text) |
||||
|
||||
|
||||
class ClipImageEmbedder(nn.Module): |
||||
def __init__( |
||||
self, |
||||
model, |
||||
jit=False, |
||||
device='cuda' if torch.cuda.is_available() else 'cpu', |
||||
antialias=True, |
||||
ucg_rate=0. |
||||
): |
||||
super().__init__() |
||||
from clip import load as load_clip |
||||
self.model, _ = load_clip(name=model, device=device, jit=jit) |
||||
|
||||
self.antialias = antialias |
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) |
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) |
||||
self.ucg_rate = ucg_rate |
||||
|
||||
def preprocess(self, x): |
||||
# normalize to [0,1] |
||||
# x = kornia_functions.geometry_resize(x, (224, 224), |
||||
# interpolation='bicubic', align_corners=True, |
||||
# antialias=self.antialias) |
||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True) |
||||
x = (x + 1.) / 2. |
||||
# re-normalize according to clip |
||||
x = kornia_functions.enhance_normalize(x, self.mean, self.std) |
||||
return x |
||||
|
||||
def forward(self, x, no_dropout=False): |
||||
# x is assumed to be in range [-1,1] |
||||
out = self.model.encode_image(self.preprocess(x)) |
||||
out = out.to(x.dtype) |
||||
if self.ucg_rate > 0. and not no_dropout: |
||||
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out |
||||
return out |
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder): |
||||
""" |
||||
Uses the OpenCLIP transformer encoder for text |
||||
""" |
||||
LAYERS = [ |
||||
# "pooled", |
||||
"last", |
||||
"penultimate" |
||||
] |
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, |
||||
freeze=True, layer="last"): |
||||
super().__init__() |
||||
assert layer in self.LAYERS |
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) |
||||
del model.visual |
||||
self.model = model |
||||
|
||||
self.device = device |
||||
self.max_length = max_length |
||||
if freeze: |
||||
self.freeze() |
||||
self.layer = layer |
||||
if self.layer == "last": |
||||
self.layer_idx = 0 |
||||
elif self.layer == "penultimate": |
||||
self.layer_idx = 1 |
||||
else: |
||||
raise NotImplementedError() |
||||
|
||||
def freeze(self): |
||||
self.model = self.model.eval() |
||||
for param in self.parameters(): |
||||
param.requires_grad = False |
||||
|
||||
def forward(self, text): |
||||
tokens = open_clip.tokenize(text) |
||||
z = self.encode_with_transformer(tokens.to(self.device)) |
||||
return z |
||||
|
||||
def encode_with_transformer(self, text): |
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] |
||||
x = x + self.model.positional_embedding |
||||
x = x.permute(1, 0, 2) # NLD -> LND |
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) |
||||
x = x.permute(1, 0, 2) # LND -> NLD |
||||
x = self.model.ln_final(x) |
||||
return x |
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): |
||||
for i, r in enumerate(self.model.transformer.resblocks): |
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx: |
||||
break |
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): |
||||
x = checkpoint(r, x, attn_mask) |
||||
else: |
||||
x = r(x, attn_mask=attn_mask) |
||||
return x |
||||
|
||||
def encode(self, text): |
||||
return self(text) |
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder): |
||||
""" |
||||
Uses the OpenCLIP vision transformer encoder for images |
||||
""" |
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, |
||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.): |
||||
super().__init__() |
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), |
||||
pretrained=version, ) |
||||
del model.transformer |
||||
self.model = model |
||||
|
||||
self.device = device |
||||
self.max_length = max_length |
||||
if freeze: |
||||
self.freeze() |
||||
self.layer = layer |
||||
if self.layer == "penultimate": |
||||
raise NotImplementedError() |
||||
self.layer_idx = 1 |
||||
|
||||
self.antialias = antialias |
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) |
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) |
||||
self.ucg_rate = ucg_rate |
||||
|
||||
def preprocess(self, x): |
||||
# normalize to [0,1] |
||||
# x = kornia.geometry.resize(x, (224, 224), |
||||
# interpolation='bicubic', align_corners=True, |
||||
# antialias=self.antialias) |
||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True) |
||||
x = (x + 1.) / 2. |
||||
# renormalize according to clip |
||||
x = kornia_functions.enhance_normalize(x, self.mean, self.std) |
||||
return x |
||||
|
||||
def freeze(self): |
||||
self.model = self.model.eval() |
||||
for param in self.parameters(): |
||||
param.requires_grad = False |
||||
|
||||
def forward(self, image, no_dropout=False): |
||||
z = self.encode_with_vision_transformer(image) |
||||
if self.ucg_rate > 0. and not no_dropout: |
||||
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z |
||||
return z |
||||
|
||||
def encode_with_vision_transformer(self, img): |
||||
img = self.preprocess(img) |
||||
x = self.model.visual(img) |
||||
return x |
||||
|
||||
def encode(self, text): |
||||
return self(text) |
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder): |
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", |
||||
clip_max_length=77, t5_max_length=77): |
||||
super().__init__() |
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) |
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) |
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " |
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") |
||||
|
||||
def encode(self, text): |
||||
return self(text) |
||||
|
||||
def forward(self, text): |
||||
clip_z = self.clip_encoder.encode(text) |
||||
t5_z = self.t5_encoder.encode(text) |
||||
return [clip_z, t5_z] |
@ -1,170 +0,0 @@
|
||||
# based on https://github.com/isl-org/MiDaS |
||||
|
||||
import cv2 |
||||
import torch |
||||
import torch.nn as nn |
||||
from torchvision.transforms import Compose |
||||
|
||||
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel |
||||
from ldm.modules.midas.midas.midas_net import MidasNet |
||||
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small |
||||
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet |
||||
|
||||
|
||||
ISL_PATHS = { |
||||
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", |
||||
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", |
||||
"midas_v21": "", |
||||
"midas_v21_small": "", |
||||
} |
||||
|
||||
|
||||
def disabled_train(self, mode=True): |
||||
"""Overwrite model.train with this function to make sure train/eval mode |
||||
does not change anymore.""" |
||||
return self |
||||
|
||||
|
||||
def load_midas_transform(model_type): |
||||
# https://github.com/isl-org/MiDaS/blob/master/run.py |
||||
# load transform only |
||||
if model_type == "dpt_large": # DPT-Large |
||||
net_w, net_h = 384, 384 |
||||
resize_mode = "minimal" |
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
||||
|
||||
elif model_type == "dpt_hybrid": # DPT-Hybrid |
||||
net_w, net_h = 384, 384 |
||||
resize_mode = "minimal" |
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
||||
|
||||
elif model_type == "midas_v21": |
||||
net_w, net_h = 384, 384 |
||||
resize_mode = "upper_bound" |
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
||||
|
||||
elif model_type == "midas_v21_small": |
||||
net_w, net_h = 256, 256 |
||||
resize_mode = "upper_bound" |
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
||||
|
||||
else: |
||||
assert False, f"model_type '{model_type}' not implemented, use: --model_type large" |
||||
|
||||
transform = Compose( |
||||
[ |
||||
Resize( |
||||
net_w, |
||||
net_h, |
||||
resize_target=None, |
||||
keep_aspect_ratio=True, |
||||
ensure_multiple_of=32, |
||||
resize_method=resize_mode, |
||||
image_interpolation_method=cv2.INTER_CUBIC, |
||||
), |
||||
normalization, |
||||
PrepareForNet(), |
||||
] |
||||
) |
||||
|
||||
return transform |
||||
|
||||
|
||||
def load_model(model_type): |
||||
# https://github.com/isl-org/MiDaS/blob/master/run.py |
||||
# load network |
||||
model_path = ISL_PATHS[model_type] |
||||
if model_type == "dpt_large": # DPT-Large |
||||
model = DPTDepthModel( |
||||
path=model_path, |
||||
backbone="vitl16_384", |
||||
non_negative=True, |
||||
) |
||||
net_w, net_h = 384, 384 |
||||
resize_mode = "minimal" |
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
||||
|
||||
elif model_type == "dpt_hybrid": # DPT-Hybrid |
||||
model = DPTDepthModel( |
||||
path=model_path, |
||||
backbone="vitb_rn50_384", |
||||
non_negative=True, |
||||
) |
||||
net_w, net_h = 384, 384 |
||||
resize_mode = "minimal" |
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
||||
|
||||
elif model_type == "midas_v21": |
||||
model = MidasNet(model_path, non_negative=True) |
||||
net_w, net_h = 384, 384 |
||||
resize_mode = "upper_bound" |
||||
normalization = NormalizeImage( |
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
||||
) |
||||
|
||||
elif model_type == "midas_v21_small": |
||||
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, |
||||
non_negative=True, blocks={'expand': True}) |
||||
net_w, net_h = 256, 256 |
||||
resize_mode = "upper_bound" |
||||
normalization = NormalizeImage( |
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
||||
) |
||||
|
||||
else: |
||||
print(f"model_type '{model_type}' not implemented, use: --model_type large") |
||||
assert False |
||||
|
||||
transform = Compose( |
||||
[ |
||||
Resize( |
||||
net_w, |
||||
net_h, |
||||
resize_target=None, |
||||
keep_aspect_ratio=True, |
||||
ensure_multiple_of=32, |
||||
resize_method=resize_mode, |
||||
image_interpolation_method=cv2.INTER_CUBIC, |
||||
), |
||||
normalization, |
||||
PrepareForNet(), |
||||
] |
||||
) |
||||
|
||||
return model.eval(), transform |
||||
|
||||
|
||||
class MiDaSInference(nn.Module): |
||||
MODEL_TYPES_TORCH_HUB = [ |
||||
"DPT_Large", |
||||
"DPT_Hybrid", |
||||
"MiDaS_small" |
||||
] |
||||
MODEL_TYPES_ISL = [ |
||||
"dpt_large", |
||||
"dpt_hybrid", |
||||
"midas_v21", |
||||
"midas_v21_small", |
||||
] |
||||
|
||||
def __init__(self, model_type): |
||||
super().__init__() |
||||
assert (model_type in self.MODEL_TYPES_ISL) |
||||
model, _ = load_model(model_type) |
||||
self.model = model |
||||
self.model.train = disabled_train |
||||
|
||||
def forward(self, x): |
||||
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array |
||||
# NOTE: we expect that the correct transform has been called during dataloading. |
||||
with torch.no_grad(): |
||||
prediction = self.model(x) |
||||
prediction = torch.nn.functional.interpolate( |
||||
prediction.unsqueeze(1), |
||||
size=x.shape[2:], |
||||
mode="bicubic", |
||||
align_corners=False, |
||||
) |
||||
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) |
||||
return prediction |
||||
|
@ -1,16 +0,0 @@
|
||||
import torch |
||||
|
||||
|
||||
class BaseModel(torch.nn.Module): |
||||
def load(self, path): |
||||
"""Load model from file. |
||||
|
||||
Args: |
||||
path (str): file path |
||||
""" |
||||
parameters = torch.load(path, map_location=torch.device('cpu')) |
||||
|
||||
if "optimizer" in parameters: |
||||
parameters = parameters["model"] |
||||
|
||||
self.load_state_dict(parameters) |
@ -1,342 +0,0 @@
|
||||
import torch |
||||
import torch.nn as nn |
||||
|
||||
from .vit import ( |
||||
_make_pretrained_vitb_rn50_384, |
||||
_make_pretrained_vitl16_384, |
||||
_make_pretrained_vitb16_384, |
||||
forward_vit, |
||||
) |
||||
|
||||
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): |
||||
if backbone == "vitl16_384": |
||||
pretrained = _make_pretrained_vitl16_384( |
||||
use_pretrained, hooks=hooks, use_readout=use_readout |
||||
) |
||||
scratch = _make_scratch( |
||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand |
||||
) # ViT-L/16 - 85.0% Top1 (backbone) |
||||
elif backbone == "vitb_rn50_384": |
||||
pretrained = _make_pretrained_vitb_rn50_384( |
||||
use_pretrained, |
||||
hooks=hooks, |
||||
use_vit_only=use_vit_only, |
||||
use_readout=use_readout, |
||||
) |
||||
scratch = _make_scratch( |
||||
[256, 512, 768, 768], features, groups=groups, expand=expand |
||||
) # ViT-H/16 - 85.0% Top1 (backbone) |
||||
elif backbone == "vitb16_384": |
||||
pretrained = _make_pretrained_vitb16_384( |
||||
use_pretrained, hooks=hooks, use_readout=use_readout |
||||
) |
||||
scratch = _make_scratch( |
||||
[96, 192, 384, 768], features, groups=groups, expand=expand |
||||
) # ViT-B/16 - 84.6% Top1 (backbone) |
||||
elif backbone == "resnext101_wsl": |
||||
pretrained = _make_pretrained_resnext101_wsl(use_pretrained) |
||||
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 |
||||
elif backbone == "efficientnet_lite3": |
||||
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) |
||||
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 |
||||
else: |
||||
print(f"Backbone '{backbone}' not implemented") |
||||
assert False |
||||
|
||||
return pretrained, scratch |
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False): |
||||
scratch = nn.Module() |
||||
|
||||
out_shape1 = out_shape |
||||
out_shape2 = out_shape |
||||
out_shape3 = out_shape |
||||
out_shape4 = out_shape |
||||
if expand==True: |
||||
out_shape1 = out_shape |
||||
out_shape2 = out_shape*2 |
||||
out_shape3 = out_shape*4 |
||||
out_shape4 = out_shape*8 |
||||
|
||||
scratch.layer1_rn = nn.Conv2d( |
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups |
||||
) |
||||
scratch.layer2_rn = nn.Conv2d( |
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups |
||||
) |
||||
scratch.layer3_rn = nn.Conv2d( |
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups |
||||
) |
||||
scratch.layer4_rn = nn.Conv2d( |
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups |
||||
) |
||||
|
||||
return scratch |
||||
|
||||
|
||||
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): |
||||
efficientnet = torch.hub.load( |
||||
"rwightman/gen-efficientnet-pytorch", |
||||
"tf_efficientnet_lite3", |
||||
pretrained=use_pretrained, |
||||
exportable=exportable |
||||
) |
||||
return _make_efficientnet_backbone(efficientnet) |
||||
|
||||
|
||||
def _make_efficientnet_backbone(effnet): |
||||
pretrained = nn.Module() |
||||
|
||||
pretrained.layer1 = nn.Sequential( |
||||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] |
||||
) |
||||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) |
||||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) |
||||
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) |
||||
|
||||
return pretrained |
||||
|
||||
|
||||
def _make_resnet_backbone(resnet): |
||||
pretrained = nn.Module() |
||||
pretrained.layer1 = nn.Sequential( |
||||
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 |
||||
) |
||||
|
||||
pretrained.layer2 = resnet.layer2 |
||||
pretrained.layer3 = resnet.layer3 |
||||
pretrained.layer4 = resnet.layer4 |
||||
|
||||
return pretrained |
||||
|
||||
|
||||
def _make_pretrained_resnext101_wsl(use_pretrained): |
||||
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") |
||||
return _make_resnet_backbone(resnet) |
||||
|
||||
|
||||
|
||||
class Interpolate(nn.Module): |
||||
"""Interpolation module. |
||||
""" |
||||
|
||||
def __init__(self, scale_factor, mode, align_corners=False): |
||||
"""Init. |
||||
|
||||
Args: |
||||
scale_factor (float): scaling |
||||
mode (str): interpolation mode |
||||
""" |
||||
super(Interpolate, self).__init__() |
||||
|
||||
self.interp = nn.functional.interpolate |
||||
self.scale_factor = scale_factor |
||||
self.mode = mode |
||||
self.align_corners = align_corners |
||||
|
||||
def forward(self, x): |
||||
"""Forward pass. |
||||
|
||||
Args: |
||||
x (tensor): input |
||||
|
||||
Returns: |
||||
tensor: interpolated data |
||||
""" |
||||
|
||||
x = self.interp( |
||||
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners |
||||
) |
||||
|
||||
return x |
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module): |
||||
"""Residual convolution module. |
||||
""" |
||||
|
||||
def __init__(self, features): |
||||
"""Init. |
||||
|
||||
Args: |
||||
features (int): number of features |
||||
""" |
||||
super().__init__() |
||||
|
||||
self.conv1 = nn.Conv2d( |
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True |
||||
) |
||||
|
||||
self.conv2 = nn.Conv2d( |
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True |
||||
) |
||||
|
||||
self.relu = nn.ReLU(inplace=True) |
||||
|
||||
def forward(self, x): |
||||
"""Forward pass. |
||||
|
||||
Args: |
||||
x (tensor): input |
||||
|
||||
Returns: |
||||
tensor: output |
||||
""" |
||||
out = self.relu(x) |
||||
out = self.conv1(out) |
||||
out = self.relu(out) |
||||
out = self.conv2(out) |
||||
|
||||
return out + x |
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module): |
||||
"""Feature fusion block. |
||||
""" |
||||
|
||||
def __init__(self, features): |
||||
"""Init. |
||||
|
||||
Args: |
||||
features (int): number of features |
||||
""" |
||||
super(FeatureFusionBlock, self).__init__() |
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features) |
||||
self.resConfUnit2 = ResidualConvUnit(features) |
||||
|
||||
def forward(self, *xs): |
||||
"""Forward pass. |
||||
|
||||
Returns: |
||||
tensor: output |
||||
""" |
||||
output = xs[0] |
||||
|
||||
if len(xs) == 2: |
||||
output += self.resConfUnit1(xs[1]) |
||||
|
||||
output = self.resConfUnit2(output) |
||||
|
||||
output = nn.functional.interpolate( |
||||
output, scale_factor=2, mode="bilinear", align_corners=True |
||||
) |
||||
|
||||
return output |
||||
|
||||
|
||||
|
||||
|
||||
class ResidualConvUnit_custom(nn.Module): |
||||
"""Residual convolution module. |
||||
""" |
||||
|
||||
def __init__(self, features, activation, bn): |
||||
"""Init. |
||||
|
||||
Args: |
||||
features (int): number of features |
||||
""" |
||||
super().__init__() |
||||
|
||||
self.bn = bn |
||||
|
||||
self.groups=1 |
||||
|
||||
self.conv1 = nn.Conv2d( |
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups |
||||
) |
||||
|
||||
self.conv2 = nn.Conv2d( |
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups |
||||
) |
||||
|
||||
if self.bn==True: |
||||
self.bn1 = nn.BatchNorm2d(features) |
||||
self.bn2 = nn.BatchNorm2d(features) |
||||
|
||||
self.activation = activation |
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional() |
||||
|
||||
def forward(self, x): |
||||
"""Forward pass. |
||||
|
||||
Args: |
||||
x (tensor): input |
||||
|
||||
Returns: |
||||
tensor: output |
||||
""" |
||||
|
||||
out = self.activation(x) |
||||
out = self.conv1(out) |
||||
if self.bn==True: |
||||
out = self.bn1(out) |
||||
|
||||
out = self.activation(out) |
||||
out = self.conv2(out) |
||||
if self.bn==True: |
||||
out = self.bn2(out) |
||||
|
||||
if self.groups > 1: |
||||
out = self.conv_merge(out) |
||||
|
||||
return self.skip_add.add(out, x) |
||||
|
||||
# return out + x |
||||
|
||||
|
||||
class FeatureFusionBlock_custom(nn.Module): |
||||
"""Feature fusion block. |
||||
""" |
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): |
||||
"""Init. |
||||
|
||||
Args: |
||||
features (int): number of features |
||||
""" |
||||
super(FeatureFusionBlock_custom, self).__init__() |
||||
|
||||
self.deconv = deconv |
||||
self.align_corners = align_corners |
||||
|
||||
self.groups=1 |
||||
|
||||
self.expand = expand |
||||
out_features = features |
||||
if self.expand==True: |
||||
out_features = features//2 |
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) |
||||
|
||||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) |
||||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) |
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional() |
||||
|
||||
def forward(self, *xs): |
||||
"""Forward pass. |
||||
|
||||
Returns: |
||||
tensor: output |
||||
""" |
||||
output = xs[0] |
||||
|
||||
if len(xs) == 2: |
||||
res = self.resConfUnit1(xs[1]) |
||||
output = self.skip_add.add(output, res) |
||||
# output += res |
||||
|
||||
output = self.resConfUnit2(output) |
||||
|
||||
output = nn.functional.interpolate( |
||||
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners |
||||
) |
||||
|
||||
output = self.out_conv(output) |
||||
|
||||
return output |
||||
|
@ -1,109 +0,0 @@
|
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
|
||||
from .base_model import BaseModel |
||||
from .blocks import ( |
||||
FeatureFusionBlock, |
||||
FeatureFusionBlock_custom, |
||||
Interpolate, |
||||
_make_encoder, |
||||
forward_vit, |
||||
) |
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn): |
||||
return FeatureFusionBlock_custom( |
||||
features, |
||||
nn.ReLU(False), |
||||
deconv=False, |
||||
bn=use_bn, |
||||
expand=False, |
||||
align_corners=True, |
||||
) |
||||
|
||||
|
||||
class DPT(BaseModel): |
||||
def __init__( |
||||
self, |
||||
head, |
||||
features=256, |
||||
backbone="vitb_rn50_384", |
||||
readout="project", |
||||
channels_last=False, |
||||
use_bn=False, |
||||
): |
||||
|
||||
super(DPT, self).__init__() |
||||
|
||||
self.channels_last = channels_last |
||||
|
||||
hooks = { |
||||
"vitb_rn50_384": [0, 1, 8, 11], |
||||
"vitb16_384": [2, 5, 8, 11], |
||||
"vitl16_384": [5, 11, 17, 23], |
||||
} |
||||
|
||||
# Instantiate backbone and reassemble blocks |
||||
self.pretrained, self.scratch = _make_encoder( |
||||
backbone, |
||||
features, |
||||
False, # Set to true of you want to train from scratch, uses ImageNet weights |
||||
groups=1, |
||||
expand=False, |
||||
exportable=False, |
||||
hooks=hooks[backbone], |
||||
use_readout=readout, |
||||
) |
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn) |
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn) |
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn) |
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn) |
||||
|
||||
self.scratch.output_conv = head |
||||
|
||||
|
||||
def forward(self, x): |
||||
if self.channels_last == True: |
||||
x.contiguous(memory_format=torch.channels_last) |
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) |
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1) |
||||
layer_2_rn = self.scratch.layer2_rn(layer_2) |
||||
layer_3_rn = self.scratch.layer3_rn(layer_3) |
||||
layer_4_rn = self.scratch.layer4_rn(layer_4) |
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn) |
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn) |
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn) |
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn) |
||||
|
||||
out = self.scratch.output_conv(path_1) |
||||
|
||||
return out |
||||
|
||||
|
||||
class DPTDepthModel(DPT): |
||||
def __init__(self, path=None, non_negative=True, **kwargs): |
||||
features = kwargs["features"] if "features" in kwargs else 256 |
||||
|
||||
head = nn.Sequential( |
||||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), |
||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True), |
||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), |
||||
nn.ReLU(True), |
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), |
||||
nn.ReLU(True) if non_negative else nn.Identity(), |
||||
nn.Identity(), |
||||
) |
||||
|
||||
super().__init__(head, **kwargs) |
||||
|
||||
if path is not None: |
||||
self.load(path) |
||||
|
||||
def forward(self, x): |
||||
return super().forward(x).squeeze(dim=1) |
||||
|
@ -1,76 +0,0 @@
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. |
||||
This file contains code that is adapted from |
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py |
||||
""" |
||||
import torch |
||||
import torch.nn as nn |
||||
|
||||
from .base_model import BaseModel |
||||
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder |
||||
|
||||
|
||||
class MidasNet(BaseModel): |
||||
"""Network for monocular depth estimation. |
||||
""" |
||||
|
||||
def __init__(self, path=None, features=256, non_negative=True): |
||||
"""Init. |
||||
|
||||
Args: |
||||
path (str, optional): Path to saved model. Defaults to None. |
||||
features (int, optional): Number of features. Defaults to 256. |
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50 |
||||
""" |
||||
print("Loading weights: ", path) |
||||
|
||||
super(MidasNet, self).__init__() |
||||
|
||||
use_pretrained = False if path is None else True |
||||
|
||||
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) |
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock(features) |
||||
self.scratch.refinenet3 = FeatureFusionBlock(features) |
||||
self.scratch.refinenet2 = FeatureFusionBlock(features) |
||||
self.scratch.refinenet1 = FeatureFusionBlock(features) |
||||
|
||||
self.scratch.output_conv = nn.Sequential( |
||||
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), |
||||
Interpolate(scale_factor=2, mode="bilinear"), |
||||
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), |
||||
nn.ReLU(True), |
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), |
||||
nn.ReLU(True) if non_negative else nn.Identity(), |
||||
) |
||||
|
||||
if path: |
||||
self.load(path) |
||||
|
||||
def forward(self, x): |
||||
"""Forward pass. |
||||
|
||||
Args: |
||||
x (tensor): input data (image) |
||||
|
||||
Returns: |
||||
tensor: depth |
||||
""" |
||||
|
||||
layer_1 = self.pretrained.layer1(x) |
||||
layer_2 = self.pretrained.layer2(layer_1) |
||||
layer_3 = self.pretrained.layer3(layer_2) |
||||
layer_4 = self.pretrained.layer4(layer_3) |
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1) |
||||
layer_2_rn = self.scratch.layer2_rn(layer_2) |
||||
layer_3_rn = self.scratch.layer3_rn(layer_3) |
||||
layer_4_rn = self.scratch.layer4_rn(layer_4) |
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn) |
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn) |
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn) |
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn) |
||||
|
||||
out = self.scratch.output_conv(path_1) |
||||
|
||||
return torch.squeeze(out, dim=1) |
@ -1,128 +0,0 @@
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. |
||||
This file contains code that is adapted from |
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py |
||||
""" |
||||
import torch |
||||
import torch.nn as nn |
||||
|
||||
from .base_model import BaseModel |
||||
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder |
||||
|
||||
|
||||
class MidasNet_small(BaseModel): |
||||
"""Network for monocular depth estimation. |
||||
""" |
||||
|
||||
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, |
||||
blocks={'expand': True}): |
||||
"""Init. |
||||
|
||||
Args: |
||||
path (str, optional): Path to saved model. Defaults to None. |
||||
features (int, optional): Number of features. Defaults to 256. |
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50 |
||||
""" |
||||
print("Loading weights: ", path) |
||||
|
||||
super(MidasNet_small, self).__init__() |
||||
|
||||
use_pretrained = False if path else True |
||||
|
||||
self.channels_last = channels_last |
||||
self.blocks = blocks |
||||
self.backbone = backbone |
||||
|
||||
self.groups = 1 |
||||
|
||||
features1=features |
||||
features2=features |
||||
features3=features |
||||
features4=features |
||||
self.expand = False |
||||
if "expand" in self.blocks and self.blocks['expand'] == True: |
||||
self.expand = True |
||||
features1=features |
||||
features2=features*2 |
||||
features3=features*4 |
||||
features4=features*8 |
||||
|
||||
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) |
||||
|
||||
self.scratch.activation = nn.ReLU(False) |
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) |
||||
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) |
||||
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) |
||||
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) |
||||
|
||||
|
||||
self.scratch.output_conv = nn.Sequential( |
||||
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), |
||||
Interpolate(scale_factor=2, mode="bilinear"), |
||||
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), |
||||
self.scratch.activation, |
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), |
||||
nn.ReLU(True) if non_negative else nn.Identity(), |
||||
nn.Identity(), |
||||
) |
||||
|
||||
if path: |
||||
self.load(path) |
||||
|
||||
|
||||
def forward(self, x): |
||||
"""Forward pass. |
||||
|
||||
Args: |
||||
x (tensor): input data (image) |
||||
|
||||
Returns: |
||||
tensor: depth |
||||
""" |
||||
if self.channels_last==True: |
||||
print("self.channels_last = ", self.channels_last) |
||||
x.contiguous(memory_format=torch.channels_last) |
||||
|
||||
|
||||
layer_1 = self.pretrained.layer1(x) |
||||
layer_2 = self.pretrained.layer2(layer_1) |
||||
layer_3 = self.pretrained.layer3(layer_2) |
||||
layer_4 = self.pretrained.layer4(layer_3) |
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1) |
||||
layer_2_rn = self.scratch.layer2_rn(layer_2) |
||||
layer_3_rn = self.scratch.layer3_rn(layer_3) |
||||
layer_4_rn = self.scratch.layer4_rn(layer_4) |
||||
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn) |
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn) |
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn) |
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn) |
||||
|
||||
out = self.scratch.output_conv(path_1) |
||||
|
||||
return torch.squeeze(out, dim=1) |
||||
|
||||
|
||||
|
||||
def fuse_model(m): |
||||
prev_previous_type = nn.Identity() |
||||
prev_previous_name = '' |
||||
previous_type = nn.Identity() |
||||
previous_name = '' |
||||
for name, module in m.named_modules(): |
||||
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: |
||||
# print("FUSED ", prev_previous_name, previous_name, name) |
||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) |
||||
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: |
||||
# print("FUSED ", prev_previous_name, previous_name) |
||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) |
||||
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU: |
||||
# print("FUSED ", previous_name, name) |
||||
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) |
||||
|
||||
prev_previous_type = previous_type |
||||
prev_previous_name = previous_name |
||||
previous_type = type(module) |
||||
previous_name = name |
@ -1,234 +0,0 @@
|
||||
import numpy as np |
||||
import cv2 |
||||
import math |
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): |
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio. |
||||
|
||||
Args: |
||||
sample (dict): sample |
||||
size (tuple): image size |
||||
|
||||
Returns: |
||||
tuple: new size |
||||
""" |
||||
shape = list(sample["disparity"].shape) |
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]: |
||||
return sample |
||||
|
||||
scale = [0, 0] |
||||
scale[0] = size[0] / shape[0] |
||||
scale[1] = size[1] / shape[1] |
||||
|
||||
scale = max(scale) |
||||
|
||||
shape[0] = math.ceil(scale * shape[0]) |
||||
shape[1] = math.ceil(scale * shape[1]) |
||||
|
||||
# resize |
||||
sample["image"] = cv2.resize( |
||||
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method |
||||
) |
||||
|
||||
sample["disparity"] = cv2.resize( |
||||
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST |
||||
) |
||||
sample["mask"] = cv2.resize( |
||||
sample["mask"].astype(np.float32), |
||||
tuple(shape[::-1]), |
||||
interpolation=cv2.INTER_NEAREST, |
||||
) |
||||
sample["mask"] = sample["mask"].astype(bool) |
||||
|
||||
return tuple(shape) |
||||
|
||||
|
||||
class Resize(object): |
||||
"""Resize sample to given size (width, height). |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
width, |
||||
height, |
||||
resize_target=True, |
||||
keep_aspect_ratio=False, |
||||
ensure_multiple_of=1, |
||||
resize_method="lower_bound", |
||||
image_interpolation_method=cv2.INTER_AREA, |
||||
): |
||||
"""Init. |
||||
|
||||
Args: |
||||
width (int): desired output width |
||||
height (int): desired output height |
||||
resize_target (bool, optional): |
||||
True: Resize the full sample (image, mask, target). |
||||
False: Resize image only. |
||||
Defaults to True. |
||||
keep_aspect_ratio (bool, optional): |
||||
True: Keep the aspect ratio of the input sample. |
||||
Output sample might not have the given width and height, and |
||||
resize behaviour depends on the parameter 'resize_method'. |
||||
Defaults to False. |
||||
ensure_multiple_of (int, optional): |
||||
Output width and height is constrained to be multiple of this parameter. |
||||
Defaults to 1. |
||||
resize_method (str, optional): |
||||
"lower_bound": Output will be at least as large as the given size. |
||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) |
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.) |
||||
Defaults to "lower_bound". |
||||
""" |
||||
self.__width = width |
||||
self.__height = height |
||||
|
||||
self.__resize_target = resize_target |
||||
self.__keep_aspect_ratio = keep_aspect_ratio |
||||
self.__multiple_of = ensure_multiple_of |
||||
self.__resize_method = resize_method |
||||
self.__image_interpolation_method = image_interpolation_method |
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None): |
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) |
||||
|
||||
if max_val is not None and y > max_val: |
||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) |
||||
|
||||
if y < min_val: |
||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) |
||||
|
||||
return y |
||||
|
||||
def get_size(self, width, height): |
||||
# determine new height and width |
||||
scale_height = self.__height / height |
||||
scale_width = self.__width / width |
||||
|
||||
if self.__keep_aspect_ratio: |
||||
if self.__resize_method == "lower_bound": |
||||
# scale such that output size is lower bound |
||||
if scale_width > scale_height: |
||||
# fit width |
||||
scale_height = scale_width |
||||
else: |
||||
# fit height |
||||
scale_width = scale_height |
||||
elif self.__resize_method == "upper_bound": |
||||
# scale such that output size is upper bound |
||||
if scale_width < scale_height: |
||||
# fit width |
||||
scale_height = scale_width |
||||
else: |
||||
# fit height |
||||
scale_width = scale_height |
||||
elif self.__resize_method == "minimal": |
||||
# scale as least as possbile |
||||
if abs(1 - scale_width) < abs(1 - scale_height): |
||||
# fit width |
||||
scale_height = scale_width |
||||
else: |
||||
# fit height |
||||
scale_width = scale_height |
||||
else: |
||||
raise ValueError( |
||||
f"resize_method {self.__resize_method} not implemented" |
||||
) |
||||
|
||||
if self.__resize_method == "lower_bound": |
||||
new_height = self.constrain_to_multiple_of( |
||||
scale_height * height, min_val=self.__height |
||||
) |
||||
new_width = self.constrain_to_multiple_of( |
||||
scale_width * width, min_val=self.__width |
||||
) |
||||
elif self.__resize_method == "upper_bound": |
||||
new_height = self.constrain_to_multiple_of( |
||||
scale_height * height, max_val=self.__height |
||||
) |
||||
new_width = self.constrain_to_multiple_of( |
||||
scale_width * width, max_val=self.__width |
||||
) |
||||
elif self.__resize_method == "minimal": |
||||
new_height = self.constrain_to_multiple_of(scale_height * height) |
||||
new_width = self.constrain_to_multiple_of(scale_width * width) |
||||
else: |
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented") |
||||
|
||||
return (new_width, new_height) |
||||
|
||||
def __call__(self, sample): |
||||
width, height = self.get_size( |
||||
sample["image"].shape[1], sample["image"].shape[0] |
||||
) |
||||
|
||||
# resize sample |
||||
sample["image"] = cv2.resize( |
||||
sample["image"], |
||||
(width, height), |
||||
interpolation=self.__image_interpolation_method, |
||||
) |
||||
|
||||
if self.__resize_target: |
||||
if "disparity" in sample: |
||||
sample["disparity"] = cv2.resize( |
||||
sample["disparity"], |
||||
(width, height), |
||||
interpolation=cv2.INTER_NEAREST, |
||||
) |
||||
|
||||
if "depth" in sample: |
||||
sample["depth"] = cv2.resize( |
||||
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST |
||||
) |
||||
|
||||
sample["mask"] = cv2.resize( |
||||
sample["mask"].astype(np.float32), |
||||
(width, height), |
||||
interpolation=cv2.INTER_NEAREST, |
||||
) |
||||
sample["mask"] = sample["mask"].astype(bool) |
||||
|
||||
return sample |
||||
|
||||
|
||||
class NormalizeImage(object): |
||||
"""Normlize image by given mean and std. |
||||
""" |
||||
|
||||
def __init__(self, mean, std): |
||||
self.__mean = mean |
||||
self.__std = std |
||||
|
||||
def __call__(self, sample): |
||||
sample["image"] = (sample["image"] - self.__mean) / self.__std |
||||
|
||||
return sample |
||||
|
||||
|
||||
class PrepareForNet(object): |
||||
"""Prepare sample for usage as network input. |
||||
""" |
||||
|
||||
def __init__(self): |
||||
pass |
||||
|
||||
def __call__(self, sample): |
||||
image = np.transpose(sample["image"], (2, 0, 1)) |
||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32) |
||||
|
||||
if "mask" in sample: |
||||
sample["mask"] = sample["mask"].astype(np.float32) |
||||
sample["mask"] = np.ascontiguousarray(sample["mask"]) |
||||
|
||||
if "disparity" in sample: |
||||
disparity = sample["disparity"].astype(np.float32) |
||||
sample["disparity"] = np.ascontiguousarray(disparity) |
||||
|
||||
if "depth" in sample: |
||||
depth = sample["depth"].astype(np.float32) |
||||
sample["depth"] = np.ascontiguousarray(depth) |
||||
|
||||
return sample |
@ -1,491 +0,0 @@
|
||||
import torch |
||||
import torch.nn as nn |
||||
import timm |
||||
import types |
||||
import math |
||||
import torch.nn.functional as F |
||||
|
||||
|
||||
class Slice(nn.Module): |
||||
def __init__(self, start_index=1): |
||||
super(Slice, self).__init__() |
||||
self.start_index = start_index |
||||
|
||||
def forward(self, x): |
||||
return x[:, self.start_index :] |
||||
|
||||
|
||||
class AddReadout(nn.Module): |
||||
def __init__(self, start_index=1): |
||||
super(AddReadout, self).__init__() |
||||
self.start_index = start_index |
||||
|
||||
def forward(self, x): |
||||
if self.start_index == 2: |
||||
readout = (x[:, 0] + x[:, 1]) / 2 |
||||
else: |
||||
readout = x[:, 0] |
||||
return x[:, self.start_index :] + readout.unsqueeze(1) |
||||
|
||||
|
||||
class ProjectReadout(nn.Module): |
||||
def __init__(self, in_features, start_index=1): |
||||
super(ProjectReadout, self).__init__() |
||||
self.start_index = start_index |
||||
|
||||
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) |
||||
|
||||
def forward(self, x): |
||||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) |
||||
features = torch.cat((x[:, self.start_index :], readout), -1) |
||||
|
||||
return self.project(features) |
||||
|
||||
|
||||
class Transpose(nn.Module): |
||||
def __init__(self, dim0, dim1): |
||||
super(Transpose, self).__init__() |
||||
self.dim0 = dim0 |
||||
self.dim1 = dim1 |
||||
|
||||
def forward(self, x): |
||||
x = x.transpose(self.dim0, self.dim1) |
||||
return x |
||||
|
||||
|
||||
def forward_vit(pretrained, x): |
||||
b, c, h, w = x.shape |
||||
|
||||
glob = pretrained.model.forward_flex(x) |
||||
|
||||
layer_1 = pretrained.activations["1"] |
||||
layer_2 = pretrained.activations["2"] |
||||
layer_3 = pretrained.activations["3"] |
||||
layer_4 = pretrained.activations["4"] |
||||
|
||||
layer_1 = pretrained.act_postprocess1[0:2](layer_1) |
||||
layer_2 = pretrained.act_postprocess2[0:2](layer_2) |
||||
layer_3 = pretrained.act_postprocess3[0:2](layer_3) |
||||
layer_4 = pretrained.act_postprocess4[0:2](layer_4) |
||||
|
||||
unflatten = nn.Sequential( |
||||
nn.Unflatten( |
||||
2, |
||||
torch.Size( |
||||
[ |
||||
h // pretrained.model.patch_size[1], |
||||
w // pretrained.model.patch_size[0], |
||||
] |
||||
), |
||||
) |
||||
) |
||||
|
||||
if layer_1.ndim == 3: |
||||
layer_1 = unflatten(layer_1) |
||||
if layer_2.ndim == 3: |
||||
layer_2 = unflatten(layer_2) |
||||
if layer_3.ndim == 3: |
||||
layer_3 = unflatten(layer_3) |
||||
if layer_4.ndim == 3: |
||||
layer_4 = unflatten(layer_4) |
||||
|
||||
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) |
||||
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) |
||||
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) |
||||
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) |
||||
|
||||
return layer_1, layer_2, layer_3, layer_4 |
||||
|
||||
|
||||
def _resize_pos_embed(self, posemb, gs_h, gs_w): |
||||
posemb_tok, posemb_grid = ( |
||||
posemb[:, : self.start_index], |
||||
posemb[0, self.start_index :], |
||||
) |
||||
|
||||
gs_old = int(math.sqrt(len(posemb_grid))) |
||||
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) |
||||
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") |
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) |
||||
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1) |
||||
|
||||
return posemb |
||||
|
||||
|
||||
def forward_flex(self, x): |
||||
b, c, h, w = x.shape |
||||
|
||||
pos_embed = self._resize_pos_embed( |
||||
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] |
||||
) |
||||
|
||||
B = x.shape[0] |
||||
|
||||
if hasattr(self.patch_embed, "backbone"): |
||||
x = self.patch_embed.backbone(x) |
||||
if isinstance(x, (list, tuple)): |
||||
x = x[-1] # last feature if backbone outputs list/tuple of features |
||||
|
||||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) |
||||
|
||||
if getattr(self, "dist_token", None) is not None: |
||||
cls_tokens = self.cls_token.expand( |
||||
B, -1, -1 |
||||
) # stole cls_tokens impl from Phil Wang, thanks |
||||
dist_token = self.dist_token.expand(B, -1, -1) |
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1) |
||||
else: |
||||
cls_tokens = self.cls_token.expand( |
||||
B, -1, -1 |
||||
) # stole cls_tokens impl from Phil Wang, thanks |
||||
x = torch.cat((cls_tokens, x), dim=1) |
||||
|
||||
x = x + pos_embed |
||||
x = self.pos_drop(x) |
||||
|
||||
for blk in self.blocks: |
||||
x = blk(x) |
||||
|
||||
x = self.norm(x) |
||||
|
||||
return x |
||||
|
||||
|
||||
activations = {} |
||||
|
||||
|
||||
def get_activation(name): |
||||
def hook(model, input, output): |
||||
activations[name] = output |
||||
|
||||
return hook |
||||
|
||||
|
||||
def get_readout_oper(vit_features, features, use_readout, start_index=1): |
||||
if use_readout == "ignore": |
||||
readout_oper = [Slice(start_index)] * len(features) |
||||
elif use_readout == "add": |
||||
readout_oper = [AddReadout(start_index)] * len(features) |
||||
elif use_readout == "project": |
||||
readout_oper = [ |
||||
ProjectReadout(vit_features, start_index) for out_feat in features |
||||
] |
||||
else: |
||||
assert ( |
||||
False |
||||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" |
||||
|
||||
return readout_oper |
||||
|
||||
|
||||
def _make_vit_b16_backbone( |
||||
model, |
||||
features=[96, 192, 384, 768], |
||||
size=[384, 384], |
||||
hooks=[2, 5, 8, 11], |
||||
vit_features=768, |
||||
use_readout="ignore", |
||||
start_index=1, |
||||
): |
||||
pretrained = nn.Module() |
||||
|
||||
pretrained.model = model |
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) |
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) |
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) |
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) |
||||
|
||||
pretrained.activations = activations |
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) |
||||
|
||||
# 32, 48, 136, 384 |
||||
pretrained.act_postprocess1 = nn.Sequential( |
||||
readout_oper[0], |
||||
Transpose(1, 2), |
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), |
||||
nn.Conv2d( |
||||
in_channels=vit_features, |
||||
out_channels=features[0], |
||||
kernel_size=1, |
||||
stride=1, |
||||
padding=0, |
||||
), |
||||
nn.ConvTranspose2d( |
||||
in_channels=features[0], |
||||
out_channels=features[0], |
||||
kernel_size=4, |
||||
stride=4, |
||||
padding=0, |
||||
bias=True, |
||||
dilation=1, |
||||
groups=1, |
||||
), |
||||
) |
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential( |
||||
readout_oper[1], |
||||
Transpose(1, 2), |
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), |
||||
nn.Conv2d( |
||||
in_channels=vit_features, |
||||
out_channels=features[1], |
||||
kernel_size=1, |
||||
stride=1, |
||||
padding=0, |
||||
), |
||||
nn.ConvTranspose2d( |
||||
in_channels=features[1], |
||||
out_channels=features[1], |
||||
kernel_size=2, |
||||
stride=2, |
||||
padding=0, |
||||
bias=True, |
||||
dilation=1, |
||||
groups=1, |
||||
), |
||||
) |
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential( |
||||
readout_oper[2], |
||||
Transpose(1, 2), |
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), |
||||
nn.Conv2d( |
||||
in_channels=vit_features, |
||||
out_channels=features[2], |
||||
kernel_size=1, |
||||
stride=1, |
||||
padding=0, |
||||
), |
||||
) |
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential( |
||||
readout_oper[3], |
||||
Transpose(1, 2), |
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), |
||||
nn.Conv2d( |
||||
in_channels=vit_features, |
||||
out_channels=features[3], |
||||
kernel_size=1, |
||||
stride=1, |
||||
padding=0, |
||||
), |
||||
nn.Conv2d( |
||||
in_channels=features[3], |
||||
out_channels=features[3], |
||||
kernel_size=3, |
||||
stride=2, |
||||
padding=1, |
||||
), |
||||
) |
||||
|
||||
pretrained.model.start_index = start_index |
||||
pretrained.model.patch_size = [16, 16] |
||||
|
||||
# We inject this function into the VisionTransformer instances so that |
||||
# we can use it with interpolated position embeddings without modifying the library source. |
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) |
||||
pretrained.model._resize_pos_embed = types.MethodType( |
||||
_resize_pos_embed, pretrained.model |
||||
) |
||||
|
||||
return pretrained |
||||
|
||||
|
||||
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): |
||||
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) |
||||
|
||||
hooks = [5, 11, 17, 23] if hooks == None else hooks |
||||
return _make_vit_b16_backbone( |
||||
model, |
||||
features=[256, 512, 1024, 1024], |
||||
hooks=hooks, |
||||
vit_features=1024, |
||||
use_readout=use_readout, |
||||
) |
||||
|
||||
|
||||
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): |
||||
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) |
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks |
||||
return _make_vit_b16_backbone( |
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout |
||||
) |
||||
|
||||
|
||||
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): |
||||
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) |
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks |
||||
return _make_vit_b16_backbone( |
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout |
||||
) |
||||
|
||||
|
||||
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): |
||||
model = timm.create_model( |
||||
"vit_deit_base_distilled_patch16_384", pretrained=pretrained |
||||
) |
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks |
||||
return _make_vit_b16_backbone( |
||||
model, |
||||
features=[96, 192, 384, 768], |
||||
hooks=hooks, |
||||
use_readout=use_readout, |
||||
start_index=2, |
||||
) |
||||
|
||||
|
||||
def _make_vit_b_rn50_backbone( |
||||
model, |
||||
features=[256, 512, 768, 768], |
||||
size=[384, 384], |
||||
hooks=[0, 1, 8, 11], |
||||
vit_features=768, |
||||
use_vit_only=False, |
||||
use_readout="ignore", |
||||
start_index=1, |
||||
): |
||||
pretrained = nn.Module() |
||||
|
||||
pretrained.model = model |
||||
|
||||
if use_vit_only == True: |
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) |
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) |
||||
else: |
||||
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( |
||||
get_activation("1") |
||||
) |
||||
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( |
||||
get_activation("2") |
||||
) |
||||
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) |
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) |
||||
|
||||
pretrained.activations = activations |
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) |
||||
|
||||
if use_vit_only == True: |
||||
pretrained.act_postprocess1 = nn.Sequential( |
||||
readout_oper[0], |
||||
Transpose(1, 2), |
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), |
||||
nn.Conv2d( |
||||
in_channels=vit_features, |
||||
out_channels=features[0], |
||||
kernel_size=1, |
||||
stride=1, |
||||
padding=0, |
||||
), |
||||
nn.ConvTranspose2d( |
||||
in_channels=features[0], |
||||
out_channels=features[0], |
||||
kernel_size=4, |
||||
stride=4, |
||||
padding=0, |
||||
bias=True, |
||||
dilation=1, |
||||
groups=1, |
||||
), |
||||
) |
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential( |
||||
readout_oper[1], |
||||
Transpose(1, 2), |
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), |
||||
nn.Conv2d( |
||||
in_channels=vit_features, |
||||
out_channels=features[1], |
||||
kernel_size=1, |
||||
stride=1, |
||||
padding=0, |
||||
), |
||||
nn.ConvTranspose2d( |
||||
in_channels=features[1], |
||||
out_channels=features[1], |
||||
kernel_size=2, |
||||
stride=2, |
||||
padding=0, |
||||
bias=True, |
||||
dilation=1, |
||||
groups=1, |
||||
), |
||||
) |
||||
else: |
||||
pretrained.act_postprocess1 = nn.Sequential( |
||||
nn.Identity(), nn.Identity(), nn.Identity() |
||||
) |
||||
pretrained.act_postprocess2 = nn.Sequential( |
||||
nn.Identity(), nn.Identity(), nn.Identity() |
||||
) |
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential( |
||||
readout_oper[2], |
||||
Transpose(1, 2), |
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), |
||||
nn.Conv2d( |
||||
in_channels=vit_features, |
||||
out_channels=features[2], |
||||
kernel_size=1, |
||||
stride=1, |
||||
padding=0, |
||||
), |
||||
) |
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential( |
||||
readout_oper[3], |
||||
Transpose(1, 2), |
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), |
||||
nn.Conv2d( |
||||
in_channels=vit_features, |
||||
out_channels=features[3], |
||||
kernel_size=1, |
||||
stride=1, |
||||
padding=0, |
||||
), |
||||
nn.Conv2d( |
||||
in_channels=features[3], |
||||
out_channels=features[3], |
||||
kernel_size=3, |
||||
stride=2, |
||||
padding=1, |
||||
), |
||||
) |
||||
|
||||
pretrained.model.start_index = start_index |
||||
pretrained.model.patch_size = [16, 16] |
||||
|
||||
# We inject this function into the VisionTransformer instances so that |
||||
# we can use it with interpolated position embeddings without modifying the library source. |
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) |
||||
|
||||
# We inject this function into the VisionTransformer instances so that |
||||
# we can use it with interpolated position embeddings without modifying the library source. |
||||
pretrained.model._resize_pos_embed = types.MethodType( |
||||
_resize_pos_embed, pretrained.model |
||||
) |
||||
|
||||
return pretrained |
||||
|
||||
|
||||
def _make_pretrained_vitb_rn50_384( |
||||
pretrained, use_readout="ignore", hooks=None, use_vit_only=False |
||||
): |
||||
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) |
||||
|
||||
hooks = [0, 1, 8, 11] if hooks == None else hooks |
||||
return _make_vit_b_rn50_backbone( |
||||
model, |
||||
features=[256, 512, 768, 768], |
||||
size=[384, 384], |
||||
hooks=hooks, |
||||
use_vit_only=use_vit_only, |
||||
use_readout=use_readout, |
||||
) |
@ -1,189 +0,0 @@
|
||||
"""Utils for monoDepth.""" |
||||
import sys |
||||
import re |
||||
import numpy as np |
||||
import cv2 |
||||
import torch |
||||
|
||||
|
||||
def read_pfm(path): |
||||
"""Read pfm file. |
||||
|
||||
Args: |
||||
path (str): path to file |
||||
|
||||
Returns: |
||||
tuple: (data, scale) |
||||
""" |
||||
with open(path, "rb") as file: |
||||
|
||||
color = None |
||||
width = None |
||||
height = None |
||||
scale = None |
||||
endian = None |
||||
|
||||
header = file.readline().rstrip() |
||||
if header.decode("ascii") == "PF": |
||||
color = True |
||||
elif header.decode("ascii") == "Pf": |
||||
color = False |
||||
else: |
||||
raise Exception("Not a PFM file: " + path) |
||||
|
||||
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) |
||||
if dim_match: |
||||
width, height = list(map(int, dim_match.groups())) |
||||
else: |
||||
raise Exception("Malformed PFM header.") |
||||
|
||||
scale = float(file.readline().decode("ascii").rstrip()) |
||||
if scale < 0: |
||||
# little-endian |
||||
endian = "<" |
||||
scale = -scale |
||||
else: |
||||
# big-endian |
||||
endian = ">" |
||||
|
||||
data = np.fromfile(file, endian + "f") |
||||
shape = (height, width, 3) if color else (height, width) |
||||
|
||||
data = np.reshape(data, shape) |
||||
data = np.flipud(data) |
||||
|
||||
return data, scale |
||||
|
||||
|
||||
def write_pfm(path, image, scale=1): |
||||
"""Write pfm file. |
||||
|
||||
Args: |
||||
path (str): pathto file |
||||
image (array): data |
||||
scale (int, optional): Scale. Defaults to 1. |
||||
""" |
||||
|
||||
with open(path, "wb") as file: |
||||
color = None |
||||
|
||||
if image.dtype.name != "float32": |
||||
raise Exception("Image dtype must be float32.") |
||||
|
||||
image = np.flipud(image) |
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image |
||||
color = True |
||||
elif ( |
||||
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 |
||||
): # greyscale |
||||
color = False |
||||
else: |
||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") |
||||
|
||||
file.write("PF\n" if color else "Pf\n".encode()) |
||||
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) |
||||
|
||||
endian = image.dtype.byteorder |
||||
|
||||
if endian == "<" or endian == "=" and sys.byteorder == "little": |
||||
scale = -scale |
||||
|
||||
file.write("%f\n".encode() % scale) |
||||
|
||||
image.tofile(file) |
||||
|
||||
|
||||
def read_image(path): |
||||
"""Read image and output RGB image (0-1). |
||||
|
||||
Args: |
||||
path (str): path to file |
||||
|
||||
Returns: |
||||
array: RGB image (0-1) |
||||
""" |
||||
img = cv2.imread(path) |
||||
|
||||
if img.ndim == 2: |
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 |
||||
|
||||
return img |
||||
|
||||
|
||||
def resize_image(img): |
||||
"""Resize image and make it fit for network. |
||||
|
||||
Args: |
||||
img (array): image |
||||
|
||||
Returns: |
||||
tensor: data ready for network |
||||
""" |
||||
height_orig = img.shape[0] |
||||
width_orig = img.shape[1] |
||||
|
||||
if width_orig > height_orig: |
||||
scale = width_orig / 384 |
||||
else: |
||||
scale = height_orig / 384 |
||||
|
||||
height = (np.ceil(height_orig / scale / 32) * 32).astype(int) |
||||
width = (np.ceil(width_orig / scale / 32) * 32).astype(int) |
||||
|
||||
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) |
||||
|
||||
img_resized = ( |
||||
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() |
||||
) |
||||
img_resized = img_resized.unsqueeze(0) |
||||
|
||||
return img_resized |
||||
|
||||
|
||||
def resize_depth(depth, width, height): |
||||
"""Resize depth map and bring to CPU (numpy). |
||||
|
||||
Args: |
||||
depth (tensor): depth |
||||
width (int): image width |
||||
height (int): image height |
||||
|
||||
Returns: |
||||
array: processed depth |
||||
""" |
||||
depth = torch.squeeze(depth[0, :, :, :]).to("cpu") |
||||
|
||||
depth_resized = cv2.resize( |
||||
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC |
||||
) |
||||
|
||||
return depth_resized |
||||
|
||||
def write_depth(path, depth, bits=1): |
||||
"""Write depth map to pfm and png file. |
||||
|
||||
Args: |
||||
path (str): filepath without extension |
||||
depth (array): depth |
||||
""" |
||||
write_pfm(path + ".pfm", depth.astype(np.float32)) |
||||
|
||||
depth_min = depth.min() |
||||
depth_max = depth.max() |
||||
|
||||
max_val = (2**(8*bits))-1 |
||||
|
||||
if depth_max - depth_min > np.finfo("float").eps: |
||||
out = max_val * (depth - depth_min) / (depth_max - depth_min) |
||||
else: |
||||
out = np.zeros(depth.shape, dtype=depth.type) |
||||
|
||||
if bits == 1: |
||||
cv2.imwrite(path + ".png", out.astype("uint8")) |
||||
elif bits == 2: |
||||
cv2.imwrite(path + ".png", out.astype("uint16")) |
||||
|
||||
return |
Loading…
Reference in new issue