diff --git a/html/card-no-preview.png b/html/card-no-preview.png
new file mode 100644
index 000000000..e2beb2692
Binary files /dev/null and b/html/card-no-preview.png differ
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
new file mode 100644
index 000000000..7314b0630
--- /dev/null
+++ b/html/extra-networks-card.html
@@ -0,0 +1,11 @@
+
+
diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html
new file mode 100644
index 000000000..389358d6c
--- /dev/null
+++ b/html/extra-networks-no-cards.html
@@ -0,0 +1,8 @@
+
+
Nothing here. Add some content to the following directories:
+
+
+
+
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
new file mode 100644
index 000000000..71e522d16
--- /dev/null
+++ b/javascript/extraNetworks.js
@@ -0,0 +1,60 @@
+
+function setupExtraNetworksForTab(tabname){
+ gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
+
+ gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh'))
+ gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close'))
+}
+
+var activePromptTextarea = null;
+var activePositivePromptTextarea = null;
+
+function setupExtraNetworks(){
+ setupExtraNetworksForTab('txt2img')
+ setupExtraNetworksForTab('img2img')
+
+ function registerPrompt(id, isNegative){
+ var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
+
+ if (activePromptTextarea == null){
+ activePromptTextarea = textarea
+ }
+ if (activePositivePromptTextarea == null && ! isNegative){
+ activePositivePromptTextarea = textarea
+ }
+
+ textarea.addEventListener("focus", function(){
+ activePromptTextarea = textarea;
+ if(! isNegative) activePositivePromptTextarea = textarea;
+ });
+ }
+
+ registerPrompt('txt2img_prompt')
+ registerPrompt('txt2img_neg_prompt', true)
+ registerPrompt('img2img_prompt')
+ registerPrompt('img2img_neg_prompt', true)
+}
+
+onUiLoaded(setupExtraNetworks)
+
+function cardClicked(textToAdd, allowNegativePrompt){
+ textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea
+
+ textarea.value = textarea.value + " " + textToAdd
+ updateInput(textarea)
+
+ return false
+}
+
+function saveCardPreview(event, tabname, filename){
+ textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
+ button = gradioApp().getElementById(tabname + '_save_preview')
+
+ textarea.value = filename
+ updateInput(textarea)
+
+ button.click()
+
+ event.stopPropagation()
+ event.preventDefault()
+}
diff --git a/javascript/hints.js b/javascript/hints.js
index e746e20d5..f4079f961 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -21,6 +21,8 @@ titles = {
"\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field",
+ "\u{1f3b4}": "Show extra networks",
+
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
diff --git a/javascript/ui.js b/javascript/ui.js
index 3ba90ca88..a7e754394 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
return [prompt, negative_prompt]
}
-
-
opts = {}
onUiUpdate(function(){
if(Object.keys(opts).length != 0) return;
@@ -239,11 +237,14 @@ onUiUpdate(function(){
return
}
+
prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative"
- textarea.addEventListener("input", () => update_token_counter(id_button));
+ textarea.addEventListener("input", function(){
+ update_token_counter(id_button);
+ });
}
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
@@ -261,10 +262,8 @@ onUiUpdate(function(){
})
}
}
-
})
-
onOptionsChanged(function(){
elem = gradioApp().getElementById('sd_checkpoint_hash')
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
diff --git a/modules/api/api.py b/modules/api/api.py
index 9814bbc28..2c371e6e7 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -480,7 +480,7 @@ class Api:
def train_hypernetwork(self, args: dict):
try:
shared.state.begin()
- initial_hypernetwork = shared.loaded_hypernetwork
+ shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -491,16 +491,15 @@ class Api:
except Exception as e:
error = e
finally:
- shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
- return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+ return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
except AssertionError as msg:
shared.state.end()
- return TrainResponse(info = "train embedding error: {error}".format(error = error))
+ return TrainResponse(info="train embedding error: {error}".format(error=error))
def get_memory(self):
try:
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
new file mode 100644
index 000000000..1978673d7
--- /dev/null
+++ b/modules/extra_networks.py
@@ -0,0 +1,147 @@
+import re
+from collections import defaultdict
+
+from modules import errors
+
+extra_network_registry = {}
+
+
+def initialize():
+ extra_network_registry.clear()
+
+
+def register_extra_network(extra_network):
+ extra_network_registry[extra_network.name] = extra_network
+
+
+class ExtraNetworkParams:
+ def __init__(self, items=None):
+ self.items = items or []
+
+
+class ExtraNetwork:
+ def __init__(self, name):
+ self.name = name
+
+ def activate(self, p, params_list):
+ """
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
+ Passes arguments related to this extra network in params_list.
+ User passes arguments by specifying this in his prompt:
+
+
+
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
+ separated by colon.
+
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
+ in this case, all effects of this extra networks should be disabled.
+
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
+
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
+
+ > "1girl, "
+
+ params_list will be:
+
+ [
+ ExtraNetworkParams(items=["agm", "1.1"]),
+ ExtraNetworkParams(items=["ray"])
+ ]
+
+ """
+ raise NotImplementedError
+
+ def deactivate(self, p):
+ """
+ Called at the end of processing for housekeeping. No need to do anything here.
+ """
+
+ raise NotImplementedError
+
+
+def activate(p, extra_network_data):
+ """call activate for extra networks in extra_network_data in specified order, then call
+ activate for all remaining registered networks with an empty argument list"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ print(f"Skipping unknown extra network: {extra_network_name}")
+ continue
+
+ try:
+ extra_network.activate(p, extra_network_args)
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.activate(p, [])
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name}")
+
+
+def deactivate(p, extra_network_data):
+ """call deactivate for extra networks in extra_network_data in specified order, then call
+ deactivate for all remaining registered networks"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating extra network {extra_network_name}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
+
+
+re_extra_net = re.compile(r"<(\w+):([^>]+)>")
+
+
+def parse_prompt(prompt):
+ res = defaultdict(list)
+
+ def found(m):
+ name = m.group(1)
+ args = m.group(2)
+
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
+
+ return ""
+
+ prompt = re.sub(re_extra_net, found, prompt)
+
+ return prompt, res
+
+
+def parse_prompts(prompts):
+ res = []
+ extra_data = None
+
+ for prompt in prompts:
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
+
+ if extra_data is None:
+ extra_data = parsed_extra_data
+
+ res.append(updated_prompt)
+
+ return res, extra_data
+
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
new file mode 100644
index 000000000..6a0c4ba87
--- /dev/null
+++ b/modules/extra_networks_hypernet.py
@@ -0,0 +1,21 @@
+from modules import extra_networks
+from modules.hypernetworks import hypernetwork
+
+
+class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('hypernet')
+
+ def activate(self, p, params_list):
+ names = []
+ multipliers = []
+ for params in params_list:
+ assert len(params.items) > 0
+
+ names.append(params.items[0])
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+ hypernetwork.load_hypernetworks(names, multipliers)
+
+ def deactivate(p, self):
+ pass
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index a381ff599..46e12dc6c 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict):
from modules import ui
settings_map = {
- 'sd_hypernetwork': 'Hypernet',
- 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash',
@@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Clip skip" not in res:
res["Clip skip"] = "1"
- if "Hypernet strength" not in res:
- res["Hypernet strength"] = "1"
-
- if "Hypernet" in res:
- hypernet_name = res["Hypernet"]
- hypernet_hash = res.get("Hypernet hash", None)
- res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+ hypernet = res.get("Hypernet", None)
+ if hypernet is not None:
+ res["Prompt"] += f""""""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 74e785824..80a47c791 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -25,7 +25,6 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
- multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
@@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module):
add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__()
+ self.multiplier = 1.0
+
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
@@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
state_dict[to] = x
def forward(self, x):
- return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
def trainables(self):
layer_structure = []
@@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module):
return layer_structure
-def apply_strength(value=None):
- HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
if layer_structure is None:
@@ -192,6 +190,20 @@ class Hypernetwork:
for param in layer.parameters():
param.requires_grad = mode
+ def to(self, device):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.to(device)
+
+ return self
+
+ def set_multiplier(self, multiplier):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.multiplier = multiplier
+
+ return self
+
def eval(self):
for k, layers in self.layers.items():
for layer in layers:
@@ -269,11 +281,13 @@ class Hypernetwork:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- print("Loaded existing optimizer from checkpoint")
- print(f"Optimizer name is {self.optimizer_name}")
+ if shared.opts.print_hypernet_extra:
+ print("Loaded existing optimizer from checkpoint")
+ print(f"Optimizer name is {self.optimizer_name}")
else:
self.optimizer_name = "AdamW"
- print("No saved optimizer exists in checkpoint")
+ if shared.opts.print_hypernet_extra:
+ print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
@@ -306,23 +320,43 @@ def list_hypernetworks(path):
return res
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- # Prevent any file named "None.pt" from being loaded.
- if path is not None and filename != "None":
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+def load_hypernetwork(name):
+ path = shared.hypernetworks.get(name, None)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print("Unloading hypernetwork")
+ if path is None:
+ return None
- shared.loaded_hypernetwork = None
+ hypernetwork = Hypernetwork()
+
+ try:
+ hypernetwork.load(path)
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ return hypernetwork
+
+
+def load_hypernetworks(names, multipliers=None):
+ already_loaded = {}
+
+ for hypernetwork in shared.loaded_hypernetworks:
+ if hypernetwork.name in names:
+ already_loaded[hypernetwork.name] = hypernetwork
+
+ shared.loaded_hypernetworks.clear()
+
+ for i, name in enumerate(names):
+ hypernetwork = already_loaded.get(name, None)
+ if hypernetwork is None:
+ hypernetwork = load_hypernetwork(name)
+
+ if hypernetwork is None:
+ continue
+
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
+ shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
@@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
return applicable[0]
-def apply_hypernetwork(hypernetwork, context, layer=None):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None:
- return context, context
+ return context_k, context_v
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
- context_k = hypernetwork_layers[0](context)
- context_v = hypernetwork_layers[1](context)
+ context_k = hypernetwork_layers[0](context_k)
+ context_v = hypernetwork_layers[1](context_v)
+ return context_k, context_v
+
+
+def apply_hypernetworks(hypernetworks, context, layer=None):
+ context_k = context
+ context_v = context
+ for hypernetwork in hypernetworks:
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
+
return context_k, context_v
@@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
- context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
@@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+ hypernetwork = Hypernetwork()
+ hypernetwork.load(path)
+ shared.loaded_hypernetworks = [hypernetwork]
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
@@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
else:
images_dir = None
- hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 81e3f519b..76599f5ad 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
@@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
def train_hypernetwork(*args):
-
- initial_hypernetwork = shared.loaded_hypernetwork
+ shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
@@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception:
raise
finally:
- shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
diff --git a/modules/processing.py b/modules/processing.py
index a3e9f7095..b5deeacf5 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
- "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork':
- shared.reload_hypernetworks() # make onchange call for changing hypernet
if k == 'sd_model_checkpoint':
- sd_models.reload_model_weights() # make onchange call for changing SD model
+ sd_models.reload_model_weights()
if k == 'sd_vae':
- sd_vae.reload_vae_weights() # make onchange call for changing VAE
+ sd_vae.reload_vae_weights()
res = process_images_inner(p)
@@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks()
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
- if k == 'sd_vae': sd_vae.reload_vae_weights()
+ if k == 'sd_model_checkpoint':
+ sd_models.reload_model_weights()
+
+ if k == 'sd_vae':
+ sd_vae.reload_vae_weights()
return res
@@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
cache[0] = (required_prompts, steps)
return cache[1]
+ p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts)
+
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ extra_networks.activate(p, extra_network_data)
+
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
@@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ extra_networks.deactivate(p, extra_network_data)
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index cdc63ed74..4fa54329d 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
del context, context_k, context_v, x
@@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
@@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k) * self.scale
v = self.to_v(context_v)
del context, context_k, context_v, x
@@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
@@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
diff --git a/modules/shared.py b/modules/shared.py
index 2f3664542..c0e11f184 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -23,6 +23,7 @@ demo = None
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
+
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
@@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = {}
-loaded_hypernetwork = None
+loaded_hypernetworks = []
def reload_hypernetworks():
@@ -153,8 +154,6 @@ def reload_hypernetworks():
global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
- hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
-
class State:
@@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
- "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
- "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -661,3 +658,17 @@ mem_mon.start()
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
+
+
+def html_path(filename):
+ return os.path.join(script_path, "html", filename)
+
+
+def html(filename):
+ path = html_path(filename)
+
+ if os.path.exists(path):
+ with open(path, encoding="utf8") as file:
+ return file.read()
+
+ return ""
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5a7be4228..4e90f690f 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -50,6 +50,7 @@ class Embedding:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
+ self.filename = None
def save(self, filename):
embedding_data = {
@@ -182,6 +183,7 @@ class EmbeddingDatabase:
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
+ embedding.filename = path
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
diff --git a/modules/ui.py b/modules/ui.py
index 06c11848a..d23b2b8e9 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -20,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️
+extra_networks_symbol = '\U0001F3B4' # 🎴
def plaintext_to_html(text):
@@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
def update_token_counter(text, steps):
try:
+ text, _ = extra_networks.parse_prompt(text)
+
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
@@ -354,10 +357,10 @@ def create_toprow(is_img2img):
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
with gr.Column(scale=1, elem_id="roll_col"):
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
- prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ paste = ToolButton(value=paste_symbol, elem_id="paste")
+ clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+
token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter")
@@ -395,11 +398,14 @@ def create_toprow(is_img2img):
outputs=[],
)
- with gr.Row():
+ with gr.Row(elem_id=f"{id_part}_styles_row"):
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button
+ prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply")
+ save_style = ToolButton(value=save_style_symbol, elem_id="style_create")
+
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
def setup_progressbar(*args, **kwargs):
@@ -616,11 +622,15 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
+ with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
@@ -794,14 +804,20 @@ def create_ui():
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
+
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+ with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+
with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
@@ -1064,6 +1080,8 @@ def create_ui():
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"),
@@ -1666,10 +1684,8 @@ def create_ui():
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
- if os.path.exists("html/licenses.html"):
- with open("html/licenses.html", encoding="utf8") as file:
- with gr.TabItem("Licenses"):
- gr.HTML(file.read(), elem_id="licenses")
+ with gr.TabItem("Licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@@ -1756,11 +1772,9 @@ def create_ui():
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
- if os.path.exists("html/footer.html"):
- with open("html/footer.html", encoding="utf8") as file:
- footer = file.read()
- footer = footer.format(versions=versions_html())
- gr.HTML(footer, elem_id="footer")
+ footer = shared.html("footer.html")
+ footer = footer.format(versions=versions_html())
+ gr.HTML(footer, elem_id="footer")
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 97acff062..463244256 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent):
return "button"
+class ToolButtonTop(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool-top", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
new file mode 100644
index 000000000..253e90f7b
--- /dev/null
+++ b/modules/ui_extra_networks.py
@@ -0,0 +1,149 @@
+import os.path
+
+from modules import shared
+import gradio as gr
+import json
+
+from modules.generation_parameters_copypaste import image_from_url_text
+
+extra_pages = []
+
+
+def register_page(page):
+ """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions"""
+
+ extra_pages.append(page)
+
+
+class ExtraNetworksPage:
+ def __init__(self, title):
+ self.title = title
+ self.card_page = shared.html("extra-networks-card.html")
+ self.allow_negative_prompt = False
+
+ def refresh(self):
+ pass
+
+ def create_html(self, tabname):
+ items_html = ''
+
+ for item in self.list_items():
+ items_html += self.create_html_for_item(item, tabname)
+
+ if items_html == '':
+ dirs = "".join([f"{x}" for x in self.allowed_directories_for_previews()])
+ items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+
+ res = ""
+
+ return res
+
+ def list_items(self):
+ raise NotImplementedError()
+
+ def allowed_directories_for_previews(self):
+ return []
+
+ def create_html_for_item(self, item, tabname):
+ preview = item.get("preview", None)
+
+ args = {
+ "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
+ "prompt": json.dumps(item["prompt"]),
+ "tabname": json.dumps(tabname),
+ "local_preview": json.dumps(item["local_preview"]),
+ "name": item["name"],
+ "allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
+ }
+
+ return self.card_page.format(**args)
+
+
+def intialize():
+ extra_pages.clear()
+
+
+class ExtraNetworksUi:
+ def __init__(self):
+ self.pages = None
+ self.stored_extra_pages = None
+
+ self.button_save_preview = None
+ self.preview_target_filename = None
+
+ self.tabname = None
+
+
+def create_ui(container, button, tabname):
+ ui = ExtraNetworksUi()
+ ui.pages = []
+ ui.stored_extra_pages = extra_pages.copy()
+ ui.tabname = tabname
+
+ with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+ button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
+ button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
+
+ for page in ui.stored_extra_pages:
+ with gr.Tab(page.title):
+ page_elem = gr.HTML(page.create_html(ui.tabname))
+ ui.pages.append(page_elem)
+
+ ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
+ ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
+
+ button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
+ button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
+
+ def refresh():
+ res = []
+
+ for pg in ui.stored_extra_pages:
+ pg.refresh()
+ res.append(pg.create_html(ui.tabname))
+
+ return res
+
+ button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
+
+ return ui
+
+
+def path_is_parent(parent_path, child_path):
+ parent_path = os.path.abspath(parent_path)
+ child_path = os.path.abspath(child_path)
+
+ return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
+
+
+def setup_ui(ui, gallery):
+ def save_preview(index, images, filename):
+ if len(images) == 0:
+ print("There is no image in gallery to save as a preview.")
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ index = int(index)
+ index = 0 if index < 0 else index
+ index = len(images) - 1 if index >= len(images) else index
+
+ img_info = images[index if index >= 0 else 0]
+ image = image_from_url_text(img_info)
+
+ is_allowed = False
+ for extra_page in ui.stored_extra_pages:
+ if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ is_allowed = True
+ break
+
+ assert is_allowed, f'writing to {filename} is not allowed'
+
+ image.save(filename)
+
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ ui.button_save_preview.click(
+ fn=save_preview,
+ _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
+ inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
+ outputs=[*ui.pages]
+ )
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
new file mode 100644
index 000000000..312dbaf04
--- /dev/null
+++ b/modules/ui_extra_networks_hypernets.py
@@ -0,0 +1,34 @@
+import os
+
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Hypernetworks')
+
+ def refresh(self):
+ shared.reload_hypernetworks()
+
+ def list_items(self):
+ for name, path in shared.hypernetworks.items():
+ path, ext = os.path.splitext(path)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ break
+
+ yield {
+ "name": name,
+ "filename": path,
+ "preview": preview,
+ "prompt": f"",
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.hypernetwork_dir]
+
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
new file mode 100644
index 000000000..e4a6e3bfb
--- /dev/null
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -0,0 +1,32 @@
+import os
+
+from modules import ui_extra_networks, sd_hijack
+
+
+class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Textual Inversion')
+ self.allow_negative_prompt = True
+
+ def refresh(self):
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+ def list_items(self):
+ for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+ path, ext = os.path.splitext(embedding.filename)
+ preview_file = path + ".preview.png"
+
+ preview = None
+ if os.path.isfile(preview_file):
+ preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
+
+ yield {
+ "name": embedding.name,
+ "filename": embedding.filename,
+ "preview": preview,
+ "prompt": embedding.name,
+ "local_preview": path + ".preview.png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/script.js b/script.js
index 3345e32b4..97e0bfcf9 100644
--- a/script.js
+++ b/script.js
@@ -13,6 +13,7 @@ function get_uiCurrentTabContent() {
}
uiUpdateCallbacks = []
+uiLoadedCallbacks = []
uiTabChangeCallbacks = []
optionsChangedCallbacks = []
let uiCurrentTab = null
@@ -20,6 +21,9 @@ let uiCurrentTab = null
function onUiUpdate(callback){
uiUpdateCallbacks.push(callback)
}
+function onUiLoaded(callback){
+ uiLoadedCallbacks.push(callback)
+}
function onUiTabChange(callback){
uiTabChangeCallbacks.push(callback)
}
@@ -38,8 +42,15 @@ function executeCallbacks(queue, m) {
queue.forEach(function(x){runCallback(x, m)})
}
+var executedOnLoaded = false;
+
document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){
+ if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
+ executedOnLoaded = true;
+ executeCallbacks(uiLoadedCallbacks);
+ }
+
executeCallbacks(uiUpdateCallbacks, m);
const newTab = get_uiCurrentTab();
if ( newTab && ( newTab !== uiCurrentTab ) ) {
@@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() {
/**
* Add a ctrl+enter as a shortcut to start a generation
*/
- document.addEventListener('keydown', function(e) {
+document.addEventListener('keydown', function(e) {
var handled = false;
if (e.key !== undefined) {
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 6629f5d5f..b1badec90 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -11,7 +11,6 @@ import modules.scripts as scripts
import gradio as gr
from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
-from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs):
raise RuntimeError(f"Unknown checkpoint: {x}")
-def apply_hypernetwork(p, x, xs):
- if x.lower() in ["", "none"]:
- name = None
- else:
- name = hypernetwork.find_closest_hypernetwork_name(x)
- if not name:
- raise RuntimeError(f"Unknown hypernetwork: {x}")
- hypernetwork.load_hypernetwork(name)
-
-
-def apply_hypernetwork_strength(p, x, xs):
- hypernetwork.apply_strength(x)
-
-
-def confirm_hypernetworks(p, xs):
- for x in xs:
- if x.lower() in ["", "none"]:
- continue
- if not hypernetwork.find_closest_hypernetwork_name(x):
- raise RuntimeError(f"Unknown hypernetwork: {x}")
-
-
def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x
@@ -208,8 +185,6 @@ axis_options = [
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
- AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
- AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
@@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
class SharedSettingsStackHelper(object):
def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
- self.hypernetwork = opts.sd_hypernetwork
self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb):
@@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object):
modules.sd_models.reload_model_weights()
modules.sd_vae.reload_vae_weights()
- hypernetwork.load_hypernetwork(self.hypernetwork)
- hypernetwork.apply_strength()
-
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
diff --git a/style.css b/style.css
index 3a515ebdc..5e8bc2ca2 100644
--- a/style.css
+++ b/style.css
@@ -132,13 +132,6 @@
}
#roll_col > button {
- min-width: 2em;
- min-height: 2em;
- max-width: 2em;
- max-height: 2em;
- flex-grow: 0;
- padding-left: 0.25em;
- padding-right: 0.25em;
margin: 0.1em 0;
}
@@ -146,9 +139,10 @@
min-width: 0 !important;
max-width: 8em !important;
margin-right: 1em;
+ gap: 0;
}
#interrogate, #deepbooru{
- margin: 0em 0.25em 0.9em 0.25em;
+ margin: 0em 0.25em 0.5em 0.25em;
min-width: 8em;
max-width: 8em;
}
@@ -157,8 +151,17 @@
min-width: 8em !important;
}
+#txt2img_styles_row, #img2img_styles_row{
+ gap: 0.25em;
+ margin-top: 0.5em;
+}
+
+#txt2img_styles_row > button, #img2img_styles_row > button{
+ margin: 0;
+}
+
#txt2img_styles, #img2img_styles{
- margin-top: 1em;
+ padding: 0;
}
#txt2img_styles ul, #img2img_styles ul{
@@ -635,16 +638,20 @@ canvas[key="mask"] {
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
}
-.gr-button-tool{
+.gr-button-tool, .gr-button-tool-top{
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
+}
+
+.gr-button-tool{
+ margin: 0.6em 0em 0.55em 0;
+}
+
+.gr-button-tool-top, #settings .gr-button-tool{
margin: 1.6em 0.7em 0.55em 0;
}
-#tab_modelmerger .gr-button-tool{
- margin: 0.6em 0em 0.55em 0;
-}
#modelmerger_results_container{
margin-top: 1em;
@@ -763,81 +770,88 @@ footer {
line-height: 2.4em;
}
-/* The following handles localization for right-to-left (RTL) languages like Arabic.
-The rtl media type will only be activated by the logic in javascript/localization.js.
-If you change anything above, you need to make sure it is RTL compliant by just running
-your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
-Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
-@media rtl {
- /* this part was added manually */
- :host {
- direction: rtl;
- }
- select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
- direction: ltr;
- }
- #script_list > label > select,
- #x_type > label > select,
- #y_type > label > select {
- direction: rtl;
- }
- .gr-radio, .gr-checkbox{
- margin-left: 0.25em;
- }
-
- /* automatically generated with few manual modifications */
- .performance .time {
- margin-right: unset;
- margin-left: 0;
- }
- .justify-center.overflow-x-scroll {
- justify-content: right;
- }
- .justify-center.overflow-x-scroll button:first-of-type {
- margin-left: unset;
- margin-right: auto;
- }
- .justify-center.overflow-x-scroll button:last-of-type {
- margin-right: unset;
- margin-left: auto;
- }
- #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
- margin-right: unset;
- margin-left: 8em;
- }
- #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
- right: unset;
- left: 0;
- }
- .progressDiv .progress{
- padding: 0 0 0 8px;
- text-align: left;
- }
- #lightboxModal{
- left: unset;
- right: 0;
- }
- .modalPrev, .modalNext{
- border-radius: 3px 0 0 3px;
- }
- .modalNext {
- right: unset;
- left: 0;
- border-radius: 0 3px 3px 0;
- }
- #imageARPreview{
- left:unset;
- right:0px;
- }
- #txt2img_skip, #img2img_skip{
- right: unset;
- left: 0px;
- }
- #context-menu{
- box-shadow:-1px 1px 2px #CE6400;
- }
- .gr-box > div > div > input.gr-text-input{
- right: unset;
- left: 0.5em;
- }
+#txt2img_extra_networks, #img2img_extra_networks{
+ margin-top: -1em;
}
+
+.extra-networks > div > [id *= '_extra_']{
+ margin: 0.3em;
+}
+
+.extra-network-cards .nocards{
+ margin: 1.25em 0.5em 0.5em 0.5em;
+}
+
+.extra-network-cards .nocards h1{
+ font-size: 1.5em;
+ margin-bottom: 1em;
+}
+
+.extra-network-cards .nocards li{
+ margin-left: 0.5em;
+}
+
+.extra-network-cards .card{
+ display: inline-block;
+ margin: 0.5em;
+ width: 16em;
+ height: 24em;
+ box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);
+ border-radius: 0.2em;
+ position: relative;
+
+ background-size: auto 100%;
+ background-position: center;
+ overflow: hidden;
+ cursor: pointer;
+
+ background-image: url('./file=html/card-no-preview.png')
+}
+
+.extra-network-cards .card:hover{
+ box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35);
+}
+
+.extra-network-cards .card .actions .additional{
+ display: none;
+}
+
+.extra-network-cards .card .actions{
+ position: absolute;
+ bottom: 0;
+ left: 0;
+ right: 0;
+ padding: 0.5em;
+ color: white;
+ background: rgba(0,0,0,0.5);
+ box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5);
+ text-shadow: 0 0 0.2em black;
+}
+
+.extra-network-cards .card .actions:hover{
+ box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important;
+}
+
+.extra-network-cards .card .actions .name{
+ font-size: 1.7em;
+ font-weight: bold;
+ line-break: anywhere;
+}
+
+.extra-network-cards .card .actions:hover .additional{
+ display: block;
+}
+
+.extra-network-cards .card ul{
+ margin: 0.25em 0 0.75em 0.25em;
+ cursor: unset;
+}
+
+.extra-network-cards .card ul a{
+ cursor: pointer;
+}
+
+.extra-network-cards .card ul a:hover{
+ color: red;
+}
+
diff --git a/webui.py b/webui.py
index 865a73006..e8dd822a6 100644
--- a/webui.py
+++ b/webui.py
@@ -9,16 +9,18 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
-from modules import import_hook, errors
+from modules import import_hook, errors, extra_networks
+from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
import torch
+
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
+from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
import modules.codeformer_model as codeformer
import modules.extras
import modules.face_restoration
@@ -84,10 +86,17 @@ def initialize():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
- shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.reload_hypernetworks()
+
+ ui_extra_networks.intialize()
+ ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
+ ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+
+ extra_networks.initialize()
+ extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
try:
@@ -209,6 +218,15 @@ def webui():
modules.sd_models.list_models()
+ shared.reload_hypernetworks()
+
+ ui_extra_networks.intialize()
+ ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
+ ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+
+ extra_networks.initialize()
+ extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+
if __name__ == "__main__":
if cmd_opts.nowebui: