Browse Source

Switch text encoder to manual cast.

Use fp16 text encoder weights for CPU inference to lower memory usage.
pull/2269/head
comfyanonymous 11 months ago
parent
commit
57926635e8
  1. 3
      comfy/model_management.py
  2. 33
      comfy/ops.py
  3. 8
      comfy/sd1_clip.py

3
comfy/model_management.py

@ -503,6 +503,9 @@ def text_encoder_dtype(device=None):
elif args.fp32_text_enc:
return torch.float32
if is_device_cpu(device):
return torch.float16
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:

33
comfy/ops.py

@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs):
else:
raise ValueError(f"unsupported dimensions: {dims}")
def cast_bias_weight(s, input):
bias = None
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype)
weight = s.weight.to(device=input.device, dtype=input.dtype)
return weight, bias
class manual_cast:
class Linear(Linear):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(Conv2d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class Conv3d(Conv3d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class GroupNorm(GroupNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class LayerNorm(LayerNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
@contextmanager
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear

8
comfy/sd1_clip.py

@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
self.transformer = model_class(config, dtype, device, comfy.ops)
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
@ -160,12 +160,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if self.transformer.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)

Loading…
Cancel
Save