allow first pass and hires pass to use a single prompt to do different prompt editing, hires is 1.0..2.0:

relative time range is [1..2]
  absolute time range is [steps+1..steps+hire_steps], e.g. with 30 steps and 20 hires steps, '20' is 2/3rds through first pass, and 40 is halfway through hires pass
This commit is contained in:
Robert Barron 2023-08-09 07:46:30 -07:00
parent 5a38a9c0ee
commit d1ba46b6e1
3 changed files with 44 additions and 18 deletions

View File

@ -319,12 +319,14 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
def cached_params(self, required_prompts, steps, extra_network_data): def cached_params(self, required_prompts, steps, hires_steps, extra_network_data, use_old_scheduling):
"""Returns parameters that invalidate the cond cache if changed""" """Returns parameters that invalidate the cond cache if changed"""
return ( return (
required_prompts, required_prompts,
steps, steps,
hires_steps,
use_old_scheduling,
opts.CLIP_stop_at_last_layers, opts.CLIP_stop_at_last_layers,
shared.sd_model.sd_checkpoint_info, shared.sd_model.sd_checkpoint_info,
extra_network_data, extra_network_data,
@ -334,7 +336,7 @@ class StableDiffusionProcessing:
self.height, self.height,
) )
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data): def get_conds_with_caching(self, function, required_prompts, steps, hires_steps, caches, extra_network_data):
""" """
Returns the result of calling function(shared.sd_model, required_prompts, steps) Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before. using a cache to store the result if the same arguments have been used before.
@ -347,7 +349,7 @@ class StableDiffusionProcessing:
caches is a list with items described above. caches is a list with items described above.
""" """
cached_params = self.cached_params(required_prompts, steps, extra_network_data) cached_params = self.cached_params(required_prompts, steps, hires_steps, extra_network_data, shared.opts.use_old_scheduling)
for cache in caches: for cache in caches:
if cache[0] is not None and cached_params == cache[0]: if cache[0] is not None and cached_params == cache[0]:
@ -356,7 +358,7 @@ class StableDiffusionProcessing:
cache = caches[0] cache = caches[0]
with devices.autocast(): with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps) cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
cache[0] = cached_params cache[0] = cached_params
return cache[1] return cache[1]
@ -367,8 +369,9 @@ class StableDiffusionProcessing:
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.firstpass_steps = self.steps * self.step_multiplier
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.firstpass_steps, None, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.firstpass_steps, None, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
@ -1225,8 +1228,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) hires_steps = (self.hr_second_pass_steps or self.steps) * self.step_multiplier
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
def setup_conds(self): def setup_conds(self):
super().setup_conds() super().setup_conds()

View File

@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER %import common.SIGNED_NUMBER -> NUMBER
""") """)
def get_learned_conditioning_prompt_schedules(prompts, steps): def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
""" """
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test") >>> g("test")
@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g("[fe|||]male") >>> g("[fe|||]male")
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
>>> g("a [b:.5] c")
[[10, 'a b c']]
>>> g("a [b:1.5] c")
[[5, 'a c'], [10, 'a b c']]
""" """
if hires_steps is None or use_old_scheduling:
int_offset = 0
flt_offset = 0
steps = base_steps
else:
int_offset = base_steps
flt_offset = 1.0
steps = hires_steps
def collect_steps(steps, tree): def collect_steps(steps, tree):
res = [steps] res = [steps]
class CollectSteps(lark.Visitor): class CollectSteps(lark.Visitor):
def scheduled(self, tree): def scheduled(self, tree):
tree.children[-2] = float(tree.children[-2]) s = tree.children[-2]
if tree.children[-2] < 1: v = float(s)
tree.children[-2] *= steps if use_old_scheduling:
tree.children[-2] = min(steps, int(tree.children[-2])) v = v*steps if v<1 else v
res.append(tree.children[-2]) else:
if "." in s:
v = (v - flt_offset) * steps
else:
v = (v - int_offset)
tree.children[-2] = min(steps, int(v))
if tree.children[-2] >= 1:
res.append(tree.children[-2])
def alternate(self, tree): def alternate(self, tree):
res.extend(range(1, steps+1)) res.extend(range(1, steps+1))
@ -134,7 +155,7 @@ class SdConditioning(list):
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one. and the sampling step at which this condition is to be replaced by the next one.
@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
""" """
res = [] res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
cache = {} cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules): for prompt, prompt_schedule in zip(prompts, prompt_schedules):
@ -229,7 +250,7 @@ class MulticondLearnedConditioning:
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator. For each prompt, the list is obtained by splitting the prompt using the AND separator.
@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps) learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
res = [] res = []
for indexes in res_indexes: for indexes in res_indexes:

View File

@ -516,6 +516,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt where first pass and hires both used the same timeline, and < 1 meant relative and >= 1 meant absolute"),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate"), { options_templates.update(options_section(('interrogate', "Interrogate"), {