Browse Source

Initialize text encoder to target dtype.

pull/546/head
comfyanonymous 1 year ago
parent
commit
00c0b2c507
  1. 13
      comfy/ops.py
  2. 7
      comfy/sd.py
  3. 6
      comfy/sd1_clip.py
  4. 4
      comfy/sd2_clip.py
  5. 14
      comfy/sdxl_clip.py

13
comfy/ops.py

@ -28,9 +28,18 @@ def conv_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear
torch.nn.Linear = Linear
force_device = device
force_dtype = dtype
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
torch.nn.Linear = linear_with_dtype
try:
yield
finally:

7
comfy/sd.py

@ -545,9 +545,12 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = load_device
self.cond_stage_model = clip(**(params))
if model_management.should_use_fp16(load_device):
self.cond_stage_model.half()
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)

6
comfy/sd1_clip.py

@ -43,7 +43,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.num_layers = 12
@ -54,10 +54,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops():
with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
if dtype is not None:
self.transformer.to(dtype)
self.max_length = max_length
if freeze:
self.freeze()

4
comfy/sd2_clip.py

@ -3,13 +3,13 @@ import torch
import os
class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=23
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
def clip_layer(self, layer_idx):

14
comfy/sdxl_clip.py

@ -3,13 +3,13 @@ import torch
import os
class SDXLClipG(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
@ -42,11 +42,11 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
return self.clip_g.untokenize(token_weight_pair)
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device)
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
self.clip_l.clip_layer(layer_idx)
@ -70,9 +70,9 @@ class SDXLClipModel(torch.nn.Module):
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(torch.nn.Module):
def __init__(self, device="cpu"):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_g = SDXLClipG(device=device)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
self.clip_g.clip_layer(layer_idx)

Loading…
Cancel
Save