|
|
|
@ -186,18 +186,60 @@ class AttnBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
# compute attention |
|
|
|
|
b,c,h,w = q.shape |
|
|
|
|
scale = (int(c)**(-0.5)) |
|
|
|
|
|
|
|
|
|
q = q.reshape(b,c,h*w) |
|
|
|
|
q = q.permute(0,2,1) # b,hw,c |
|
|
|
|
k = k.reshape(b,c,h*w) # b,c,hw |
|
|
|
|
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] |
|
|
|
|
w_ = w_ * (int(c)**(-0.5)) |
|
|
|
|
w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
|
|
|
|
|
|
|
# attend to values |
|
|
|
|
v = v.reshape(b,c,h*w) |
|
|
|
|
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) |
|
|
|
|
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] |
|
|
|
|
h_ = h_.reshape(b,c,h,w) |
|
|
|
|
|
|
|
|
|
r1 = torch.zeros_like(k, device=q.device) |
|
|
|
|
|
|
|
|
|
stats = torch.cuda.memory_stats(q.device) |
|
|
|
|
mem_active = stats['active_bytes.all.current'] |
|
|
|
|
mem_reserved = stats['reserved_bytes.all.current'] |
|
|
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) |
|
|
|
|
mem_free_torch = mem_reserved - mem_active |
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch |
|
|
|
|
|
|
|
|
|
gb = 1024 ** 3 |
|
|
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() |
|
|
|
|
modifier = 3 if q.element_size() == 2 else 2.5 |
|
|
|
|
mem_required = tensor_size * modifier |
|
|
|
|
steps = 1 |
|
|
|
|
|
|
|
|
|
if mem_required > mem_free_total: |
|
|
|
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) |
|
|
|
|
|
|
|
|
|
first_op_done = False |
|
|
|
|
while True: |
|
|
|
|
try: |
|
|
|
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] |
|
|
|
|
for i in range(0, q.shape[1], slice_size): |
|
|
|
|
end = i + slice_size |
|
|
|
|
s1 = torch.bmm(q[:, i:end], k) * scale |
|
|
|
|
first_op_done = True |
|
|
|
|
|
|
|
|
|
torch.exp(s1, out=s1) |
|
|
|
|
summed = torch.sum(s1, dim=2, keepdim=True) |
|
|
|
|
s1 /= summed |
|
|
|
|
s2 = s1.permute(0,2,1) |
|
|
|
|
del s1 |
|
|
|
|
|
|
|
|
|
r1[:, :, i:end] = torch.bmm(v, s2) |
|
|
|
|
del s2 |
|
|
|
|
break |
|
|
|
|
except torch.cuda.OutOfMemoryError as e: |
|
|
|
|
if first_op_done == False: |
|
|
|
|
steps *= 2 |
|
|
|
|
if steps > 128: |
|
|
|
|
raise e |
|
|
|
|
print("out of memory error, increasing steps and trying again", steps) |
|
|
|
|
else: |
|
|
|
|
raise e |
|
|
|
|
|
|
|
|
|
h_ = r1.reshape(b,c,h,w) |
|
|
|
|
del r1 |
|
|
|
|
|
|
|
|
|
h_ = self.proj_out(h_) |
|
|
|
|
|
|
|
|
|