mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-01 20:35:06 +08:00
257 lines
8.2 KiB
Python
257 lines
8.2 KiB
Python
import re
|
|
import dataclasses
|
|
import os
|
|
import gradio as gr
|
|
|
|
from modules import errors, shared
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PostprocessedImageSharedInfo:
|
|
target_width: int = None
|
|
target_height: int = None
|
|
|
|
|
|
class PostprocessedImage:
|
|
def __init__(self, image):
|
|
self.image = image
|
|
self.info = {}
|
|
self.shared = PostprocessedImageSharedInfo()
|
|
self.extra_images = []
|
|
self.nametags = []
|
|
self.disable_processing = False
|
|
self.caption = None
|
|
|
|
def get_suffix(self, used_suffixes=None):
|
|
used_suffixes = {} if used_suffixes is None else used_suffixes
|
|
suffix = "-".join(self.nametags)
|
|
if suffix:
|
|
suffix = "-" + suffix
|
|
|
|
if suffix not in used_suffixes:
|
|
used_suffixes[suffix] = 1
|
|
return suffix
|
|
|
|
for i in range(1, 100):
|
|
proposed_suffix = suffix + "-" + str(i)
|
|
|
|
if proposed_suffix not in used_suffixes:
|
|
used_suffixes[proposed_suffix] = 1
|
|
return proposed_suffix
|
|
|
|
return suffix
|
|
|
|
def create_copy(self, new_image, *, nametags=None, disable_processing=False):
|
|
pp = PostprocessedImage(new_image)
|
|
pp.shared = self.shared
|
|
pp.nametags = self.nametags.copy()
|
|
pp.info = self.info.copy()
|
|
pp.disable_processing = disable_processing
|
|
|
|
if nametags is not None:
|
|
pp.nametags += nametags
|
|
|
|
return pp
|
|
|
|
|
|
class ScriptPostprocessing:
|
|
filename = None
|
|
controls = None
|
|
args_from = None
|
|
args_to = None
|
|
|
|
# define if the script should be used only in extras or main UI
|
|
extra_only = None
|
|
main_ui_only = None
|
|
|
|
order = 1000
|
|
"""scripts will be ordred by this value in postprocessing UI"""
|
|
|
|
name = None
|
|
"""this function should return the title of the script."""
|
|
|
|
group = None
|
|
"""A gr.Group component that has all script's UI inside it"""
|
|
|
|
def ui(self):
|
|
"""
|
|
This function should create gradio UI elements. See https://gradio.app/docs/#components
|
|
The return value should be a dictionary that maps parameter names to components used in processing.
|
|
Values of those components will be passed to process() function.
|
|
"""
|
|
|
|
pass
|
|
|
|
def process(self, pp: PostprocessedImage, **args):
|
|
"""
|
|
This function is called to postprocess the image.
|
|
args contains a dictionary with all values returned by components from ui()
|
|
"""
|
|
|
|
pass
|
|
|
|
def process_firstpass(self, pp: PostprocessedImage, **args):
|
|
"""
|
|
Called for all scripts before calling process(). Scripts can examine the image here and set fields
|
|
of the pp object to communicate things to other scripts.
|
|
args contains a dictionary with all values returned by components from ui()
|
|
"""
|
|
|
|
pass
|
|
|
|
def image_changed(self):
|
|
pass
|
|
|
|
tab_name = '' # used by ScriptPostprocessingForMainUI
|
|
replace_pattern = re.compile(r'\s')
|
|
rm_pattern = re.compile(r'[^a-z_0-9]')
|
|
|
|
def elem_id(self, item_id):
|
|
"""
|
|
Helper function to generate id for a HTML element
|
|
constructs final id out of script name and user-supplied item_id
|
|
'script_extras_{self.name.lower()}_{item_id}'
|
|
{tab_name} will append to the end of the id if set
|
|
tab_name will be set to '_img2img' or '_txt2img' if use by ScriptPostprocessingForMainUI
|
|
|
|
Extensions should use this function to generate element IDs
|
|
"""
|
|
return self.elem_id_suffix(f'extras_{self.name.lower()}_{item_id}')
|
|
|
|
def elem_id_suffix(self, base_id):
|
|
"""
|
|
Append tab_name to the base_id
|
|
|
|
Extensions that already have specific there element IDs and wish to keep their IDs the same when possible should use this function
|
|
"""
|
|
base_id = self.rm_pattern.sub('', self.replace_pattern.sub('_', base_id))
|
|
return f'{base_id}{self.tab_name}'
|
|
|
|
|
|
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
|
try:
|
|
res = func(*args, **kwargs)
|
|
return res
|
|
except Exception as e:
|
|
errors.display(e, f"calling {filename}/{funcname}")
|
|
|
|
return default
|
|
|
|
|
|
class ScriptPostprocessingRunner:
|
|
def __init__(self):
|
|
self.scripts = None
|
|
self.ui_created = False
|
|
|
|
def initialize_scripts(self, scripts_data):
|
|
self.scripts = []
|
|
|
|
for script_data in scripts_data:
|
|
script: ScriptPostprocessing = script_data.script_class()
|
|
script.filename = script_data.path
|
|
self.scripts.append(script)
|
|
|
|
def create_script_ui(self, script, inputs):
|
|
script.args_from = len(inputs)
|
|
script.args_to = len(inputs)
|
|
|
|
script.controls = wrap_call(script.ui, script.filename, "ui")
|
|
|
|
for control in script.controls.values():
|
|
control.custom_script_source = os.path.basename(script.filename)
|
|
|
|
inputs += list(script.controls.values())
|
|
script.args_to = len(inputs)
|
|
|
|
def scripts_in_preferred_order(self):
|
|
if self.scripts is None:
|
|
import modules.scripts
|
|
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
|
|
|
scripts_order = shared.opts.postprocessing_operation_order
|
|
scripts_filter_out = set(shared.opts.postprocessing_disable_in_extras)
|
|
|
|
def script_score(name):
|
|
for i, possible_match in enumerate(scripts_order):
|
|
if possible_match == name:
|
|
return i
|
|
|
|
return len(self.scripts)
|
|
|
|
filtered_scripts = [script for script in self.scripts if script.name not in scripts_filter_out and not script.main_ui_only]
|
|
script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(filtered_scripts)}
|
|
|
|
return sorted(filtered_scripts, key=lambda x: script_scores[x.name])
|
|
|
|
def setup_ui(self):
|
|
inputs = []
|
|
|
|
for script in self.scripts_in_preferred_order():
|
|
with gr.Row() as group:
|
|
self.create_script_ui(script, inputs)
|
|
|
|
script.group = group
|
|
|
|
self.ui_created = True
|
|
return inputs
|
|
|
|
def run(self, pp: PostprocessedImage, args):
|
|
scripts = []
|
|
|
|
for script in self.scripts_in_preferred_order():
|
|
script_args = args[script.args_from:script.args_to]
|
|
|
|
process_args = {}
|
|
for (name, _component), value in zip(script.controls.items(), script_args):
|
|
process_args[name] = value
|
|
|
|
scripts.append((script, process_args))
|
|
|
|
for script, process_args in scripts:
|
|
script.process_firstpass(pp, **process_args)
|
|
|
|
all_images = [pp]
|
|
|
|
for script, process_args in scripts:
|
|
if shared.state.skipped:
|
|
break
|
|
|
|
shared.state.job = script.name
|
|
|
|
for single_image in all_images.copy():
|
|
|
|
if not single_image.disable_processing:
|
|
script.process(single_image, **process_args)
|
|
|
|
for extra_image in single_image.extra_images:
|
|
if not isinstance(extra_image, PostprocessedImage):
|
|
extra_image = single_image.create_copy(extra_image)
|
|
|
|
all_images.append(extra_image)
|
|
|
|
single_image.extra_images.clear()
|
|
|
|
pp.extra_images = all_images[1:]
|
|
|
|
def create_args_for_run(self, scripts_args):
|
|
if not self.ui_created:
|
|
with gr.Blocks(analytics_enabled=False):
|
|
self.setup_ui()
|
|
|
|
scripts = self.scripts_in_preferred_order()
|
|
args = [None] * max([x.args_to for x in scripts])
|
|
|
|
for script in scripts:
|
|
script_args_dict = scripts_args.get(script.name, None)
|
|
if script_args_dict is not None:
|
|
|
|
for i, name in enumerate(script.controls):
|
|
args[script.args_from + i] = script_args_dict.get(name, None)
|
|
|
|
return args
|
|
|
|
def image_changed(self):
|
|
for script in self.scripts_in_preferred_order():
|
|
script.image_changed()
|
|
|