diff --git a/modules/sd_models.py b/modules/sd_models.py index f4274ae42..00b68289a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -952,7 +952,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if sd_model is not None: sd_unet.apply_unet("None") - send_model_to_cpu(sd_model) + send_model_to_device(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)