You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
24 lines
629 B
24 lines
629 B
import torch |
|
|
|
from ldm.modules.midas.api import load_midas_transform |
|
|
|
|
|
class AddMiDaS(object): |
|
def __init__(self, model_type): |
|
super().__init__() |
|
self.transform = load_midas_transform(model_type) |
|
|
|
def pt2np(self, x): |
|
x = ((x + 1.0) * .5).detach().cpu().numpy() |
|
return x |
|
|
|
def np2pt(self, x): |
|
x = torch.from_numpy(x) * 2 - 1. |
|
return x |
|
|
|
def __call__(self, sample): |
|
# sample['jpg'] is tensor hwc in [-1, 1] at this point |
|
x = self.pt2np(sample['jpg']) |
|
x = self.transform({"image": x})["image"] |
|
sample['midas_in'] = x |
|
return sample |