From 5f9ddfa46f28ca2aa9e0bd832f6bbd67069be63e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:57:22 +0800 Subject: [PATCH] Add sdxl only arg --- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 0f14c71e4..20bfb2c44 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -119,3 +119,4 @@ parser.add_argument("--disable-all-extensions", action='store_true', help="preve parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) +parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b8ff8209..08af128fc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -394,6 +394,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.opt_unet_fp8_storage: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16