From 209c26a1cb9e4be357ab3c5e7613caf3cbc26183 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:11:44 +0800 Subject: [PATCH 1/7] improve efficiency and support more device --- modules/devices.py | 60 ++++++++++++++++++++++++++++++------------ modules/shared_init.py | 1 + 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ff279ac50..6edfb1278 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -110,6 +110,7 @@ device_codeformer: torch.device = None dtype: torch.dtype = torch.float16 dtype_vae: torch.dtype = torch.float16 dtype_unet: torch.dtype = torch.float16 +dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False @@ -131,21 +132,49 @@ patch_module_list = [ ] -def manual_cast_forward(self, *args, **kwargs): - org_dtype = torch_utils.get_param(self).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result +def manual_cast_forward(target_dtype): + def forward_wrapper(self, *args, **kwargs): + org_dtype = torch_utils.get_param(self).dtype + if not target_dtype == org_dtype == dtype_inference: + self.to(target_dtype) + args = [ + arg.to(target_dtype) + if isinstance(arg, torch.Tensor) + else arg + for arg in args + ] + kwargs = { + k: v.to(target_dtype) + if isinstance(v, torch.Tensor) + else v + for k, v in kwargs.items() + } + + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + + if target_dtype != dtype_inference: + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + return forward_wrapper @contextlib.contextmanager -def manual_cast(): +def manual_cast(target_dtype): for module_type in patch_module_list: org_forward = module_type.forward - module_type.forward = manual_cast_forward + if module_type == torch.nn.MultiheadAttention and has_xpu(): + module_type.forward = manual_cast_forward(torch.float32) + else: + module_type.forward = manual_cast_forward(target_dtype) module_type.org_forward = org_forward try: yield None @@ -161,15 +190,12 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_cast() - - if has_mps() and shared.cmd_opts.precision != "full": - return manual_cast() - - if dtype == torch.float32 or shared.cmd_opts.precision == "full": + if dtype == torch.float32 and shared.cmd_opts.precision == "full": return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype_inference) + return torch.autocast("cuda") diff --git a/modules/shared_init.py b/modules/shared_init.py index 586be3423..935e3a21c 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -29,6 +29,7 @@ def initialize(): devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 + devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype shared.device = devices.device shared.weight_load_location = None if cmd_opts.lowram else "cpu" From 42e6df723c68af775b73c9fa4f43f99345348689 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Tue, 9 Jan 2024 22:39:39 +0800 Subject: [PATCH 2/7] Fix bugs when arg dtype doesn't match --- modules/devices.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 6edfb1278..e05740524 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -134,24 +134,19 @@ patch_module_list = [ def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): - org_dtype = torch_utils.get_param(self).dtype - if not target_dtype == org_dtype == dtype_inference: - self.to(target_dtype) - args = [ - arg.to(target_dtype) - if isinstance(arg, torch.Tensor) - else arg - for arg in args - ] - kwargs = { - k: v.to(target_dtype) - if isinstance(v, torch.Tensor) - else v - for k, v in kwargs.items() - } + if any( + isinstance(arg, torch.Tensor) and arg.dtype != target_dtype + for arg in args + ): + args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + org_dtype = torch_utils.get_param(self).dtype + if org_dtype != target_dtype: + self.to(target_dtype) result = self.org_forward(*args, **kwargs) - self.to(org_dtype) + if org_dtype != target_dtype: + self.to(org_dtype) if target_dtype != dtype_inference: if isinstance(result, tuple): From c2c05fcca8f3547783c5440c04ec10cc63c65db5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:53:58 +0800 Subject: [PATCH 3/7] linting and debugs --- modules/devices.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index e05740524..ad36f6562 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -140,20 +140,20 @@ def manual_cast_forward(target_dtype): ): args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - + org_dtype = torch_utils.get_param(self).dtype if org_dtype != target_dtype: self.to(target_dtype) result = self.org_forward(*args, **kwargs) if org_dtype != target_dtype: self.to(org_dtype) - + if target_dtype != dtype_inference: if isinstance(result, tuple): result = tuple( - i.to(dtype_inference) - if isinstance(i, torch.Tensor) - else i + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i for i in result ) elif isinstance(result, torch.Tensor): @@ -185,7 +185,7 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if dtype == torch.float32 and shared.cmd_opts.precision == "full": + if dtype == torch.float32: return contextlib.nullcontext() if has_xpu() or has_mps() or cuda_no_autocast(): From e00365962b17550a42235d1fbe2ad2c7cc4b8961 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:13:34 +0800 Subject: [PATCH 4/7] Apply correct inference precision implementation --- modules/devices.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ad36f6562..9e1f207c3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -132,6 +132,21 @@ patch_module_list = [ ] +def cast_output(result): + if isinstance(result, tuple): + result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + + +def autocast_with_cast_output(self, *args, **kwargs): + result = self.org_forward(*args, **kwargs) + if dtype_inference != dtype: + result = cast_output(result) + return result + + def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): if any( @@ -149,15 +164,7 @@ def manual_cast_forward(target_dtype): self.to(org_dtype) if target_dtype != dtype_inference: - if isinstance(result, tuple): - result = tuple( - i.to(dtype_inference) - if isinstance(i, torch.Tensor) - else i - for i in result - ) - elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) + result = cast_output(result) return result return forward_wrapper @@ -178,6 +185,20 @@ def manual_cast(target_dtype): module_type.forward = module_type.org_forward +@contextlib.contextmanager +def precision_full_with_autocast(autocast_ctx): + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = autocast_with_cast_output + module_type.org_forward = org_forward + try: + with autocast_ctx: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -191,6 +212,9 @@ def autocast(disable=False): if has_xpu() or has_mps() or cuda_no_autocast(): return manual_cast(dtype_inference) + if dtype_inference == torch.float32 and dtype != torch.float32: + return precision_full_with_autocast(torch.autocast("cuda")) + return torch.autocast("cuda") From 1fd69655fe340325863cbd7bf5297e034a6a3a0a Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:15:05 +0800 Subject: [PATCH 5/7] Revert "Apply correct inference precision implementation" This reverts commit e00365962b17550a42235d1fbe2ad2c7cc4b8961. --- modules/devices.py | 42 +++++++++--------------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 9e1f207c3..ad36f6562 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -132,21 +132,6 @@ patch_module_list = [ ] -def cast_output(result): - if isinstance(result, tuple): - result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result) - elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) - return result - - -def autocast_with_cast_output(self, *args, **kwargs): - result = self.org_forward(*args, **kwargs) - if dtype_inference != dtype: - result = cast_output(result) - return result - - def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): if any( @@ -164,7 +149,15 @@ def manual_cast_forward(target_dtype): self.to(org_dtype) if target_dtype != dtype_inference: - result = cast_output(result) + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) return result return forward_wrapper @@ -185,20 +178,6 @@ def manual_cast(target_dtype): module_type.forward = module_type.org_forward -@contextlib.contextmanager -def precision_full_with_autocast(autocast_ctx): - for module_type in patch_module_list: - org_forward = module_type.forward - module_type.forward = autocast_with_cast_output - module_type.org_forward = org_forward - try: - with autocast_ctx: - yield None - finally: - for module_type in patch_module_list: - module_type.forward = module_type.org_forward - - def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -212,9 +191,6 @@ def autocast(disable=False): if has_xpu() or has_mps() or cuda_no_autocast(): return manual_cast(dtype_inference) - if dtype_inference == torch.float32 and dtype != torch.float32: - return precision_full_with_autocast(torch.autocast("cuda")) - return torch.autocast("cuda") From 58d5b042cd02f287faabef399134b97d323691f2 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:23:40 +0800 Subject: [PATCH 6/7] Apply the correct behavior of precision='full' --- modules/devices.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ad36f6562..29a270d11 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -185,11 +185,14 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if dtype == torch.float32: - return contextlib.nullcontext() - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype_inference) + return manual_cast(dtype) + + if fp8 and dtype_inference == torch.float32: + return manual_cast(dtype) + + if dtype == torch.float32 or dtype_inference == torch.float32: + return contextlib.nullcontext() return torch.autocast("cuda") From ca671e5d7b9d03227f01e6bcb350032b6d14e722 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:30:55 +0800 Subject: [PATCH 7/7] rearrange if-statements for cpu --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 29a270d11..0321d12c6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -185,15 +185,15 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype) - if fp8 and dtype_inference == torch.float32: return manual_cast(dtype) if dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype) + return torch.autocast("cuda")