mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-03 20:22:56 +08:00
make it possible to use checkpoints of different types (SD1, SDXL) in first and second pass of hires fix
This commit is contained in:
parent
eec540b227
commit
a64fbe8928
@ -1060,16 +1060,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
if self.latent_scale_mode is None:
|
||||||
|
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||||
|
else:
|
||||||
|
decoded_samples = None
|
||||||
|
|
||||||
current = shared.sd_model.sd_checkpoint_info
|
current = shared.sd_model.sd_checkpoint_info
|
||||||
try:
|
try:
|
||||||
if self.hr_checkpoint_info is not None:
|
if self.hr_checkpoint_info is not None:
|
||||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
|
||||||
return self.sample_hr_pass(samples, seeds, subseeds, subseed_strength, prompts)
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
finally:
|
finally:
|
||||||
sd_models.reload_model_weights(info=current)
|
sd_models.reload_model_weights(info=current)
|
||||||
|
|
||||||
def sample_hr_pass(self, samples, seeds, subseeds, subseed_strength, prompts):
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
@ -1100,7 +1105,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
else:
|
else:
|
||||||
image_conditioning = self.txt2img_image_conditioning(samples)
|
image_conditioning = self.txt2img_image_conditioning(samples)
|
||||||
else:
|
else:
|
||||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
|
||||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
batch_images = []
|
batch_images = []
|
||||||
|
Loading…
Reference in New Issue
Block a user