extra networks UI

rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>
This commit is contained in:
AUTOMATIC 2023-01-21 08:36:07 +03:00
parent e33cace2c2
commit 40ff6db532
25 changed files with 767 additions and 216 deletions

BIN
html/card-no-preview.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

View File

@ -0,0 +1,11 @@
<div class='card' {preview_html} onclick='return cardClicked({prompt}, {allow_negative_prompt})'>
<div class='actions'>
<div class='additional'>
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick='return saveCardPreview(event, {tabname}, {local_preview})'>replace preview</a>
</ul>
</div>
<span class='name'>{name}</span>
</div>
</div>

View File

@ -0,0 +1,8 @@
<div class='nocards'>
<h1>Nothing here. Add some content to the following directories:</h1>
<ul>
{dirs}
</ul>
</div>

View File

@ -0,0 +1,60 @@
function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh'))
gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close'))
}
var activePromptTextarea = null;
var activePositivePromptTextarea = null;
function setupExtraNetworks(){
setupExtraNetworksForTab('txt2img')
setupExtraNetworksForTab('img2img')
function registerPrompt(id, isNegative){
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
if (activePromptTextarea == null){
activePromptTextarea = textarea
}
if (activePositivePromptTextarea == null && ! isNegative){
activePositivePromptTextarea = textarea
}
textarea.addEventListener("focus", function(){
activePromptTextarea = textarea;
if(! isNegative) activePositivePromptTextarea = textarea;
});
}
registerPrompt('txt2img_prompt')
registerPrompt('txt2img_neg_prompt', true)
registerPrompt('img2img_prompt')
registerPrompt('img2img_neg_prompt', true)
}
onUiLoaded(setupExtraNetworks)
function cardClicked(textToAdd, allowNegativePrompt){
textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea
textarea.value = textarea.value + " " + textToAdd
updateInput(textarea)
return false
}
function saveCardPreview(event, tabname, filename){
textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
button = gradioApp().getElementById(tabname + '_save_preview')
textarea.value = filename
updateInput(textarea)
button.click()
event.stopPropagation()
event.preventDefault()
}

View File

@ -21,6 +21,8 @@ titles = {
"\U0001F5D1": "Clear prompt", "\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field", "\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show extra networks",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",

View File

@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
return [prompt, negative_prompt] return [prompt, negative_prompt]
} }
opts = {} opts = {}
onUiUpdate(function(){ onUiUpdate(function(){
if(Object.keys(opts).length != 0) return; if(Object.keys(opts).length != 0) return;
@ -239,11 +237,14 @@ onUiUpdate(function(){
return return
} }
prompt.parentElement.insertBefore(counter, prompt) prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter") counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative" prompt.parentElement.style.position = "relative"
textarea.addEventListener("input", () => update_token_counter(id_button)); textarea.addEventListener("input", function(){
update_token_counter(id_button);
});
} }
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
@ -261,10 +262,8 @@ onUiUpdate(function(){
}) })
} }
} }
}) })
onOptionsChanged(function(){ onOptionsChanged(function(){
elem = gradioApp().getElementById('sd_checkpoint_hash') elem = gradioApp().getElementById('sd_checkpoint_hash')
sd_checkpoint_hash = opts.sd_checkpoint_hash || "" sd_checkpoint_hash = opts.sd_checkpoint_hash || ""

View File

@ -480,7 +480,7 @@ class Api:
def train_hypernetwork(self, args: dict): def train_hypernetwork(self, args: dict):
try: try:
shared.state.begin() shared.state.begin()
initial_hypernetwork = shared.loaded_hypernetwork shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations apply_optimizations = shared.opts.training_xattention_optimizations
error = None error = None
filename = '' filename = ''
@ -491,7 +491,6 @@ class Api:
except Exception as e: except Exception as e:
error = e error = e
finally: finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
if not apply_optimizations: if not apply_optimizations:

147
modules/extra_networks.py Normal file
View File

@ -0,0 +1,147 @@
import re
from collections import defaultdict
from modules import errors
extra_network_registry = {}
def initialize():
extra_network_registry.clear()
def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
class ExtraNetwork:
def __init__(self, name):
self.name = name
def activate(self, p, params_list):
"""
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
Passes arguments related to this extra network in params_list.
User passes arguments by specifying this in his prompt:
<name:arg1:arg2:arg3>
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
separated by colon.
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
in this case, all effects of this extra networks should be disabled.
Can be called multiple times before deactivate() - each new call should override the previous call completely.
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
params_list will be:
[
ExtraNetworkParams(items=["agm", "1.1"]),
ExtraNetworkParams(items=["ray"])
]
"""
raise NotImplementedError
def deactivate(self, p):
"""
Called at the end of processing for housekeeping. No need to do anything here.
"""
raise NotImplementedError
def activate(p, extra_network_data):
"""call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}")
continue
try:
extra_network.activate(p, extra_network_args)
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items():
args = extra_network_data.get(extra_network_name, None)
if args is not None:
continue
try:
extra_network.activate(p, [])
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}")
def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
continue
try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating extra network {extra_network_name}")
for extra_network_name, extra_network in extra_network_registry.items():
args = extra_network_data.get(extra_network_name, None)
if args is not None:
continue
try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
def parse_prompt(prompt):
res = defaultdict(list)
def found(m):
name = m.group(1)
args = m.group(2)
res[name].append(ExtraNetworkParams(items=args.split(":")))
return ""
prompt = re.sub(re_extra_net, found, prompt)
return prompt, res
def parse_prompts(prompts):
res = []
extra_data = None
for prompt in prompts:
updated_prompt, parsed_extra_data = parse_prompt(prompt)
if extra_data is None:
extra_data = parsed_extra_data
res.append(updated_prompt)
return res, extra_data

View File

@ -0,0 +1,21 @@
from modules import extra_networks
from modules.hypernetworks import hypernetwork
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('hypernet')
def activate(self, p, params_list):
names = []
multipliers = []
for params in params_list:
assert len(params.items) > 0
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
hypernetwork.load_hypernetworks(names, multipliers)
def deactivate(p, self):
pass

View File

@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict):
from modules import ui from modules import ui
settings_map = { settings_map = {
'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip', 'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight', 'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash', 'sd_model_checkpoint': 'Model hash',
@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Clip skip" not in res: if "Clip skip" not in res:
res["Clip skip"] = "1" res["Clip skip"] = "1"
if "Hypernet strength" not in res: hypernet = res.get("Hypernet", None)
res["Hypernet strength"] = "1" if hypernet is not None:
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hypernet" in res:
hypernet_name = res["Hypernet"]
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
if "Hires resize-1" not in res: if "Hires resize-1" not in res:
res["Hires resize-1"] = 0 res["Hires resize-1"] = 0

View File

@ -25,7 +25,6 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = { activation_dict = {
"linear": torch.nn.Identity, "linear": torch.nn.Identity,
"relu": torch.nn.ReLU, "relu": torch.nn.ReLU,
@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module):
add_layer_norm=False, activate_output=False, dropout_structure=None): add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__() super().__init__()
self.multiplier = 1.0
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
state_dict[to] = x state_dict[to] = x
def forward(self, x): def forward(self, x):
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) return x + self.linear(x) * (self.multiplier if not self.training else 1)
def trainables(self): def trainables(self):
layer_structure = [] layer_structure = []
@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module):
return layer_structure return layer_structure
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
if layer_structure is None: if layer_structure is None:
@ -192,6 +190,20 @@ class Hypernetwork:
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = mode param.requires_grad = mode
def to(self, device):
for k, layers in self.layers.items():
for layer in layers:
layer.to(device)
return self
def set_multiplier(self, multiplier):
for k, layers in self.layers.items():
for layer in layers:
layer.multiplier = multiplier
return self
def eval(self): def eval(self):
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
@ -269,10 +281,12 @@ class Hypernetwork:
self.optimizer_state_dict = None self.optimizer_state_dict = None
if self.optimizer_state_dict: if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
if shared.opts.print_hypernet_extra:
print("Loaded existing optimizer from checkpoint") print("Loaded existing optimizer from checkpoint")
print(f"Optimizer name is {self.optimizer_name}") print(f"Optimizer name is {self.optimizer_name}")
else: else:
self.optimizer_name = "AdamW" self.optimizer_name = "AdamW"
if shared.opts.print_hypernet_extra:
print("No saved optimizer exists in checkpoint") print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items(): for size, sd in state_dict.items():
@ -306,23 +320,43 @@ def list_hypernetworks(path):
return res return res
def load_hypernetwork(filename): def load_hypernetwork(name):
path = shared.hypernetworks.get(filename, None) path = shared.hypernetworks.get(name, None)
# Prevent any file named "None.pt" from being loaded.
if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
if path is None:
return None
hypernetwork = Hypernetwork()
try:
hypernetwork.load(path)
except Exception: except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr) print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: return None
if shared.loaded_hypernetwork is not None:
print("Unloading hypernetwork")
shared.loaded_hypernetwork = None return hypernetwork
def load_hypernetworks(names, multipliers=None):
already_loaded = {}
for hypernetwork in shared.loaded_hypernetworks:
if hypernetwork.name in names:
already_loaded[hypernetwork.name] = hypernetwork
shared.loaded_hypernetworks.clear()
for i, name in enumerate(names):
hypernetwork = already_loaded.get(name, None)
if hypernetwork is None:
hypernetwork = load_hypernetwork(name)
if hypernetwork is None:
continue
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str): def find_closest_hypernetwork_name(search: str):
@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
return applicable[0] return applicable[0]
def apply_hypernetwork(hypernetwork, context, layer=None): def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None: if hypernetwork_layers is None:
return context, context return context_k, context_v
if layer is not None: if layer is not None:
layer.hyper_k = hypernetwork_layers[0] layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1] layer.hyper_v = hypernetwork_layers[1]
context_k = hypernetwork_layers[0](context) context_k = hypernetwork_layers[0](context_k)
context_v = hypernetwork_layers[1](context) context_v = hypernetwork_layers[1](context_v)
return context_k, context_v
def apply_hypernetworks(hypernetworks, context, layer=None):
context_k = context
context_v = context
for hypernetwork in hypernetworks:
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
return context_k, context_v return context_k, context_v
@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
k = self.to_k(context_k) k = self.to_k(context_k)
v = self.to_v(context_v) v = self.to_v(context_v)
@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
template_file = template_file.path template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path) hypernetwork.load(path)
shared.loaded_hypernetworks = [hypernetwork]
shared.state.job = "train-hypernetwork" shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
else: else:
images_dir = None images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0 initial_step = hypernetwork.step or 0

View File

@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
def train_hypernetwork(*args): def train_hypernetwork(*args):
shared.loaded_hypernetworks = []
initial_hypernetwork = shared.loaded_hypernetwork
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()

View File

@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try: try:
for k, v in p.override_settings.items(): for k, v in p.override_settings.items():
setattr(opts, k, v) setattr(opts, k, v)
if k == 'sd_hypernetwork':
shared.reload_hypernetworks() # make onchange call for changing hypernet
if k == 'sd_model_checkpoint': if k == 'sd_model_checkpoint':
sd_models.reload_model_weights() # make onchange call for changing SD model sd_models.reload_model_weights()
if k == 'sd_vae': if k == 'sd_vae':
sd_vae.reload_vae_weights() # make onchange call for changing VAE sd_vae.reload_vae_weights()
res = process_images_inner(p) res = process_images_inner(p)
@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.override_settings_restore_afterwards: if p.override_settings_restore_afterwards:
for k, v in stored_opts.items(): for k, v in stored_opts.items():
setattr(opts, k, v) setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks() if k == 'sd_model_checkpoint':
if k == 'sd_model_checkpoint': sd_models.reload_model_weights() sd_models.reload_model_weights()
if k == 'sd_vae': sd_vae.reload_vae_weights()
if k == 'sd_vae':
sd_vae.reload_vae_weights()
return res return res
@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
cache[0] = (required_prompts, steps) cache[0] = (required_prompts, steps)
return cache[1] return cache[1]
p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts)
with torch.no_grad(), p.sd_model.ema_scope(): with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast(): with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
extra_networks.activate(p, extra_network_data)
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0)) file.write(processed.infotext(p, 0))
@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
extra_networks.deactivate(p, extra_network_data)
devices.torch_gc() devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)

View File

@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k) k_in = self.to_k(context_k)
v_in = self.to_v(context_v) v_in = self.to_v(context_v)
del context, context_k, context_v, x del context, context_k, context_v, x
@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k) k_in = self.to_k(context_k)
v_in = self.to_v(context_v) v_in = self.to_v(context_v)
@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k) * self.scale k = self.to_k(context_k) * self.scale
v = self.to_v(context_v) v = self.to_v(context_v)
del context, context_k, context_v, x del context, context_k, context_v, x
@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k) k = self.to_k(context_k)
v = self.to_v(context_v) v = self.to_v(context_v)
del context, context_k, context_v, x del context, context_k, context_v, x
@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k) k_in = self.to_k(context_k)
v_in = self.to_v(context_v) v_in = self.to_v(context_v)

View File

@ -23,6 +23,7 @@ demo = None
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = {} hypernetworks = {}
loaded_hypernetwork = None loaded_hypernetworks = []
def reload_hypernetworks(): def reload_hypernetworks():
@ -153,8 +154,6 @@ def reload_hypernetworks():
global hypernetworks global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
class State: class State:
@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@ -661,3 +658,17 @@ mem_mon.start()
def listfiles(dirname): def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)] return [file for file in filenames if os.path.isfile(file)]
def html_path(filename):
return os.path.join(script_path, "html", filename)
def html(filename):
path = html_path(filename)
if os.path.exists(path):
with open(path, encoding="utf8") as file:
return file.read()
return ""

View File

@ -50,6 +50,7 @@ class Embedding:
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.optimizer_state_dict = None self.optimizer_state_dict = None
self.filename = None
def save(self, filename): def save(self, filename):
embedding_data = { embedding_data = {
@ -182,6 +183,7 @@ class EmbeddingDatabase:
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0] embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1] embedding.shape = vec.shape[-1]
embedding.filename = path
if self.expected_shape == -1 or self.expected_shape == embedding.shape: if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model) self.register_embedding(embedding, shared.sd_model)

View File

@ -20,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path from modules.paths import script_path
@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾 save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋 apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️ clear_prompt_symbol = '\U0001F5D1' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
def plaintext_to_html(text): def plaintext_to_html(text):
@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
def update_token_counter(text, steps): def update_token_counter(text, steps):
try: try:
text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
@ -354,10 +357,10 @@ def create_toprow(is_img2img):
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
with gr.Column(scale=1, elem_id="roll_col"): with gr.Column(scale=1, elem_id="roll_col"):
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = ToolButton(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create") clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter") token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter") negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
@ -395,11 +398,14 @@ def create_toprow(is_img2img):
outputs=[], outputs=[],
) )
with gr.Row(): with gr.Row(elem_id=f"{id_part}_styles_row"):
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply")
save_style = ToolButton(value=save_style_symbol, elem_id="style_create")
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
def setup_progressbar(*args, **kwargs): def setup_progressbar(*args, **kwargs):
@ -616,11 +622,15 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"): with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories(): for category in ordered_ui_categories():
@ -794,14 +804,20 @@ def create_ui():
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
with FormRow().style(equal_height=False): with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"): with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = [] copy_image_buttons = []
@ -1064,6 +1080,8 @@ def create_ui():
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
img2img_paste_fields = [ img2img_paste_fields = [
(img2img_prompt, "Prompt"), (img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"), (img2img_negative_prompt, "Negative prompt"),
@ -1666,10 +1684,8 @@ def create_ui():
download_localization = gr.Button(value='Download localization template', elem_id="download_localization") download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
if os.path.exists("html/licenses.html"):
with open("html/licenses.html", encoding="utf8") as file:
with gr.TabItem("Licenses"): with gr.TabItem("Licenses"):
gr.HTML(file.read(), elem_id="licenses") gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages") gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@ -1756,9 +1772,7 @@ def create_ui():
if os.path.exists(os.path.join(script_path, "notification.mp3")): if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
if os.path.exists("html/footer.html"): footer = shared.html("footer.html")
with open("html/footer.html", encoding="utf8") as file:
footer = file.read()
footer = footer.format(versions=versions_html()) footer = footer.format(versions=versions_html())
gr.HTML(footer, elem_id="footer") gr.HTML(footer, elem_id="footer")

View File

@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent):
return "button" return "button"
class ToolButtonTop(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool-top", **kwargs)
def get_block_name(self):
return "button"
class FormRow(gr.Row, gr.components.FormComponent): class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms""" """Same as gr.Row but fits inside gradio forms"""

View File

@ -0,0 +1,149 @@
import os.path
from modules import shared
import gradio as gr
import json
from modules.generation_parameters_copypaste import image_from_url_text
extra_pages = []
def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions"""
extra_pages.append(page)
class ExtraNetworksPage:
def __init__(self, title):
self.title = title
self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False
def refresh(self):
pass
def create_html(self, tabname):
items_html = ''
for item in self.list_items():
items_html += self.create_html_for_item(item, tabname)
if items_html == '':
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
res = "<div class='extra-network-cards'>" + items_html + "</div>"
return res
def list_items(self):
raise NotImplementedError()
def allowed_directories_for_previews(self):
return []
def create_html_for_item(self, item, tabname):
preview = item.get("preview", None)
args = {
"preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
"prompt": json.dumps(item["prompt"]),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
"allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
}
return self.card_page.format(**args)
def intialize():
extra_pages.clear()
class ExtraNetworksUi:
def __init__(self):
self.pages = None
self.stored_extra_pages = None
self.button_save_preview = None
self.preview_target_filename = None
self.tabname = None
def create_ui(container, button, tabname):
ui = ExtraNetworksUi()
ui.pages = []
ui.stored_extra_pages = extra_pages.copy()
ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
for page in ui.stored_extra_pages:
with gr.Tab(page.title):
page_elem = gr.HTML(page.create_html(ui.tabname))
ui.pages.append(page_elem)
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
def refresh():
res = []
for pg in ui.stored_extra_pages:
pg.refresh()
res.append(pg.create_html(ui.tabname))
return res
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
return ui
def path_is_parent(parent_path, child_path):
parent_path = os.path.abspath(parent_path)
child_path = os.path.abspath(child_path)
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
def setup_ui(ui, gallery):
def save_preview(index, images, filename):
if len(images) == 0:
print("There is no image in gallery to save as a preview.")
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
index = int(index)
index = 0 if index < 0 else index
index = len(images) - 1 if index >= len(images) else index
img_info = images[index if index >= 0 else 0]
image = image_from_url_text(img_info)
is_allowed = False
for extra_page in ui.stored_extra_pages:
if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
is_allowed = True
break
assert is_allowed, f'writing to {filename} is not allowed'
image.save(filename)
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
ui.button_save_preview.click(
fn=save_preview,
_js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
outputs=[*ui.pages]
)

View File

@ -0,0 +1,34 @@
import os
from modules import shared, ui_extra_networks
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Hypernetworks')
def refresh(self):
shared.reload_hypernetworks()
def list_items(self):
for name, path in shared.hypernetworks.items():
path, ext = os.path.splitext(path)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
break
yield {
"name": name,
"filename": path,
"preview": preview,
"prompt": f"<hypernet:{name}:1.0>",
"local_preview": path + ".png",
}
def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir]

View File

@ -0,0 +1,32 @@
import os
from modules import ui_extra_networks, sd_hijack
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Textual Inversion')
self.allow_negative_prompt = True
def refresh(self):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
def list_items(self):
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
path, ext = os.path.splitext(embedding.filename)
preview_file = path + ".preview.png"
preview = None
if os.path.isfile(preview_file):
preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
yield {
"name": embedding.name,
"filename": embedding.filename,
"preview": preview,
"prompt": embedding.name,
"local_preview": path + ".preview.png",
}
def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)

View File

@ -13,6 +13,7 @@ function get_uiCurrentTabContent() {
} }
uiUpdateCallbacks = [] uiUpdateCallbacks = []
uiLoadedCallbacks = []
uiTabChangeCallbacks = [] uiTabChangeCallbacks = []
optionsChangedCallbacks = [] optionsChangedCallbacks = []
let uiCurrentTab = null let uiCurrentTab = null
@ -20,6 +21,9 @@ let uiCurrentTab = null
function onUiUpdate(callback){ function onUiUpdate(callback){
uiUpdateCallbacks.push(callback) uiUpdateCallbacks.push(callback)
} }
function onUiLoaded(callback){
uiLoadedCallbacks.push(callback)
}
function onUiTabChange(callback){ function onUiTabChange(callback){
uiTabChangeCallbacks.push(callback) uiTabChangeCallbacks.push(callback)
} }
@ -38,8 +42,15 @@ function executeCallbacks(queue, m) {
queue.forEach(function(x){runCallback(x, m)}) queue.forEach(function(x){runCallback(x, m)})
} }
var executedOnLoaded = false;
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
executedOnLoaded = true;
executeCallbacks(uiLoadedCallbacks);
}
executeCallbacks(uiUpdateCallbacks, m); executeCallbacks(uiUpdateCallbacks, m);
const newTab = get_uiCurrentTab(); const newTab = get_uiCurrentTab();
if ( newTab && ( newTab !== uiCurrentTab ) ) { if ( newTab && ( newTab !== uiCurrentTab ) ) {

View File

@ -11,7 +11,6 @@ import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images, paths, sd_samplers, processing, sd_models, sd_vae from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs):
raise RuntimeError(f"Unknown checkpoint: {x}") raise RuntimeError(f"Unknown checkpoint: {x}")
def apply_hypernetwork(p, x, xs):
if x.lower() in ["", "none"]:
name = None
else:
name = hypernetwork.find_closest_hypernetwork_name(x)
if not name:
raise RuntimeError(f"Unknown hypernetwork: {x}")
hypernetwork.load_hypernetwork(name)
def apply_hypernetwork_strength(p, x, xs):
hypernetwork.apply_strength(x)
def confirm_hypernetworks(p, xs):
for x in xs:
if x.lower() in ["", "none"]:
continue
if not hypernetwork.find_closest_hypernetwork_name(x):
raise RuntimeError(f"Unknown hypernetwork: {x}")
def apply_clip_skip(p, x, xs): def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x opts.data["CLIP_stop_at_last_layers"] = x
@ -208,8 +185,6 @@ axis_options = [
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")), AxisOption("Sigma max", float, apply_field("s_tmax")),
@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
class SharedSettingsStackHelper(object): class SharedSettingsStackHelper(object):
def __enter__(self): def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork
self.vae = opts.sd_vae self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb):
@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object):
modules.sd_models.reload_model_weights() modules.sd_models.reload_model_weights()
modules.sd_vae.reload_vae_weights() modules.sd_vae.reload_vae_weights()
hypernetwork.load_hypernetwork(self.hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers

164
style.css
View File

@ -132,13 +132,6 @@
} }
#roll_col > button { #roll_col > button {
min-width: 2em;
min-height: 2em;
max-width: 2em;
max-height: 2em;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
margin: 0.1em 0; margin: 0.1em 0;
} }
@ -146,9 +139,10 @@
min-width: 0 !important; min-width: 0 !important;
max-width: 8em !important; max-width: 8em !important;
margin-right: 1em; margin-right: 1em;
gap: 0;
} }
#interrogate, #deepbooru{ #interrogate, #deepbooru{
margin: 0em 0.25em 0.9em 0.25em; margin: 0em 0.25em 0.5em 0.25em;
min-width: 8em; min-width: 8em;
max-width: 8em; max-width: 8em;
} }
@ -157,8 +151,17 @@
min-width: 8em !important; min-width: 8em !important;
} }
#txt2img_styles_row, #img2img_styles_row{
gap: 0.25em;
margin-top: 0.5em;
}
#txt2img_styles_row > button, #img2img_styles_row > button{
margin: 0;
}
#txt2img_styles, #img2img_styles{ #txt2img_styles, #img2img_styles{
margin-top: 1em; padding: 0;
} }
#txt2img_styles ul, #img2img_styles ul{ #txt2img_styles ul, #img2img_styles ul{
@ -635,16 +638,20 @@ canvas[key="mask"] {
background-color: rgb(31 41 55 / var(--tw-bg-opacity)); background-color: rgb(31 41 55 / var(--tw-bg-opacity));
} }
.gr-button-tool{ .gr-button-tool, .gr-button-tool-top{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em !important; min-width: 2.5em !important;
height: 2.4em; height: 2.4em;
}
.gr-button-tool{
margin: 0.6em 0em 0.55em 0;
}
.gr-button-tool-top, #settings .gr-button-tool{
margin: 1.6em 0.7em 0.55em 0; margin: 1.6em 0.7em 0.55em 0;
} }
#tab_modelmerger .gr-button-tool{
margin: 0.6em 0em 0.55em 0;
}
#modelmerger_results_container{ #modelmerger_results_container{
margin-top: 1em; margin-top: 1em;
@ -763,81 +770,88 @@ footer {
line-height: 2.4em; line-height: 2.4em;
} }
/* The following handles localization for right-to-left (RTL) languages like Arabic. #txt2img_extra_networks, #img2img_extra_networks{
The rtl media type will only be activated by the logic in javascript/localization.js. margin-top: -1em;
If you change anything above, you need to make sure it is RTL compliant by just running
your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
@media rtl {
/* this part was added manually */
:host {
direction: rtl;
}
select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
direction: ltr;
}
#script_list > label > select,
#x_type > label > select,
#y_type > label > select {
direction: rtl;
}
.gr-radio, .gr-checkbox{
margin-left: 0.25em;
} }
/* automatically generated with few manual modifications */ .extra-networks > div > [id *= '_extra_']{
.performance .time { margin: 0.3em;
margin-right: unset;
margin-left: 0;
} }
.justify-center.overflow-x-scroll {
justify-content: right; .extra-network-cards .nocards{
margin: 1.25em 0.5em 0.5em 0.5em;
} }
.justify-center.overflow-x-scroll button:first-of-type {
margin-left: unset; .extra-network-cards .nocards h1{
margin-right: auto; font-size: 1.5em;
margin-bottom: 1em;
} }
.justify-center.overflow-x-scroll button:last-of-type {
margin-right: unset; .extra-network-cards .nocards li{
margin-left: auto; margin-left: 0.5em;
} }
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
margin-right: unset; .extra-network-cards .card{
margin-left: 8em; display: inline-block;
margin: 0.5em;
width: 16em;
height: 24em;
box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);
border-radius: 0.2em;
position: relative;
background-size: auto 100%;
background-position: center;
overflow: hidden;
cursor: pointer;
background-image: url('./file=html/card-no-preview.png')
} }
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
right: unset; .extra-network-cards .card:hover{
box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35);
}
.extra-network-cards .card .actions .additional{
display: none;
}
.extra-network-cards .card .actions{
position: absolute;
bottom: 0;
left: 0; left: 0;
}
.progressDiv .progress{
padding: 0 0 0 8px;
text-align: left;
}
#lightboxModal{
left: unset;
right: 0; right: 0;
padding: 0.5em;
color: white;
background: rgba(0,0,0,0.5);
box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5);
text-shadow: 0 0 0.2em black;
} }
.modalPrev, .modalNext{
border-radius: 3px 0 0 3px; .extra-network-cards .card .actions:hover{
box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important;
} }
.modalNext {
right: unset; .extra-network-cards .card .actions .name{
left: 0; font-size: 1.7em;
border-radius: 0 3px 3px 0; font-weight: bold;
line-break: anywhere;
} }
#imageARPreview{
left:unset; .extra-network-cards .card .actions:hover .additional{
right:0px; display: block;
} }
#txt2img_skip, #img2img_skip{
right: unset; .extra-network-cards .card ul{
left: 0px; margin: 0.25em 0 0.75em 0.25em;
cursor: unset;
} }
#context-menu{
box-shadow:-1px 1px 2px #CE6400; .extra-network-cards .card ul a{
} cursor: pointer;
.gr-box > div > div > input.gr-text-input{
right: unset;
left: 0.5em;
} }
.extra-network-cards .card ul a:hover{
color: red;
} }

View File

@ -9,16 +9,18 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules import import_hook, errors from modules import import_hook, errors, extra_networks
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path from modules.paths import script_path
import torch import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__: if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
@ -84,10 +86,17 @@ def initialize():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.reload_hypernetworks()
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
try: try:
@ -209,6 +218,15 @@ def webui():
modules.sd_models.list_models() modules.sd_models.list_models()
shared.reload_hypernetworks()
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
if __name__ == "__main__": if __name__ == "__main__":
if cmd_opts.nowebui: if cmd_opts.nowebui: