From 847f869c67c7108e3e792fc193331d0e6acca29c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 21:00:52 +0300 Subject: [PATCH] experimental optimization --- modules/processing.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 61e97077c..a408d622e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -544,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] + cached_uc = [None, None] + cached_c = [None, None] + + def get_conds_with_caching(function, required_prompts, steps, cache): + """ + Returns the result of calling function(shared.sd_model, required_prompts, steps) + using a cache to store the result if the same arguments have been used before. + + cache is an array containing two elements. The first element is a tuple + representing the previously used arguments, or None if no arguments + have been used before. The second element is where the previously + computed result is stored. + """ + + if cache[0] is not None and (required_prompts, steps) == cache[0]: + return cache[1] + + with devices.autocast(): + cache[1] = function(shared.sd_model, required_prompts, steps) + + cache[0] = (required_prompts, steps) + return cache[1] + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -571,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) - with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps) - c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) + uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc) + c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: