Fix SD15 dtype

This commit is contained in:
huchenlei 2024-05-17 13:23:12 -04:00
parent 2a8a60c2c5
commit dca9007ac7

View File

@ -733,6 +733,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config sd_model.used_config = checkpoint_config
# ldm's Unet is using self.dtype to cast input tensor. If we do not overwrite
# UnetModel.dtype, it will be the default dtype from config.
# sgm's Unet is not using dtype for casting. The value will be ignored.
sd_model.model.diffusion_model.dtype = devices.dtype_unet
timer.record("create model") timer.record("create model")