make it possible to use checkpoints of different types (SD1, SDXL) in first and second pass of hires fix

This commit is contained in:
AUTOMATIC1111 2023-07-30 15:12:09 +03:00
parent eec540b227
commit a64fbe8928

View File

@ -1060,16 +1060,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr: if not self.enable_hr:
return samples return samples
if self.latent_scale_mode is None:
decoded_samples = decode_first_stage(self.sd_model, samples)
else:
decoded_samples = None
current = shared.sd_model.sd_checkpoint_info current = shared.sd_model.sd_checkpoint_info
try: try:
if self.hr_checkpoint_info is not None: if self.hr_checkpoint_info is not None:
sd_models.reload_model_weights(info=self.hr_checkpoint_info) sd_models.reload_model_weights(info=self.hr_checkpoint_info)
return self.sample_hr_pass(samples, seeds, subseeds, subseed_strength, prompts) return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
finally: finally:
sd_models.reload_model_weights(info=current) sd_models.reload_model_weights(info=current)
def sample_hr_pass(self, samples, seeds, subseeds, subseed_strength, prompts): def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
self.is_hr_pass = True self.is_hr_pass = True
target_width = self.hr_upscale_to_x target_width = self.hr_upscale_to_x
@ -1100,7 +1105,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else: else:
image_conditioning = self.txt2img_image_conditioning(samples) image_conditioning = self.txt2img_image_conditioning(samples)
else: else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = [] batch_images = []