mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
initial refiner support
This commit is contained in:
parent
57e8a11d17
commit
f1975b0213
@ -666,6 +666,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# after running refiner, the refiner model is not unloaded - webui swaps back to main model here
|
||||||
|
if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
|
||||||
|
sd_models.reload_model_weights()
|
||||||
|
|
||||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||||
p.override_settings.pop('sd_model_checkpoint', None)
|
p.override_settings.pop('sd_model_checkpoint', None)
|
||||||
|
@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class SkipWritingToConfig:
|
||||||
|
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
|
||||||
|
|
||||||
|
skip = False
|
||||||
|
previous = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.previous = SkipWritingToConfig.skip
|
||||||
|
SkipWritingToConfig.skip = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
SkipWritingToConfig.skip = self.previous
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
timer.record("calculate hash")
|
timer.record("calculate hash")
|
||||||
|
|
||||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
if not SkipWritingToConfig.skip:
|
||||||
|
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||||
|
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
@ -2,7 +2,7 @@ from collections import namedtuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
|
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
@ -127,3 +127,20 @@ def replace_torchsde_browinan():
|
|||||||
|
|
||||||
|
|
||||||
replace_torchsde_browinan()
|
replace_torchsde_browinan()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_refiner(sampler):
|
||||||
|
completed_ratio = sampler.step / sampler.steps
|
||||||
|
if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint:
|
||||||
|
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
|
||||||
|
if refiner_checkpoint_info is None:
|
||||||
|
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
|
||||||
|
|
||||||
|
with sd_models.SkipWritingToConfig():
|
||||||
|
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
sampler.update_inner_model()
|
||||||
|
|
||||||
|
sampler.p.setup_conds()
|
||||||
|
@ -19,7 +19,8 @@ samplers_data_compvis = [
|
|||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
class VanillaStableDiffusionSampler:
|
||||||
def __init__(self, constructor, sd_model):
|
def __init__(self, constructor, sd_model):
|
||||||
self.sampler = constructor(sd_model)
|
self.p = None
|
||||||
|
self.sampler = constructor(shared.sd_model)
|
||||||
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
||||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||||
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
||||||
@ -32,6 +33,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
|
self.steps = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.stop_at = None
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
@ -44,6 +46,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
def launch_sampling(self, steps, func):
|
||||||
|
self.steps = steps
|
||||||
state.sampling_steps = steps
|
state.sampling_steps = steps
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
@ -61,10 +64,15 @@ class VanillaStableDiffusionSampler:
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def update_inner_model(self):
|
||||||
|
self.sampler.model = shared.sd_model
|
||||||
|
|
||||||
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
|
sd_samplers_common.apply_refiner(self)
|
||||||
|
|
||||||
if self.stop_at is not None and self.step > self.stop_at:
|
if self.stop_at is not None and self.step > self.stop_at:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
@ -134,6 +142,8 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.update_step(x)
|
self.update_step(x)
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
|
self.p = p
|
||||||
|
|
||||||
if self.is_ddim:
|
if self.is_ddim:
|
||||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||||
else:
|
else:
|
||||||
|
@ -2,7 +2,7 @@ from collections import deque
|
|||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
|
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra, sd_models
|
||||||
|
|
||||||
from modules.processing import StableDiffusionProcessing
|
from modules.processing import StableDiffusionProcessing
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
@ -87,15 +87,25 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
negative prompt.
|
negative prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.model_wrap = None
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.steps = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.image_cfg_scale = None
|
self.image_cfg_scale = None
|
||||||
self.padded_cond_uncond = False
|
self.padded_cond_uncond = False
|
||||||
|
self.p = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||||
|
|
||||||
|
return self.model_wrap
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
@ -113,10 +123,15 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
def update_inner_model(self):
|
||||||
|
self.model_wrap = None
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
|
sd_samplers_common.apply_refiner(self)
|
||||||
|
|
||||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||||
# so is_edit_model is set to False to support AND composition.
|
# so is_edit_model is set to False to support AND composition.
|
||||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||||
@ -267,13 +282,13 @@ class TorchHijack:
|
|||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
def __init__(self, funcname, sd_model):
|
def __init__(self, funcname, sd_model):
|
||||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
|
||||||
|
|
||||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
self.p = None
|
||||||
self.funcname = funcname
|
self.funcname = funcname
|
||||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
self.model_wrap_cfg = CFGDenoiser()
|
||||||
|
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.stop_at = None
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
@ -305,6 +320,7 @@ class KDiffusionSampler:
|
|||||||
shared.total_tqdm.update()
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
def launch_sampling(self, steps, func):
|
||||||
|
self.model_wrap_cfg.steps = steps
|
||||||
state.sampling_steps = steps
|
state.sampling_steps = steps
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
@ -324,6 +340,8 @@ class KDiffusionSampler:
|
|||||||
return p.steps
|
return p.steps
|
||||||
|
|
||||||
def initialize(self, p: StableDiffusionProcessing):
|
def initialize(self, p: StableDiffusionProcessing):
|
||||||
|
self.p = p
|
||||||
|
self.model_wrap_cfg.p = p
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.model_wrap_cfg.step = 0
|
self.model_wrap_cfg.step = 0
|
||||||
|
@ -461,6 +461,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||||
|
"sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"),
|
||||||
|
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||||
|
Loading…
Reference in New Issue
Block a user