mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-06 12:49:02 +08:00
Fix bugs when arg dtype doesn't match
This commit is contained in:
parent
209c26a1cb
commit
42e6df723c
@ -134,24 +134,19 @@ patch_module_list = [
|
|||||||
|
|
||||||
def manual_cast_forward(target_dtype):
|
def manual_cast_forward(target_dtype):
|
||||||
def forward_wrapper(self, *args, **kwargs):
|
def forward_wrapper(self, *args, **kwargs):
|
||||||
org_dtype = torch_utils.get_param(self).dtype
|
if any(
|
||||||
if not target_dtype == org_dtype == dtype_inference:
|
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||||
self.to(target_dtype)
|
for arg in args
|
||||||
args = [
|
):
|
||||||
arg.to(target_dtype)
|
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||||
if isinstance(arg, torch.Tensor)
|
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||||
else arg
|
|
||||||
for arg in args
|
|
||||||
]
|
|
||||||
kwargs = {
|
|
||||||
k: v.to(target_dtype)
|
|
||||||
if isinstance(v, torch.Tensor)
|
|
||||||
else v
|
|
||||||
for k, v in kwargs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
org_dtype = torch_utils.get_param(self).dtype
|
||||||
|
if org_dtype != target_dtype:
|
||||||
|
self.to(target_dtype)
|
||||||
result = self.org_forward(*args, **kwargs)
|
result = self.org_forward(*args, **kwargs)
|
||||||
self.to(org_dtype)
|
if org_dtype != target_dtype:
|
||||||
|
self.to(org_dtype)
|
||||||
|
|
||||||
if target_dtype != dtype_inference:
|
if target_dtype != dtype_inference:
|
||||||
if isinstance(result, tuple):
|
if isinstance(result, tuple):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user