@ -786,6 +786,7 @@ class UNetModel(nn.Module):
if control is not None:
hsp += control.pop()
h = th.cat([h, hsp], dim=1)
del hsp
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids: