clear GenerationParametersList before batch

clears any generation parameters that are with the attribute to_be_clear_before_batch = True
prevent buildup of some parameters
This commit is contained in:
w-e-w 2024-11-24 20:07:00 +09:00
parent 025080218f
commit ac8c05398b
2 changed files with 18 additions and 3 deletions

View File

@ -457,7 +457,7 @@ class StableDiffusionProcessing:
opts.emphasis,
)
def apply_generation_params_states(self, generation_params_states):
def apply_generation_params_list(self, generation_params_states):
"""add and apply generation_params_states to self.extra_generation_params"""
for key, value in generation_params_states.items():
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
@ -465,6 +465,12 @@ class StableDiffusionProcessing:
else:
self.extra_generation_params[key] = value
def clear_marked_generation_params(self):
"""clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
for key, value in list(self.extra_generation_params.items()):
if getattr(value, 'to_be_clear_before_batch', False):
self.extra_generation_params.pop(key)
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
@ -491,7 +497,7 @@ class StableDiffusionProcessing:
if len(cache) == 3:
generation_params_states, cached_cached_params = cache[2]
if cached_params == cached_cached_params:
self.apply_generation_params_states(generation_params_states)
self.apply_generation_params_list(generation_params_states)
return cache[1]
cache = caches[0]
@ -500,7 +506,7 @@ class StableDiffusionProcessing:
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
generation_params_states = model_hijack.extract_generation_params_states()
self.apply_generation_params_states(generation_params_states)
self.apply_generation_params_list(generation_params_states)
if len(cache) == 2:
cache.append((generation_params_states, cached_params))
else:
@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted or state.stopping_generation:
break
p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
sd_models.reload_model_weights() # model can be changed for example by refiner
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]

View File

@ -308,9 +308,17 @@ class GenerationParametersList(list):
if return str, the value will be written to infotext, if return None will be ignored.
"""
def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
super().__init__(*args, **kwargs)
self._to_be_clear_before_batch = to_be_clear_before_batch
def __call__(self, *args, **kwargs):
return ', '.join(sorted(set(self), key=natural_sort_key))
@property
def to_be_clear_before_batch(self):
return self._to_be_clear_before_batch
def __add__(self, other):
if isinstance(other, GenerationParametersList):
return self.__class__([*self, *other])