Browse Source
In `load_clip_model`, it used to check whether a GPU is being used by checking if `config.device` == "cuda". This is fine, assuming all users will pass a str for the device. Unfortunately, many users (including the `run_{cli,gradio}.py` scripts instead pass a `torch.device`, and `torch.device("cuda") != "cuda"` This commit makes it compare the `device.type` instead, which will be a string, making this condition pass, and uses float16 when possible.pull/46/head
bolshoytoster
2 years ago
1 changed files with 14 additions and 4 deletions
Loading…
Reference in new issue