mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-14 00:39:01 +08:00
speedup, second order correction, intensity control
This commit is contained in:
parent
a63cf10650
commit
6b6396f4a6
@ -13,7 +13,7 @@ import k_diffusion as K
|
||||
|
||||
# Debugging notes - the original method apply_model is being called for sd1.5 is in modules.sd_hijack_utils and is ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
# For sdxl - OpenAIWrapper will be called, which will call the underlying diffusion_model
|
||||
|
||||
# When controlnet is enabled, the underlying model is not available to use, therefore we skip
|
||||
|
||||
def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
||||
x = p.init_latent
|
||||
@ -78,11 +78,11 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
||||
return x / x.std()
|
||||
|
||||
|
||||
Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"])
|
||||
Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment", "second_order_correction", "noise_sigma_intensity"])
|
||||
|
||||
|
||||
# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
|
||||
def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
||||
def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps, correction_factor, sigma_intensity):
|
||||
x = p.init_latent
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
@ -98,11 +98,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
||||
|
||||
for i in trange(1, len(sigmas)):
|
||||
shared.state.sampling_step += 1
|
||||
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
|
||||
|
||||
|
||||
if shared.sd_model.is_sdxl:
|
||||
cond_tensor = cond['crossattn']
|
||||
uncond_tensor = uncond['crossattn']
|
||||
@ -113,13 +109,53 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
||||
image_conditioning = torch.cat([p.image_conditioning] * 2)
|
||||
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
|
||||
|
||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
|
||||
|
||||
if i == 1:
|
||||
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
|
||||
dt = (sigmas[i] - sigmas[i - 1]) / (2 * sigmas[i])
|
||||
else:
|
||||
t = dnw.sigma_to_t(sigma_in)
|
||||
dt = (sigmas[i] - sigmas[i - 1]) / sigmas[i - 1]
|
||||
|
||||
noise = noise_from_model(x, t, dt, sigma_in, cond_in, cfg_scale, dnw, skip)
|
||||
|
||||
if correction_factor > 0:
|
||||
recalculated_noise = noise_from_model(x + noise, t, dt, sigma_in, cond_in, cfg_scale, dnw, skip)
|
||||
noise = recalculated_noise * correction_factor + noise * (1 - correction_factor)
|
||||
|
||||
x += noise
|
||||
|
||||
sd_samplers_common.store_latent(x)
|
||||
|
||||
# This shouldn't be necessary, but solved some VRAM issues
|
||||
#del x_in, sigma_in, cond_in, c_out, c_in, t
|
||||
#del eps, denoised_uncond, denoised_cond, denoised, dt
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
return x / (x.std()*(1 - sigma_intensity) + sigmas[-1]*sigma_intensity)
|
||||
|
||||
def noise_from_model(x, t, dt, sigma_in, cond_in, cfg_scale, dnw, skip):
|
||||
|
||||
if cfg_scale == 1: # Case where denoised_uncond should not be calculated - 50% speedup, also good for sdxl in experiments
|
||||
x_in = x
|
||||
sigma_in = sigma_in[1:2]
|
||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
|
||||
cond_in = {"c_concat":[cond_in["c_concat"][0][1:2]], "c_crossattn": [cond_in["c_crossattn"][0][1:2]]}
|
||||
if shared.sd_model.is_sdxl:
|
||||
num_classes_hack = shared.sd_model.model.diffusion_model.num_classes
|
||||
shared.sd_model.model.diffusion_model.num_classes = None
|
||||
try:
|
||||
eps = shared.sd_model.model(x_in * c_in, t[1:2], {"crossattn": cond_in["c_crossattn"][0]})
|
||||
finally:
|
||||
shared.sd_model.model.diffusion_model.num_classes = num_classes_hack
|
||||
else:
|
||||
eps = shared.sd_model.apply_model(x_in * c_in, t[1:2], cond=cond_in)
|
||||
|
||||
return -eps * c_out* dt
|
||||
else :
|
||||
x_in = torch.cat([x] * 2)
|
||||
|
||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
|
||||
|
||||
if shared.sd_model.is_sdxl:
|
||||
num_classes_hack = shared.sd_model.model.diffusion_model.num_classes
|
||||
@ -131,28 +167,11 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
||||
else:
|
||||
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
|
||||
|
||||
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
|
||||
denoised_uncond, denoised_cond = (eps * c_out).chunk(2)
|
||||
|
||||
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
|
||||
|
||||
if i == 1:
|
||||
d = (x - denoised) / (2 * sigmas[i])
|
||||
else:
|
||||
d = (x - denoised) / sigmas[i - 1]
|
||||
|
||||
dt = sigmas[i] - sigmas[i - 1]
|
||||
x = x + d * dt
|
||||
|
||||
sd_samplers_common.store_latent(x)
|
||||
|
||||
# This shouldn't be necessary, but solved some VRAM issues
|
||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||
del eps, denoised_uncond, denoised_cond, denoised, d, dt
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
return x / sigmas[-1]
|
||||
|
||||
return -denoised * dt
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
@ -183,6 +202,8 @@ class Script(scripts.Script):
|
||||
cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg"))
|
||||
randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness"))
|
||||
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
|
||||
second_order_correction = gr.Slider(label="Correct noise by running model again", minimum=0.0, maximum=1.0, step=0.01, value=0.5, elem_id=self.elem_id("second_order_correction"))
|
||||
noise_sigma_intensity = gr.Slider(label="Weight scaling std vs sigma based", minimum=-1.0, maximum=2.0, step=0.01, value=0.5, elem_id=self.elem_id("noise_sigma_intensity"))
|
||||
|
||||
return [
|
||||
info,
|
||||
@ -190,10 +211,11 @@ class Script(scripts.Script):
|
||||
override_prompt, original_prompt, original_negative_prompt,
|
||||
override_steps, st,
|
||||
override_strength,
|
||||
cfg, randomness, sigma_adjustment,
|
||||
cfg, randomness, sigma_adjustment, second_order_correction,
|
||||
noise_sigma_intensity
|
||||
]
|
||||
|
||||
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
|
||||
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment, second_order_correction, noise_sigma_intensity):
|
||||
# Override
|
||||
if override_sampler:
|
||||
p.sampler_name = "Euler"
|
||||
@ -211,7 +233,9 @@ class Script(scripts.Script):
|
||||
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
|
||||
and self.cache.original_prompt == original_prompt \
|
||||
and self.cache.original_negative_prompt == original_negative_prompt \
|
||||
and self.cache.sigma_adjustment == sigma_adjustment
|
||||
and self.cache.sigma_adjustment == sigma_adjustment \
|
||||
and self.cache.second_order_correction == second_order_correction \
|
||||
and self.cache.noise_sigma_intensity == noise_sigma_intensity
|
||||
same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
|
||||
|
||||
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
|
||||
@ -231,10 +255,10 @@ class Script(scripts.Script):
|
||||
cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
|
||||
uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
|
||||
if sigma_adjustment:
|
||||
rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)
|
||||
rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st, second_order_correction, noise_sigma_intensity)
|
||||
else:
|
||||
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
|
||||
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
|
||||
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment, second_order_correction, noise_sigma_intensity)
|
||||
|
||||
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user