|
|
|
@ -413,6 +413,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
|
|
|
|
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) |
|
|
|
|
for y in range(0, s.shape[2], tile_y - overlap): |
|
|
|
|
for x in range(0, s.shape[3], tile_x - overlap): |
|
|
|
|
x = max(0, min(s.shape[-1] - overlap, x)) |
|
|
|
|
y = max(0, min(s.shape[-2] - overlap, y)) |
|
|
|
|
s_in = s[:,:,y:y+tile_y,x:x+tile_x] |
|
|
|
|
|
|
|
|
|
ps = function(s_in).to(output_device) |
|
|
|
|