mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-06 05:32:52 +08:00
Merge branch 'dev' into efficient-vae-methods
This commit is contained in:
commit
d56a9cfe6a
@ -2,16 +2,14 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
||||||
|
|
||||||
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
||||||
var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
|
var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
|
||||||
|
var search = searchDiv.querySelector('textarea');
|
||||||
var sort = gradioApp().getElementById(tabname + '_extra_sort');
|
var sort = gradioApp().getElementById(tabname + '_extra_sort');
|
||||||
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
|
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
|
||||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||||
|
|
||||||
search.classList.add('search');
|
|
||||||
sort.classList.add('sort');
|
|
||||||
sortOrder.classList.add('sortorder');
|
|
||||||
sort.dataset.sortkey = 'sortDefault';
|
sort.dataset.sortkey = 'sortDefault';
|
||||||
tabs.appendChild(search);
|
tabs.appendChild(searchDiv);
|
||||||
tabs.appendChild(sort);
|
tabs.appendChild(sort);
|
||||||
tabs.appendChild(sortOrder);
|
tabs.appendChild(sortOrder);
|
||||||
tabs.appendChild(refresh);
|
tabs.appendChild(refresh);
|
||||||
@ -179,7 +177,7 @@ function saveCardPreview(event, tabname, filename) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksSearchButton(tabs_id, event) {
|
function extraNetworksSearchButton(tabs_id, event) {
|
||||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
|
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
|
||||||
var button = event.target;
|
var button = event.target;
|
||||||
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
||||||
|
|
||||||
|
@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "Hires sampler" not in res:
|
if "Hires sampler" not in res:
|
||||||
res["Hires sampler"] = "Use same sampler"
|
res["Hires sampler"] = "Use same sampler"
|
||||||
|
|
||||||
|
if "Hires checkpoint" not in res:
|
||||||
|
res["Hires checkpoint"] = "Use same checkpoint"
|
||||||
|
|
||||||
if "Hires prompt" not in res:
|
if "Hires prompt" not in res:
|
||||||
res["Hires prompt"] = ""
|
res["Hires prompt"] = ""
|
||||||
|
|
||||||
|
@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
invalid_filename_chars = '<>:"/\\|?*\n'
|
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
|
||||||
invalid_filename_prefix = ' '
|
invalid_filename_prefix = ' '
|
||||||
invalid_filename_postfix = ' .'
|
invalid_filename_postfix = ' .'
|
||||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||||
|
@ -15,6 +15,9 @@ def send_everything_to_cpu():
|
|||||||
|
|
||||||
|
|
||||||
def setup_for_low_vram(sd_model, use_medvram):
|
def setup_for_low_vram(sd_model, use_medvram):
|
||||||
|
if getattr(sd_model, 'lowvram', False):
|
||||||
|
return
|
||||||
|
|
||||||
sd_model.lowvram = True
|
sd_model.lowvram = True
|
||||||
|
|
||||||
parents = {}
|
parents = {}
|
||||||
|
@ -539,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DecodedSamples(list):
|
||||||
|
already_decoded = True
|
||||||
|
|
||||||
|
|
||||||
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||||
samples = []
|
samples = DecodedSamples()
|
||||||
|
|
||||||
for i in range(batch.shape[0]):
|
for i in range(batch.shape[0]):
|
||||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||||
@ -788,8 +792,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
|
|
||||||
|
if getattr(samples_ddim, 'already_decoded', False):
|
||||||
|
x_samples_ddim = samples_ddim
|
||||||
|
else:
|
||||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
@ -931,7 +939,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
cached_hr_uc = [None, None]
|
cached_hr_uc = [None, None]
|
||||||
cached_hr_c = [None, None]
|
cached_hr_c = [None, None]
|
||||||
|
|
||||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.enable_hr = enable_hr
|
self.enable_hr = enable_hr
|
||||||
self.denoising_strength = denoising_strength
|
self.denoising_strength = denoising_strength
|
||||||
@ -942,11 +950,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.hr_resize_y = hr_resize_y
|
self.hr_resize_y = hr_resize_y
|
||||||
self.hr_upscale_to_x = hr_resize_x
|
self.hr_upscale_to_x = hr_resize_x
|
||||||
self.hr_upscale_to_y = hr_resize_y
|
self.hr_upscale_to_y = hr_resize_y
|
||||||
|
self.hr_checkpoint_name = hr_checkpoint_name
|
||||||
|
self.hr_checkpoint_info = None
|
||||||
self.hr_sampler_name = hr_sampler_name
|
self.hr_sampler_name = hr_sampler_name
|
||||||
self.hr_prompt = hr_prompt
|
self.hr_prompt = hr_prompt
|
||||||
self.hr_negative_prompt = hr_negative_prompt
|
self.hr_negative_prompt = hr_negative_prompt
|
||||||
self.all_hr_prompts = None
|
self.all_hr_prompts = None
|
||||||
self.all_hr_negative_prompts = None
|
self.all_hr_negative_prompts = None
|
||||||
|
self.latent_scale_mode = None
|
||||||
|
|
||||||
if firstphase_width != 0 or firstphase_height != 0:
|
if firstphase_width != 0 or firstphase_height != 0:
|
||||||
self.hr_upscale_to_x = self.width
|
self.hr_upscale_to_x = self.width
|
||||||
@ -969,6 +980,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
if self.enable_hr:
|
if self.enable_hr:
|
||||||
|
if self.hr_checkpoint_name:
|
||||||
|
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
||||||
|
|
||||||
|
if self.hr_checkpoint_info is None:
|
||||||
|
raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
|
||||||
|
|
||||||
|
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
|
||||||
|
|
||||||
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
||||||
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
||||||
|
|
||||||
@ -978,6 +997,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
||||||
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
||||||
|
|
||||||
|
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||||
|
if self.enable_hr and self.latent_scale_mode is None:
|
||||||
|
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
||||||
|
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
||||||
|
|
||||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||||
self.hr_resize_x = self.width
|
self.hr_resize_x = self.width
|
||||||
self.hr_resize_y = self.height
|
self.hr_resize_y = self.height
|
||||||
@ -1016,14 +1040,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
||||||
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
||||||
|
|
||||||
# special case: the user has chosen to do nothing
|
|
||||||
if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
|
|
||||||
self.enable_hr = False
|
|
||||||
self.denoising_strength = None
|
|
||||||
self.extra_generation_params.pop("Hires upscale", None)
|
|
||||||
self.extra_generation_params.pop("Hires resize", None)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not state.processing_has_refined_job_count:
|
if not state.processing_has_refined_job_count:
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = self.n_iter
|
state.job_count = self.n_iter
|
||||||
@ -1041,17 +1057,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
|
||||||
if self.enable_hr and latent_scale_mode is None:
|
|
||||||
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
|
||||||
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
|
||||||
|
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||||
|
del x
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
if self.latent_scale_mode is None:
|
||||||
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
decoded_samples = None
|
||||||
|
|
||||||
|
current = shared.sd_model.sd_checkpoint_info
|
||||||
|
try:
|
||||||
|
if self.hr_checkpoint_info is not None:
|
||||||
|
self.sampler = None
|
||||||
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
|
finally:
|
||||||
|
self.sampler = None
|
||||||
|
sd_models.reload_model_weights(info=current)
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
@ -1069,11 +1100,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
||||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
|
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
|
||||||
|
|
||||||
if latent_scale_mode is not None:
|
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
||||||
|
|
||||||
|
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
||||||
|
img2img_sampler_name = 'DDIM'
|
||||||
|
|
||||||
|
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||||
|
|
||||||
|
if self.latent_scale_mode is not None:
|
||||||
for i in range(samples.shape[0]):
|
for i in range(samples.shape[0]):
|
||||||
save_intermediate(samples, i)
|
save_intermediate(samples, i)
|
||||||
|
|
||||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
|
||||||
|
|
||||||
# Avoid making the inpainting conditioning unless necessary as
|
# Avoid making the inpainting conditioning unless necessary as
|
||||||
# this does need some extra compute to decode / encode the image again.
|
# this does need some extra compute to decode / encode the image again.
|
||||||
@ -1082,7 +1120,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
else:
|
else:
|
||||||
image_conditioning = self.txt2img_image_conditioning(samples)
|
image_conditioning = self.txt2img_image_conditioning(samples)
|
||||||
else:
|
else:
|
||||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
|
||||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
batch_images = []
|
batch_images = []
|
||||||
@ -1108,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
|
||||||
|
|
||||||
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
|
||||||
img2img_sampler_name = 'DDIM'
|
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||||
|
|
||||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
||||||
|
|
||||||
# GC now before running the next img2img to prevent running out of memory
|
# GC now before running the next img2img to prevent running out of memory
|
||||||
x = None
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
if not self.disable_extra_networks:
|
if not self.disable_extra_networks:
|
||||||
@ -1139,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return samples
|
return decoded_samples
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
super().close()
|
super().close()
|
||||||
@ -1180,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if self.hr_c is not None:
|
if self.hr_c is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
||||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
||||||
|
|
||||||
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
||||||
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
super().setup_conds()
|
super().setup_conds()
|
||||||
@ -1189,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.hr_uc = None
|
self.hr_uc = None
|
||||||
self.hr_c = None
|
self.hr_c = None
|
||||||
|
|
||||||
if self.enable_hr:
|
if self.enable_hr and self.hr_checkpoint_info is None:
|
||||||
if shared.opts.hires_fix_use_firstpass_conds:
|
if shared.opts.hires_fix_use_firstpass_conds:
|
||||||
self.calculate_hr_conds()
|
self.calculate_hr_conds()
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from types import MethodType
|
|||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
|
|||||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||||
|
|
||||||
# silence new console spam from SD2
|
# silence new console spam from SD2
|
||||||
ldm.modules.attention.print = lambda *args: None
|
ldm.modules.attention.print = shared.ldm_print
|
||||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
|
ldm.util.print = shared.ldm_print
|
||||||
|
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
|
|
||||||
|
sd_hijack_inpainting.do_inpainting_hijack()
|
||||||
|
|
||||||
optimizers = []
|
optimizers = []
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
|
@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
def do_inpainting_hijack():
|
||||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
|
||||||
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
|
@ -15,7 +15,6 @@ import ldm.modules.midas as midas
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
|
||||||
@ -67,6 +66,7 @@ class CheckpointInfo:
|
|||||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||||
|
|
||||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||||
|
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||||
|
|
||||||
@ -87,6 +87,7 @@ class CheckpointInfo:
|
|||||||
|
|
||||||
checkpoints_list.pop(self.title, None)
|
checkpoints_list.pop(self.title, None)
|
||||||
self.title = f'{self.name} [{self.shorthash}]'
|
self.title = f'{self.name} [{self.shorthash}]'
|
||||||
|
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
|
||||||
self.register()
|
self.register()
|
||||||
|
|
||||||
return self.shorthash
|
return self.shorthash
|
||||||
@ -107,14 +108,8 @@ def setup_model():
|
|||||||
enable_midas_autodownload()
|
enable_midas_autodownload()
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_tiles():
|
def checkpoint_tiles(use_short=False):
|
||||||
def convert(name):
|
return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
|
||||||
return int(name) if name.isdigit() else name.lower()
|
|
||||||
|
|
||||||
def alphanumeric_key(key):
|
|
||||||
return [convert(c) for c in re.split('([0-9]+)', key)]
|
|
||||||
|
|
||||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
@ -137,11 +132,14 @@ def list_models():
|
|||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
|
|
||||||
for filename in sorted(model_list, key=str.lower):
|
for filename in model_list:
|
||||||
checkpoint_info = CheckpointInfo(filename)
|
checkpoint_info = CheckpointInfo(filename)
|
||||||
checkpoint_info.register()
|
checkpoint_info.register()
|
||||||
|
|
||||||
|
|
||||||
|
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(search_string):
|
def get_closet_checkpoint_match(search_string):
|
||||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
@ -151,6 +149,11 @@ def get_closet_checkpoint_match(search_string):
|
|||||||
if found:
|
if found:
|
||||||
return found[0]
|
return found[0]
|
||||||
|
|
||||||
|
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
|
||||||
|
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
|
||||||
|
if found:
|
||||||
|
return found[0]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -430,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
|||||||
class SdModelData:
|
class SdModelData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sd_model = None
|
self.sd_model = None
|
||||||
|
self.loaded_sd_models = []
|
||||||
self.was_loaded_at_least_once = False
|
self.was_loaded_at_least_once = False
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
@ -444,6 +448,7 @@ class SdModelData:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
load_model()
|
load_model()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
||||||
print("", file=sys.stderr)
|
print("", file=sys.stderr)
|
||||||
@ -455,11 +460,24 @@ class SdModelData:
|
|||||||
def set_sd_model(self, v):
|
def set_sd_model(self, v):
|
||||||
self.sd_model = v
|
self.sd_model = v
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.loaded_sd_models.remove(v)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if v is not None:
|
||||||
|
self.loaded_sd_models.insert(0, v)
|
||||||
|
|
||||||
|
|
||||||
model_data = SdModelData()
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
def get_empty_cond(sd_model):
|
def get_empty_cond(sd_model):
|
||||||
|
from modules import extra_networks, processing
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img()
|
||||||
|
extra_networks.activate(p, {})
|
||||||
|
|
||||||
if hasattr(sd_model, 'conditioner'):
|
if hasattr(sd_model, 'conditioner'):
|
||||||
d = sd_model.get_learned_conditioning([""])
|
d = sd_model.get_learned_conditioning([""])
|
||||||
return d['crossattn']
|
return d['crossattn']
|
||||||
@ -467,20 +485,44 @@ def get_empty_cond(sd_model):
|
|||||||
return sd_model.cond_stage_model([""])
|
return sd_model.cond_stage_model([""])
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
def send_model_to_cpu(m):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
else:
|
||||||
|
m.to(devices.cpu)
|
||||||
|
|
||||||
if model_data.sd_model:
|
|
||||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
|
||||||
model_data.sd_model = None
|
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
do_inpainting_hijack()
|
|
||||||
|
def send_model_to_device(m):
|
||||||
|
from modules import lowvram
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
|
||||||
|
else:
|
||||||
|
m.to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_trash(m):
|
||||||
|
m.to(device="meta")
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
|
from modules import sd_hijack
|
||||||
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
|
if model_data.sd_model:
|
||||||
|
send_model_to_trash(model_data.sd_model)
|
||||||
|
model_data.sd_model = None
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
timer.record("unload existing model")
|
||||||
|
|
||||||
if already_loaded_state_dict is not None:
|
if already_loaded_state_dict is not None:
|
||||||
state_dict = already_loaded_state_dict
|
state_dict = already_loaded_state_dict
|
||||||
else:
|
else:
|
||||||
@ -519,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
|
|
||||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
|
timer.record("load weights from state dict")
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
send_model_to_device(sd_model)
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
|
||||||
else:
|
|
||||||
sd_model.to(shared.device)
|
|
||||||
|
|
||||||
timer.record("move model to device")
|
timer.record("move model to device")
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
@ -532,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
model_data.sd_model = sd_model
|
model_data.set_sd_model(sd_model)
|
||||||
model_data.was_loaded_at_least_once = True
|
model_data.was_loaded_at_least_once = True
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
@ -553,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||||
|
"""
|
||||||
|
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
||||||
|
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||||
|
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||||
|
If no such model exists, returns None.
|
||||||
|
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||||
|
"""
|
||||||
|
|
||||||
|
already_loaded = None
|
||||||
|
for i in reversed(range(len(model_data.loaded_sd_models))):
|
||||||
|
loaded_model = model_data.loaded_sd_models[i]
|
||||||
|
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
already_loaded = loaded_model
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
||||||
|
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
||||||
|
model_data.loaded_sd_models.pop()
|
||||||
|
send_model_to_trash(loaded_model)
|
||||||
|
timer.record("send model to trash")
|
||||||
|
|
||||||
|
if shared.opts.sd_checkpoints_keep_in_cpu:
|
||||||
|
send_model_to_cpu(sd_model)
|
||||||
|
timer.record("send model to cpu")
|
||||||
|
|
||||||
|
if already_loaded is not None:
|
||||||
|
send_model_to_device(already_loaded)
|
||||||
|
timer.record("send model to device")
|
||||||
|
|
||||||
|
model_data.set_sd_model(already_loaded)
|
||||||
|
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
||||||
|
return model_data.sd_model
|
||||||
|
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
||||||
|
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
||||||
|
|
||||||
|
model_data.sd_model = None
|
||||||
|
load_model(checkpoint_info)
|
||||||
|
return model_data.sd_model
|
||||||
|
elif len(model_data.loaded_sd_models) > 0:
|
||||||
|
sd_model = model_data.loaded_sd_models.pop()
|
||||||
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
|
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
||||||
|
return sd_model
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def reload_model_weights(sd_model=None, info=None):
|
def reload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import devices, sd_hijack
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = model_data.sd_model
|
sd_model = model_data.sd_model
|
||||||
|
|
||||||
@ -565,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
else:
|
else:
|
||||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return sd_model
|
||||||
|
|
||||||
|
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||||
|
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
return sd_model
|
||||||
|
|
||||||
|
if sd_model is not None:
|
||||||
sd_unet.apply_unet("None")
|
sd_unet.apply_unet("None")
|
||||||
|
send_model_to_cpu(sd_model)
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.send_everything_to_cpu()
|
|
||||||
else:
|
|
||||||
sd_model.to(devices.cpu)
|
|
||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
timer = Timer()
|
|
||||||
|
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
@ -585,7 +673,9 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
timer.record("find config")
|
timer.record("find config")
|
||||||
|
|
||||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
del sd_model
|
if sd_model is not None:
|
||||||
|
send_model_to_trash(sd_model)
|
||||||
|
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return model_data.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
@ -608,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
print(f"Weights loaded in {timer.summary()}.")
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
|
model_data.set_sd_model(sd_model)
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,10 +98,10 @@ def extend_sdxl(model):
|
|||||||
model.conditioner.wrapped = torch.nn.Module()
|
model.conditioner.wrapped = torch.nn.Module()
|
||||||
|
|
||||||
|
|
||||||
sgm.modules.attention.print = lambda *args: None
|
sgm.modules.attention.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||||
sgm.modules.encoders.modules.print = lambda *args: None
|
sgm.modules.encoders.modules.print = shared.ldm_print
|
||||||
|
|
||||||
# this gets the code to load the vanilla attention that we override
|
# this gets the code to load the vanilla attention that we override
|
||||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||||
|
@ -220,6 +220,8 @@ class State:
|
|||||||
return
|
return
|
||||||
|
|
||||||
import modules.sd_samplers
|
import modules.sd_samplers
|
||||||
|
|
||||||
|
try:
|
||||||
if opts.show_progress_grid:
|
if opts.show_progress_grid:
|
||||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||||
else:
|
else:
|
||||||
@ -227,6 +229,11 @@ class State:
|
|||||||
|
|
||||||
self.current_image_sampling_step = self.sampling_step
|
self.current_image_sampling_step = self.sampling_step
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||||
|
# we silently ignore this error
|
||||||
|
errors.record_exception()
|
||||||
|
|
||||||
def assign_current_image(self, image):
|
def assign_current_image(self, image):
|
||||||
self.current_image = image
|
self.current_image = image
|
||||||
self.id_live_preview += 1
|
self.id_live_preview += 1
|
||||||
@ -393,6 +400,7 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
|
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
@ -412,17 +420,14 @@ options_templates.update(options_section(('training', "Training"), {
|
|||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||||
|
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||||
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_restart(),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
|
||||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
|
||||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
@ -441,6 +446,22 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
|||||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
options_templates.update(options_section(('img2img', "img2img"), {
|
||||||
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||||
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||||
|
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
|
||||||
|
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
||||||
|
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_restart(),
|
||||||
|
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_restart(),
|
||||||
|
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_restart(),
|
||||||
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
@ -460,7 +481,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|||||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||||
@ -493,10 +514,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
||||||
"img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
@ -515,11 +533,12 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
||||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(),
|
||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
@ -892,3 +911,10 @@ def walk_files(path, allowed_extensions=None):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
yield os.path.join(root, filename)
|
yield os.path.join(root, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def ldm_print(*args, **kwargs):
|
||||||
|
if opts.hide_ldm_prints:
|
||||||
|
return
|
||||||
|
|
||||||
|
print(*args, **kwargs)
|
||||||
|
@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
hr_second_pass_steps=hr_second_pass_steps,
|
hr_second_pass_steps=hr_second_pass_steps,
|
||||||
hr_resize_x=hr_resize_x,
|
hr_resize_x=hr_resize_x,
|
||||||
hr_resize_y=hr_resize_y,
|
hr_resize_y=hr_resize_y,
|
||||||
|
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||||
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
||||||
hr_prompt=hr_prompt,
|
hr_prompt=hr_prompt,
|
||||||
hr_negative_prompt=hr_negative_prompt,
|
hr_negative_prompt=hr_negative_prompt,
|
||||||
|
@ -261,7 +261,6 @@ class Toprow:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
|
|
||||||
|
|
||||||
self.button_interrogate = None
|
self.button_interrogate = None
|
||||||
self.button_deepbooru = None
|
self.button_deepbooru = None
|
||||||
if is_img2img:
|
if is_img2img:
|
||||||
@ -290,7 +289,6 @@ class Toprow:
|
|||||||
with gr.Row(elem_id=f"{id_part}_tools"):
|
with gr.Row(elem_id=f"{id_part}_tools"):
|
||||||
self.paste = ToolButton(value=paste_symbol, elem_id="paste")
|
self.paste = ToolButton(value=paste_symbol, elem_id="paste")
|
||||||
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
||||||
self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
|
|
||||||
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
||||||
|
|
||||||
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||||
@ -404,11 +402,10 @@ def create_ui():
|
|||||||
|
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
||||||
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
|
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
||||||
from modules import ui_extra_networks
|
extra_tabs.__enter__()
|
||||||
extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img')
|
|
||||||
|
|
||||||
with gr.Row(equal_height=False):
|
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||||
scripts.scripts_txt2img.prepare_ui()
|
scripts.scripts_txt2img.prepare_ui()
|
||||||
|
|
||||||
@ -456,6 +453,10 @@ def create_ui():
|
|||||||
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||||
|
|
||||||
|
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||||
|
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||||
|
|
||||||
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
@ -533,6 +534,7 @@ def create_ui():
|
|||||||
hr_second_pass_steps,
|
hr_second_pass_steps,
|
||||||
hr_resize_x,
|
hr_resize_x,
|
||||||
hr_resize_y,
|
hr_resize_y,
|
||||||
|
hr_checkpoint_name,
|
||||||
hr_sampler_index,
|
hr_sampler_index,
|
||||||
hr_prompt,
|
hr_prompt,
|
||||||
hr_negative_prompt,
|
hr_negative_prompt,
|
||||||
@ -599,8 +601,9 @@ def create_ui():
|
|||||||
(hr_second_pass_steps, "Hires steps"),
|
(hr_second_pass_steps, "Hires steps"),
|
||||||
(hr_resize_x, "Hires resize-1"),
|
(hr_resize_x, "Hires resize-1"),
|
||||||
(hr_resize_y, "Hires resize-2"),
|
(hr_resize_y, "Hires resize-2"),
|
||||||
|
(hr_checkpoint_name, "Hires checkpoint"),
|
||||||
(hr_sampler_index, "Hires sampler"),
|
(hr_sampler_index, "Hires sampler"),
|
||||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
|
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||||
(hr_prompt, "Hires prompt"),
|
(hr_prompt, "Hires prompt"),
|
||||||
(hr_negative_prompt, "Hires negative prompt"),
|
(hr_negative_prompt, "Hires negative prompt"),
|
||||||
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||||
@ -625,19 +628,22 @@ def create_ui():
|
|||||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||||
|
|
||||||
|
from modules import ui_extra_networks
|
||||||
|
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||||
|
|
||||||
|
extra_tabs.__exit__()
|
||||||
|
|
||||||
scripts.scripts_current = scripts.scripts_img2img
|
scripts.scripts_current = scripts.scripts_img2img
|
||||||
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
toprow = Toprow(is_img2img=True)
|
toprow = Toprow(is_img2img=True)
|
||||||
|
|
||||||
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
|
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
||||||
from modules import ui_extra_networks
|
extra_tabs.__enter__()
|
||||||
extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img')
|
|
||||||
|
|
||||||
with FormRow(equal_height=False):
|
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow().style(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
||||||
copy_image_buttons = []
|
copy_image_buttons = []
|
||||||
copy_image_destinations = {}
|
copy_image_destinations = {}
|
||||||
@ -663,15 +669,15 @@ def create_ui():
|
|||||||
add_copy_image_controls('img2img', init_img)
|
add_copy_image_controls('img2img', init_img)
|
||||||
|
|
||||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height)
|
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
||||||
add_copy_image_controls('sketch', sketch)
|
add_copy_image_controls('sketch', sketch)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff')
|
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
|
||||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff')
|
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
||||||
inpaint_color_sketch_orig = gr.State(None)
|
inpaint_color_sketch_orig = gr.State(None)
|
||||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||||
|
|
||||||
@ -959,8 +965,6 @@ def create_ui():
|
|||||||
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||||
|
|
||||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
|
||||||
|
|
||||||
img2img_paste_fields = [
|
img2img_paste_fields = [
|
||||||
(toprow.prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(toprow.negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
@ -989,6 +993,12 @@ def create_ui():
|
|||||||
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
|
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
from modules import ui_extra_networks
|
||||||
|
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
|
||||||
|
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
||||||
|
|
||||||
|
extra_tabs.__exit__()
|
||||||
|
|
||||||
scripts.scripts_current = None
|
scripts.scripts_current = None
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||||
|
@ -155,7 +155,7 @@ class ExtraNetworksPage:
|
|||||||
subdirs = {"": 1, **subdirs}
|
subdirs = {"": 1, **subdirs}
|
||||||
|
|
||||||
subdirs_html = "".join([f"""
|
subdirs_html = "".join([f"""
|
||||||
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
|
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
|
||||||
{html.escape(subdir if subdir!="" else "all")}
|
{html.escape(subdir if subdir!="" else "all")}
|
||||||
</button>
|
</button>
|
||||||
""" for subdir in subdirs])
|
""" for subdir in subdirs])
|
||||||
@ -347,7 +347,7 @@ def pages_in_preferred_order(pages):
|
|||||||
return sorted(pages, key=lambda x: tab_scores[x.name])
|
return sorted(pages, key=lambda x: tab_scores[x.name])
|
||||||
|
|
||||||
|
|
||||||
def create_ui(container, button, tabname):
|
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||||
ui = ExtraNetworksUi()
|
ui = ExtraNetworksUi()
|
||||||
ui.pages = []
|
ui.pages = []
|
||||||
ui.pages_contents = []
|
ui.pages_contents = []
|
||||||
@ -355,9 +355,10 @@ def create_ui(container, button, tabname):
|
|||||||
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
||||||
ui.tabname = tabname
|
ui.tabname = tabname
|
||||||
|
|
||||||
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
|
related_tabs = []
|
||||||
|
|
||||||
for page in ui.stored_extra_pages:
|
for page in ui.stored_extra_pages:
|
||||||
with gr.Tab(page.title, id=page.id_page):
|
with gr.Tab(page.title, id=page.id_page) as tab:
|
||||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||||
ui.pages.append(page_elem)
|
ui.pages.append(page_elem)
|
||||||
@ -368,35 +369,27 @@ def create_ui(container, button, tabname):
|
|||||||
editor.create_ui()
|
editor.create_ui()
|
||||||
ui.user_metadata_editors.append(editor)
|
ui.user_metadata_editors.append(editor)
|
||||||
|
|
||||||
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
related_tabs.append(tab)
|
||||||
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
|
|
||||||
ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder")
|
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||||
|
button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
|
||||||
|
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||||
|
|
||||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||||
|
|
||||||
def toggle_visibility(is_visible):
|
for tab in unrelated_tabs:
|
||||||
is_visible = not is_visible
|
tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False)
|
||||||
|
|
||||||
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
|
for tab in related_tabs:
|
||||||
|
tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False)
|
||||||
def fill_tabs(is_empty):
|
|
||||||
"""Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
|
|
||||||
|
|
||||||
|
def pages_html():
|
||||||
if not ui.pages_contents:
|
if not ui.pages_contents:
|
||||||
refresh()
|
return refresh()
|
||||||
|
|
||||||
if is_empty:
|
return ui.pages_contents
|
||||||
return True, *ui.pages_contents
|
|
||||||
|
|
||||||
return True, *[gr.update() for _ in ui.pages_contents]
|
|
||||||
|
|
||||||
state_visible = gr.State(value=False)
|
|
||||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
|
|
||||||
|
|
||||||
state_empty = gr.State(value=True)
|
|
||||||
button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
|
|
||||||
|
|
||||||
def refresh():
|
def refresh():
|
||||||
for pg in ui.stored_extra_pages:
|
for pg in ui.stored_extra_pages:
|
||||||
@ -406,6 +399,7 @@ def create_ui(container, button, tabname):
|
|||||||
|
|
||||||
return ui.pages_contents
|
return ui.pages_contents
|
||||||
|
|
||||||
|
interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
|
||||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||||
|
|
||||||
return ui
|
return ui
|
||||||
|
10
style.css
10
style.css
@ -779,9 +779,14 @@ footer {
|
|||||||
/* extra networks UI */
|
/* extra networks UI */
|
||||||
|
|
||||||
.extra-network-cards{
|
.extra-network-cards{
|
||||||
height: 725px;
|
height: calc(100vh - 24rem);
|
||||||
overflow: scroll;
|
overflow: clip scroll;
|
||||||
resize: vertical;
|
resize: vertical;
|
||||||
|
min-height: 52rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-networks > div.tab-nav{
|
||||||
|
min-height: 3.4rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
.extra-networks > div > [id *= '_extra_']{
|
.extra-networks > div > [id *= '_extra_']{
|
||||||
@ -797,7 +802,6 @@ footer {
|
|||||||
}
|
}
|
||||||
.extra-networks .tab-nav .search,
|
.extra-networks .tab-nav .search,
|
||||||
.extra-networks .tab-nav .sort{
|
.extra-networks .tab-nav .sort{
|
||||||
display: inline-block;
|
|
||||||
margin: 0.3em;
|
margin: 0.3em;
|
||||||
align-self: center;
|
align-self: center;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user