SD2 v autodetection fix

This commit is contained in:
AUTOMATIC1111 2024-07-06 11:00:22 +03:00
parent 477869c044
commit 74069addc3

View File

@ -58,12 +58,13 @@ def is_using_v_parameterization_for_sd2(state_dict):
with torch.no_grad(): with torch.no_grad():
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
unet.load_state_dict(unet_sd, strict=True) unet.load_state_dict(unet_sd, strict=True)
unet.to(device=device, dtype=torch.float) unet.to(device=device, dtype=devices.dtype_unet)
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() with devices.autocast():
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
return out < -1 return out < -1