fix to support dtype_inference != dtype case

This commit is contained in:
Won-Kyu Park 2024-09-12 00:23:00 +09:00
parent 2f72fd89ff
commit 24f2c1b9e4
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
3 changed files with 4 additions and 4 deletions

View File

@ -484,7 +484,7 @@ class StableDiffusionProcessing:
cache = caches[0] cache = caches[0]
with devices.autocast(): with devices.autocast(target_dtype=devices.dtype_inference):
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
cache[0] = cached_params cache[0] = cached_params
@ -1439,7 +1439,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast(): with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data) extra_networks.activate(self, self.hr_extra_network_data)
with devices.autocast(): with devices.autocast(target_dtype=devices.dtype_inference):
self.calculate_hr_conds() self.calculate_hr_conds()
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))

View File

@ -984,7 +984,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
timer.record("scripts callbacks") timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad(): with devices.autocast(target_dtype=devices.dtype_inference), torch.no_grad():
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt") timer.record("calculate empty prompt")

View File

@ -18,7 +18,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
is_negative_prompt = getattr(batch, 'is_negative_prompt', False) is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
devices_args = dict(device=devices.device, dtype=devices.dtype) devices_args = dict(device=devices.device, dtype=devices.dtype_inference)
sdxl_conds = { sdxl_conds = {
"txt": batch, "txt": batch,