mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 05:45:05 +08:00
added highres fix feature
This commit is contained in:
parent
8a32a71ca3
commit
6d7ca54a1a
@ -72,6 +72,10 @@ titles = {
|
|||||||
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
|
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
|
||||||
|
|
||||||
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
||||||
|
|
||||||
|
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
|
||||||
|
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,11 +74,12 @@ class StableDiffusionProcessing:
|
|||||||
self.overlay_images = overlay_images
|
self.overlay_images = overlay_images
|
||||||
self.paste_to = None
|
self.paste_to = None
|
||||||
self.color_corrections = None
|
self.color_corrections = None
|
||||||
|
self.denoising_strength: float = 0
|
||||||
|
|
||||||
def init(self, seed):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sample(self, x, conditioning, unconditional_conditioning):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@ -303,7 +304,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||||
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
||||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||||
p.init(seed=all_seeds[0])
|
p.init(all_prompts, all_seeds, all_subseeds)
|
||||||
|
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
@ -328,13 +329,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
comments[comment] = 1
|
comments[comment] = 1
|
||||||
|
|
||||||
# we manually generate all input noises because each one should have a specific seed
|
|
||||||
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
|
|
||||||
|
|
||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
|
|
||||||
# if we are interruped, sample returns just noise
|
# if we are interruped, sample returns just noise
|
||||||
@ -406,13 +404,64 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
firstphase_width = 0
|
||||||
|
firstphase_height = 0
|
||||||
|
firstphase_width_truncated = 0
|
||||||
|
firstphase_height_truncated = 0
|
||||||
|
|
||||||
def init(self, seed):
|
def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.enable_hr = enable_hr
|
||||||
|
self.scale_latent = scale_latent
|
||||||
|
self.denoising_strength = denoising_strength
|
||||||
|
|
||||||
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
|
if self.enable_hr:
|
||||||
|
if state.job_count == -1:
|
||||||
|
state.job_count = self.n_iter * 2
|
||||||
|
else:
|
||||||
|
state.job_count = state.job_count * 2
|
||||||
|
|
||||||
|
desired_pixel_count = 512 * 512
|
||||||
|
actual_pixel_count = self.width * self.height
|
||||||
|
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
||||||
|
|
||||||
|
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
|
||||||
|
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
|
||||||
|
self.firstphase_width_truncated = int(scale * self.width)
|
||||||
|
self.firstphase_height_truncated = int(scale * self.height)
|
||||||
|
|
||||||
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
||||||
|
|
||||||
def sample(self, x, conditioning, unconditional_conditioning):
|
if not self.enable_hr:
|
||||||
samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
return samples_ddim
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
||||||
|
|
||||||
|
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
|
||||||
|
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
|
||||||
|
|
||||||
|
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
|
||||||
|
|
||||||
|
if self.scale_latent:
|
||||||
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||||
|
else:
|
||||||
|
decoded_samples = self.sd_model.decode_first_stage(samples)
|
||||||
|
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
||||||
|
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
||||||
|
|
||||||
|
shared.state.nextjob()
|
||||||
|
|
||||||
|
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
||||||
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
@ -435,7 +484,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
|
|
||||||
def init(self, seed):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
||||||
crop_region = None
|
crop_region = None
|
||||||
|
|
||||||
@ -529,12 +578,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
||||||
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
||||||
|
|
||||||
|
# this needs to be fixed to be done in sample() using actual seeds for batches
|
||||||
if self.inpainting_fill == 2:
|
if self.inpainting_fill == 2:
|
||||||
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
|
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
|
||||||
def sample(self, x, conditioning, unconditional_conditioning):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
|
@ -38,9 +38,9 @@ samplers = [
|
|||||||
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
||||||
|
|
||||||
|
|
||||||
def setup_img2img_steps(p):
|
def setup_img2img_steps(p, steps=None):
|
||||||
if opts.img2img_fix_steps:
|
if opts.img2img_fix_steps or steps is not None:
|
||||||
steps = int(p.steps / min(p.denoising_strength, 0.999))
|
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||||
t_enc = p.steps - 1
|
t_enc = p.steps - 1
|
||||||
else:
|
else:
|
||||||
steps = p.steps
|
steps = p.steps
|
||||||
@ -115,8 +115,8 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.step += 1
|
self.step += 1
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||||
steps, t_enc = setup_img2img_steps(p)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
# existing code fails with cetain step counts, like 9
|
# existing code fails with cetain step counts, like 9
|
||||||
try:
|
try:
|
||||||
@ -127,16 +127,16 @@ class VanillaStableDiffusionSampler:
|
|||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||||
|
|
||||||
self.sampler.p_sample_ddim = self.p_sample_ddim_hook
|
self.sampler.p_sample_ddim = self.p_sample_ddim_hook
|
||||||
self.mask = p.mask
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.nmask = p.nmask
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.init_latent = p.init_latent
|
self.init_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
|
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||||
if hasattr(self.sampler, fieldname):
|
if hasattr(self.sampler, fieldname):
|
||||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||||
@ -145,11 +145,13 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
steps = steps or p.steps
|
||||||
|
|
||||||
# existing code fails with cetin step counts, like 9
|
# existing code fails with cetin step counts, like 9
|
||||||
try:
|
try:
|
||||||
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
|
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
|
||||||
except Exception:
|
except Exception:
|
||||||
samples_ddim, _ = self.sampler.sample(S=p.steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
|
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
|
||||||
|
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
@ -186,7 +188,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
def extended_trange(count, *args, **kwargs):
|
def extended_trange(sampler, count, *args, **kwargs):
|
||||||
state.sampling_steps = count
|
state.sampling_steps = count
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
@ -194,6 +196,9 @@ def extended_trange(count, *args, **kwargs):
|
|||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if sampler.stop_at is not None and x > sampler.stop_at:
|
||||||
|
break
|
||||||
|
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
state.sampling_step += 1
|
state.sampling_step += 1
|
||||||
@ -222,6 +227,7 @@ class KDiffusionSampler:
|
|||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
|
self.stop_at = None
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
store_latent(d["denoised"])
|
store_latent(d["denoised"])
|
||||||
@ -240,8 +246,8 @@ class KDiffusionSampler:
|
|||||||
self.sampler_noise_index += 1
|
self.sampler_noise_index += 1
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||||
steps, t_enc = setup_img2img_steps(p)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
|
|
||||||
@ -251,33 +257,36 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
|
|
||||||
self.model_wrap_cfg.mask = p.mask
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.model_wrap_cfg.nmask = p.nmask
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.model_wrap_cfg.init_latent = p.init_latent
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.model_wrap.step = 0
|
self.model_wrap.step = 0
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
if hasattr(k_diffusion.sampling, 'trange'):
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
|
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
||||||
|
|
||||||
if self.sampler_noises is not None:
|
if self.sampler_noises is not None:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
|
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||||
sigmas = self.model_wrap.get_sigmas(p.steps)
|
steps = steps or p.steps
|
||||||
|
|
||||||
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
x = x * sigmas[0]
|
x = x * sigmas[0]
|
||||||
|
|
||||||
self.model_wrap_cfg.step = 0
|
self.model_wrap_cfg.step = 0
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
if hasattr(k_diffusion.sampling, 'trange'):
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
|
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
||||||
|
|
||||||
if self.sampler_noises is not None:
|
if self.sampler_noises is not None:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
|
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
|
||||||
return samples_ddim
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import modules.processing as processing
|
|||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, *args):
|
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
|
||||||
p = StableDiffusionProcessingTxt2Img(
|
p = StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||||
@ -28,6 +28,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||||||
height=height,
|
height=height,
|
||||||
restore_faces=restore_faces,
|
restore_faces=restore_faces,
|
||||||
tiling=tiling,
|
tiling=tiling,
|
||||||
|
enable_hr=enable_hr,
|
||||||
|
scale_latent=scale_latent,
|
||||||
|
denoising_strength=denoising_strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
@ -327,6 +327,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|||||||
outputs=[seed, dummy_component]
|
outputs=[seed, dummy_component]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_toprow(is_img2img):
|
def create_toprow(is_img2img):
|
||||||
with gr.Row(elem_id="toprow"):
|
with gr.Row(elem_id="toprow"):
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
@ -392,6 +393,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
||||||
tiling = gr.Checkbox(label='Tiling', value=False)
|
tiling = gr.Checkbox(label='Tiling', value=False)
|
||||||
|
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
|
||||||
|
|
||||||
|
with gr.Row(visible=False) as hr_options:
|
||||||
|
scale_latent = gr.Checkbox(label='Scale latent', value=True)
|
||||||
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
||||||
@ -451,6 +457,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
|
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
|
enable_hr,
|
||||||
|
scale_latent,
|
||||||
|
denoising_strength,
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
txt2img_gallery,
|
txt2img_gallery,
|
||||||
@ -463,6 +472,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||||||
txt2img_prompt.submit(**txt2img_args)
|
txt2img_prompt.submit(**txt2img_args)
|
||||||
submit.click(**txt2img_args)
|
submit.click(**txt2img_args)
|
||||||
|
|
||||||
|
enable_hr.change(
|
||||||
|
fn=lambda x: gr_show(x),
|
||||||
|
inputs=[enable_hr],
|
||||||
|
outputs=[hr_options],
|
||||||
|
)
|
||||||
|
|
||||||
interrupt.click(
|
interrupt.click(
|
||||||
fn=lambda: shared.state.interrupt(),
|
fn=lambda: shared.state.interrupt(),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
|
Loading…
Reference in New Issue
Block a user