Merge pull request #10997 from AUTOMATIC1111/fix-conds-caching-with-extra-network

fix conds caching with extra network
This commit is contained in:
AUTOMATIC1111 2023-06-04 12:07:41 +03:00 committed by GitHub
commit 1c6dca9383
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 15 deletions

View File

@ -32,6 +32,9 @@ class ExtraNetworkParams:
else: else:
self.positional.append(item) self.positional.append(item)
def __eq__(self, other):
return self.items == other.items
class ExtraNetwork: class ExtraNetwork:
def __init__(self, name): def __init__(self, name):

View File

@ -171,6 +171,7 @@ class StableDiffusionProcessing:
self.prompts = None self.prompts = None
self.negative_prompts = None self.negative_prompts = None
self.extra_network_data = None
self.seeds = None self.seeds = None
self.subseeds = None self.subseeds = None
@ -311,7 +312,7 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
def get_conds_with_caching(self, function, required_prompts, steps, cache): def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data):
""" """
Returns the result of calling function(shared.sd_model, required_prompts, steps) Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before. using a cache to store the result if the same arguments have been used before.
@ -321,26 +322,24 @@ class StableDiffusionProcessing:
have been used before. The second element is where the previously have been used before. The second element is where the previously
computed result is stored. computed result is stored.
""" """
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]: if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
return cache[1] return cache[1]
with devices.autocast(): with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps) cache[1] = function(shared.sd_model, required_prompts, steps)
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
return cache[1] return cache[1]
def setup_conds(self): def setup_conds(self):
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts) self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
return extra_network_data
class Processed: class Processed:
@ -681,7 +680,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
extra_network_data = None
for n in range(p.n_iter): for n in range(p.n_iter):
p.iteration = n p.iteration = n
@ -702,11 +700,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(p.prompts) == 0: if len(p.prompts) == 0:
break break
extra_network_data = p.parse_extra_network_prompts() p.parse_extra_network_prompts()
if not p.disable_extra_networks: if not p.disable_extra_networks:
with devices.autocast(): with devices.autocast():
extra_networks.activate(p, extra_network_data) extra_networks.activate(p, p.extra_network_data)
if p.scripts is not None: if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@ -828,8 +826,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and extra_network_data: if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, extra_network_data) extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc() devices.torch_gc()
@ -1101,8 +1099,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
super().setup_conds() super().setup_conds()
if self.enable_hr: if self.enable_hr:
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc) self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c) self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts() res = super().parse_extra_network_prompts()