mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
extra networks UI
rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>
This commit is contained in:
parent
e33cace2c2
commit
40ff6db532
BIN
html/card-no-preview.png
Normal file
BIN
html/card-no-preview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 82 KiB |
11
html/extra-networks-card.html
Normal file
11
html/extra-networks-card.html
Normal file
@ -0,0 +1,11 @@
|
||||
<div class='card' {preview_html} onclick='return cardClicked({prompt}, {allow_negative_prompt})'>
|
||||
<div class='actions'>
|
||||
<div class='additional'>
|
||||
<ul>
|
||||
<a href="#" title="replace preview image with currently selected in gallery" onclick='return saveCardPreview(event, {tabname}, {local_preview})'>replace preview</a>
|
||||
</ul>
|
||||
</div>
|
||||
<span class='name'>{name}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
8
html/extra-networks-no-cards.html
Normal file
8
html/extra-networks-no-cards.html
Normal file
@ -0,0 +1,8 @@
|
||||
<div class='nocards'>
|
||||
<h1>Nothing here. Add some content to the following directories:</h1>
|
||||
|
||||
<ul>
|
||||
{dirs}
|
||||
</ul>
|
||||
</div>
|
||||
|
60
javascript/extraNetworks.js
Normal file
60
javascript/extraNetworks.js
Normal file
@ -0,0 +1,60 @@
|
||||
|
||||
function setupExtraNetworksForTab(tabname){
|
||||
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
||||
|
||||
gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh'))
|
||||
gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close'))
|
||||
}
|
||||
|
||||
var activePromptTextarea = null;
|
||||
var activePositivePromptTextarea = null;
|
||||
|
||||
function setupExtraNetworks(){
|
||||
setupExtraNetworksForTab('txt2img')
|
||||
setupExtraNetworksForTab('img2img')
|
||||
|
||||
function registerPrompt(id, isNegative){
|
||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||
|
||||
if (activePromptTextarea == null){
|
||||
activePromptTextarea = textarea
|
||||
}
|
||||
if (activePositivePromptTextarea == null && ! isNegative){
|
||||
activePositivePromptTextarea = textarea
|
||||
}
|
||||
|
||||
textarea.addEventListener("focus", function(){
|
||||
activePromptTextarea = textarea;
|
||||
if(! isNegative) activePositivePromptTextarea = textarea;
|
||||
});
|
||||
}
|
||||
|
||||
registerPrompt('txt2img_prompt')
|
||||
registerPrompt('txt2img_neg_prompt', true)
|
||||
registerPrompt('img2img_prompt')
|
||||
registerPrompt('img2img_neg_prompt', true)
|
||||
}
|
||||
|
||||
onUiLoaded(setupExtraNetworks)
|
||||
|
||||
function cardClicked(textToAdd, allowNegativePrompt){
|
||||
textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea
|
||||
|
||||
textarea.value = textarea.value + " " + textToAdd
|
||||
updateInput(textarea)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function saveCardPreview(event, tabname, filename){
|
||||
textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
|
||||
button = gradioApp().getElementById(tabname + '_save_preview')
|
||||
|
||||
textarea.value = filename
|
||||
updateInput(textarea)
|
||||
|
||||
button.click()
|
||||
|
||||
event.stopPropagation()
|
||||
event.preventDefault()
|
||||
}
|
@ -21,6 +21,8 @@ titles = {
|
||||
"\U0001F5D1": "Clear prompt",
|
||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||
"\u{1f4d2}": "Paste available values into the field",
|
||||
"\u{1f3b4}": "Show extra networks",
|
||||
|
||||
|
||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||
|
@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
|
||||
return [prompt, negative_prompt]
|
||||
}
|
||||
|
||||
|
||||
|
||||
opts = {}
|
||||
onUiUpdate(function(){
|
||||
if(Object.keys(opts).length != 0) return;
|
||||
@ -239,11 +237,14 @@ onUiUpdate(function(){
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
prompt.parentElement.insertBefore(counter, prompt)
|
||||
counter.classList.add("token-counter")
|
||||
prompt.parentElement.style.position = "relative"
|
||||
|
||||
textarea.addEventListener("input", () => update_token_counter(id_button));
|
||||
textarea.addEventListener("input", function(){
|
||||
update_token_counter(id_button);
|
||||
});
|
||||
}
|
||||
|
||||
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
||||
@ -261,10 +262,8 @@ onUiUpdate(function(){
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
|
||||
onOptionsChanged(function(){
|
||||
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||
|
@ -480,7 +480,7 @@ class Api:
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
shared.loaded_hypernetworks = []
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
@ -491,7 +491,6 @@ class Api:
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if not apply_optimizations:
|
||||
|
147
modules/extra_networks.py
Normal file
147
modules/extra_networks.py
Normal file
@ -0,0 +1,147 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from modules import errors
|
||||
|
||||
extra_network_registry = {}
|
||||
|
||||
|
||||
def initialize():
|
||||
extra_network_registry.clear()
|
||||
|
||||
|
||||
def register_extra_network(extra_network):
|
||||
extra_network_registry[extra_network.name] = extra_network
|
||||
|
||||
|
||||
class ExtraNetworkParams:
|
||||
def __init__(self, items=None):
|
||||
self.items = items or []
|
||||
|
||||
|
||||
class ExtraNetwork:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def activate(self, p, params_list):
|
||||
"""
|
||||
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
||||
Passes arguments related to this extra network in params_list.
|
||||
User passes arguments by specifying this in his prompt:
|
||||
|
||||
<name:arg1:arg2:arg3>
|
||||
|
||||
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
||||
separated by colon.
|
||||
|
||||
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
||||
in this case, all effects of this extra networks should be disabled.
|
||||
|
||||
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
||||
|
||||
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
||||
|
||||
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
||||
|
||||
params_list will be:
|
||||
|
||||
[
|
||||
ExtraNetworkParams(items=["agm", "1.1"]),
|
||||
ExtraNetworkParams(items=["ray"])
|
||||
]
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def deactivate(self, p):
|
||||
"""
|
||||
Called at the end of processing for housekeeping. No need to do anything here.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def activate(p, extra_network_data):
|
||||
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||
activate for all remaining registered networks with an empty argument list"""
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
print(f"Skipping unknown extra network: {extra_network_name}")
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.activate(p, extra_network_args)
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
args = extra_network_data.get(extra_network_name, None)
|
||||
if args is not None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.activate(p, [])
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name}")
|
||||
|
||||
|
||||
def deactivate(p, extra_network_data):
|
||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||
deactivate for all remaining registered networks"""
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating extra network {extra_network_name}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
args = extra_network_data.get(extra_network_name, None)
|
||||
if args is not None:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
||||
|
||||
|
||||
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
||||
|
||||
|
||||
def parse_prompt(prompt):
|
||||
res = defaultdict(list)
|
||||
|
||||
def found(m):
|
||||
name = m.group(1)
|
||||
args = m.group(2)
|
||||
|
||||
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
||||
|
||||
return ""
|
||||
|
||||
prompt = re.sub(re_extra_net, found, prompt)
|
||||
|
||||
return prompt, res
|
||||
|
||||
|
||||
def parse_prompts(prompts):
|
||||
res = []
|
||||
extra_data = None
|
||||
|
||||
for prompt in prompts:
|
||||
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
||||
|
||||
if extra_data is None:
|
||||
extra_data = parsed_extra_data
|
||||
|
||||
res.append(updated_prompt)
|
||||
|
||||
return res, extra_data
|
||||
|
21
modules/extra_networks_hypernet.py
Normal file
21
modules/extra_networks_hypernet.py
Normal file
@ -0,0 +1,21 @@
|
||||
from modules import extra_networks
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('hypernet')
|
||||
|
||||
def activate(self, p, params_list):
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
assert len(params.items) > 0
|
||||
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
hypernetwork.load_hypernetworks(names, multipliers)
|
||||
|
||||
def deactivate(p, self):
|
||||
pass
|
@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict):
|
||||
from modules import ui
|
||||
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'sd_hypernetwork_strength': 'Hypernet strength',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'inpainting_mask_weight': 'Conditional mask weight',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Clip skip" not in res:
|
||||
res["Clip skip"] = "1"
|
||||
|
||||
if "Hypernet strength" not in res:
|
||||
res["Hypernet strength"] = "1"
|
||||
|
||||
if "Hypernet" in res:
|
||||
hypernet_name = res["Hypernet"]
|
||||
hypernet_hash = res.get("Hypernet hash", None)
|
||||
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
||||
hypernet = res.get("Hypernet", None)
|
||||
if hypernet is not None:
|
||||
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
||||
|
||||
if "Hires resize-1" not in res:
|
||||
res["Hires resize-1"] = 0
|
||||
|
@ -25,7 +25,6 @@ from statistics import stdev, mean
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module):
|
||||
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
||||
super().__init__()
|
||||
|
||||
self.multiplier = 1.0
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||
state_dict[to] = x
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
||||
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
||||
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module):
|
||||
return layer_structure
|
||||
|
||||
|
||||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
||||
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
||||
if layer_structure is None:
|
||||
@ -192,6 +190,20 @@ class Hypernetwork:
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = mode
|
||||
|
||||
def to(self, device):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.multiplier = multiplier
|
||||
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
@ -269,10 +281,12 @@ class Hypernetwork:
|
||||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
else:
|
||||
self.optimizer_name = "AdamW"
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
@ -306,23 +320,43 @@ def list_hypernetworks(path):
|
||||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(filename):
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
# Prevent any file named "None.pt" from being loaded.
|
||||
if path is not None and filename != "None":
|
||||
print(f"Loading hypernetwork {filename}")
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
def load_hypernetwork(name):
|
||||
path = shared.hypernetworks.get(name, None)
|
||||
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
hypernetwork = Hypernetwork()
|
||||
|
||||
try:
|
||||
hypernetwork.load(path)
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print("Unloading hypernetwork")
|
||||
return None
|
||||
|
||||
shared.loaded_hypernetwork = None
|
||||
return hypernetwork
|
||||
|
||||
|
||||
def load_hypernetworks(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for hypernetwork in shared.loaded_hypernetworks:
|
||||
if hypernetwork.name in names:
|
||||
already_loaded[hypernetwork.name] = hypernetwork
|
||||
|
||||
shared.loaded_hypernetworks.clear()
|
||||
|
||||
for i, name in enumerate(names):
|
||||
hypernetwork = already_loaded.get(name, None)
|
||||
if hypernetwork is None:
|
||||
hypernetwork = load_hypernetwork(name)
|
||||
|
||||
if hypernetwork is None:
|
||||
continue
|
||||
|
||||
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
||||
shared.loaded_hypernetworks.append(hypernetwork)
|
||||
|
||||
|
||||
def find_closest_hypernetwork_name(search: str):
|
||||
@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
|
||||
return applicable[0]
|
||||
|
||||
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is None:
|
||||
return context, context
|
||||
return context_k, context_v
|
||||
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = hypernetwork_layers[0](context)
|
||||
context_v = hypernetwork_layers[1](context)
|
||||
context_k = hypernetwork_layers[0](context_k)
|
||||
context_v = hypernetwork_layers[1](context_v)
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def apply_hypernetworks(hypernetworks, context, layer=None):
|
||||
context_k = context
|
||||
context_v = context
|
||||
for hypernetwork in hypernetworks:
|
||||
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
||||
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
|
||||
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
template_file = template_file.path
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
hypernetwork = Hypernetwork()
|
||||
hypernetwork.load(path)
|
||||
shared.loaded_hypernetworks = [hypernetwork]
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
initial_step = hypernetwork.step or 0
|
||||
|
@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||
|
||||
@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
shared.loaded_hypernetworks = []
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
||||
@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
sd_hijack.apply_optimizations()
|
||||
|
@ -13,7 +13,7 @@ from skimage import exposure
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
|
||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
try:
|
||||
for k, v in p.override_settings.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_hypernetwork':
|
||||
shared.reload_hypernetworks() # make onchange call for changing hypernet
|
||||
|
||||
if k == 'sd_model_checkpoint':
|
||||
sd_models.reload_model_weights() # make onchange call for changing SD model
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights() # make onchange call for changing VAE
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
|
||||
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
|
||||
if k == 'sd_vae': sd_vae.reload_vae_weights()
|
||||
if k == 'sd_model_checkpoint':
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
return res
|
||||
|
||||
@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
cache[0] = (required_prompts, steps)
|
||||
return cache[1]
|
||||
|
||||
p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts)
|
||||
|
||||
with torch.no_grad(), p.sd_model.ema_scope():
|
||||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
|
||||
extra_networks.activate(p, extra_network_data)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
extra_networks.deactivate(p, extra_network_data)
|
||||
devices.torch_gc()
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k = self.to_k(context_k) * self.scale
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
|
@ -23,6 +23,7 @@ demo = None
|
||||
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file
|
||||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = {}
|
||||
loaded_hypernetwork = None
|
||||
loaded_hypernetworks = []
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
@ -153,8 +154,6 @@ def reload_hypernetworks():
|
||||
global hypernetworks
|
||||
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
||||
|
||||
|
||||
|
||||
class State:
|
||||
@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
@ -661,3 +658,17 @@ mem_mon.start()
|
||||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
||||
|
||||
def html_path(filename):
|
||||
return os.path.join(script_path, "html", filename)
|
||||
|
||||
|
||||
def html(filename):
|
||||
path = html_path(filename)
|
||||
|
||||
if os.path.exists(path):
|
||||
with open(path, encoding="utf8") as file:
|
||||
return file.read()
|
||||
|
||||
return ""
|
||||
|
@ -50,6 +50,7 @@ class Embedding:
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.optimizer_state_dict = None
|
||||
self.filename = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
@ -182,6 +183,7 @@ class EmbeddingDatabase:
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
embedding.filename = path
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path
|
||||
|
||||
@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
apply_style_symbol = '\U0001f4cb' # 📋
|
||||
clear_prompt_symbol = '\U0001F5D1' # 🗑️
|
||||
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||
|
||||
def update_token_counter(text, steps):
|
||||
try:
|
||||
text, _ = extra_networks.parse_prompt(text)
|
||||
|
||||
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
||||
|
||||
@ -354,10 +357,10 @@ def create_toprow(is_img2img):
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
|
||||
with gr.Column(scale=1, elem_id="roll_col"):
|
||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
|
||||
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
|
||||
clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
||||
paste = ToolButton(value=paste_symbol, elem_id="paste")
|
||||
clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
||||
extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
|
||||
|
||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
|
||||
@ -395,11 +398,14 @@ def create_toprow(is_img2img):
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(elem_id=f"{id_part}_styles_row"):
|
||||
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
|
||||
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
|
||||
|
||||
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button
|
||||
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply")
|
||||
save_style = ToolButton(value=save_style_symbol, elem_id="style_create")
|
||||
|
||||
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
|
||||
|
||||
|
||||
def setup_progressbar(*args, **kwargs):
|
||||
@ -616,11 +622,15 @@ def create_ui():
|
||||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
|
||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
|
||||
|
||||
dummy_component = gr.Label(visible=False)
|
||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
||||
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
|
||||
from modules import ui_extra_networks
|
||||
extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||
for category in ordered_ui_categories():
|
||||
@ -794,14 +804,20 @@ def create_ui():
|
||||
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
||||
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||
|
||||
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
|
||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
|
||||
|
||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
||||
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
|
||||
from modules import ui_extra_networks
|
||||
extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
|
||||
|
||||
with FormRow().style(equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
||||
copy_image_buttons = []
|
||||
@ -1064,6 +1080,8 @@ def create_ui():
|
||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
||||
|
||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
||||
|
||||
img2img_paste_fields = [
|
||||
(img2img_prompt, "Prompt"),
|
||||
(img2img_negative_prompt, "Negative prompt"),
|
||||
@ -1666,10 +1684,8 @@ def create_ui():
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||
|
||||
if os.path.exists("html/licenses.html"):
|
||||
with open("html/licenses.html", encoding="utf8") as file:
|
||||
with gr.TabItem("Licenses"):
|
||||
gr.HTML(file.read(), elem_id="licenses")
|
||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||
|
||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
|
||||
@ -1756,9 +1772,7 @@ def create_ui():
|
||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||
|
||||
if os.path.exists("html/footer.html"):
|
||||
with open("html/footer.html", encoding="utf8") as file:
|
||||
footer = file.read()
|
||||
footer = shared.html("footer.html")
|
||||
footer = footer.format(versions=versions_html())
|
||||
gr.HTML(footer, elem_id="footer")
|
||||
|
||||
|
@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
return "button"
|
||||
|
||||
|
||||
class ToolButtonTop(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(variant="tool-top", **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
class FormRow(gr.Row, gr.components.FormComponent):
|
||||
"""Same as gr.Row but fits inside gradio forms"""
|
||||
|
||||
|
149
modules/ui_extra_networks.py
Normal file
149
modules/ui_extra_networks.py
Normal file
@ -0,0 +1,149 @@
|
||||
import os.path
|
||||
|
||||
from modules import shared
|
||||
import gradio as gr
|
||||
import json
|
||||
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
|
||||
extra_pages = []
|
||||
|
||||
|
||||
def register_page(page):
|
||||
"""registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions"""
|
||||
|
||||
extra_pages.append(page)
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
def __init__(self, title):
|
||||
self.title = title
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
self.allow_negative_prompt = False
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def create_html(self, tabname):
|
||||
items_html = ''
|
||||
|
||||
for item in self.list_items():
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
if items_html == '':
|
||||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
res = "<div class='extra-network-cards'>" + items_html + "</div>"
|
||||
|
||||
return res
|
||||
|
||||
def list_items(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return []
|
||||
|
||||
def create_html_for_item(self, item, tabname):
|
||||
preview = item.get("preview", None)
|
||||
|
||||
args = {
|
||||
"preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
|
||||
"prompt": json.dumps(item["prompt"]),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
"name": item["name"],
|
||||
"allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
|
||||
}
|
||||
|
||||
return self.card_page.format(**args)
|
||||
|
||||
|
||||
def intialize():
|
||||
extra_pages.clear()
|
||||
|
||||
|
||||
class ExtraNetworksUi:
|
||||
def __init__(self):
|
||||
self.pages = None
|
||||
self.stored_extra_pages = None
|
||||
|
||||
self.button_save_preview = None
|
||||
self.preview_target_filename = None
|
||||
|
||||
self.tabname = None
|
||||
|
||||
|
||||
def create_ui(container, button, tabname):
|
||||
ui = ExtraNetworksUi()
|
||||
ui.pages = []
|
||||
ui.stored_extra_pages = extra_pages.copy()
|
||||
ui.tabname = tabname
|
||||
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||
button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
|
||||
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title):
|
||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||
|
||||
button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
|
||||
button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
|
||||
|
||||
def refresh():
|
||||
res = []
|
||||
|
||||
for pg in ui.stored_extra_pages:
|
||||
pg.refresh()
|
||||
res.append(pg.create_html(ui.tabname))
|
||||
|
||||
return res
|
||||
|
||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||
|
||||
return ui
|
||||
|
||||
|
||||
def path_is_parent(parent_path, child_path):
|
||||
parent_path = os.path.abspath(parent_path)
|
||||
child_path = os.path.abspath(child_path)
|
||||
|
||||
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
|
||||
|
||||
|
||||
def setup_ui(ui, gallery):
|
||||
def save_preview(index, images, filename):
|
||||
if len(images) == 0:
|
||||
print("There is no image in gallery to save as a preview.")
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
index = int(index)
|
||||
index = 0 if index < 0 else index
|
||||
index = len(images) - 1 if index >= len(images) else index
|
||||
|
||||
img_info = images[index if index >= 0 else 0]
|
||||
image = image_from_url_text(img_info)
|
||||
|
||||
is_allowed = False
|
||||
for extra_page in ui.stored_extra_pages:
|
||||
if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
|
||||
is_allowed = True
|
||||
break
|
||||
|
||||
assert is_allowed, f'writing to {filename} is not allowed'
|
||||
|
||||
image.save(filename)
|
||||
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
ui.button_save_preview.click(
|
||||
fn=save_preview,
|
||||
_js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
|
||||
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
||||
outputs=[*ui.pages]
|
||||
)
|
34
modules/ui_extra_networks_hypernets.py
Normal file
34
modules/ui_extra_networks_hypernets.py
Normal file
@ -0,0 +1,34 @@
|
||||
import os
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
|
||||
|
||||
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Hypernetworks')
|
||||
|
||||
def refresh(self):
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
def list_items(self):
|
||||
for name, path in shared.hypernetworks.items():
|
||||
path, ext = os.path.splitext(path)
|
||||
previews = [path + ".png", path + ".preview.png"]
|
||||
|
||||
preview = None
|
||||
for file in previews:
|
||||
if os.path.isfile(file):
|
||||
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
|
||||
break
|
||||
|
||||
yield {
|
||||
"name": name,
|
||||
"filename": path,
|
||||
"preview": preview,
|
||||
"prompt": f"<hypernet:{name}:1.0>",
|
||||
"local_preview": path + ".png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.hypernetwork_dir]
|
||||
|
32
modules/ui_extra_networks_textual_inversion.py
Normal file
32
modules/ui_extra_networks_textual_inversion.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
|
||||
from modules import ui_extra_networks, sd_hijack
|
||||
|
||||
|
||||
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Textual Inversion')
|
||||
self.allow_negative_prompt = True
|
||||
|
||||
def refresh(self):
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def list_items(self):
|
||||
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
|
||||
path, ext = os.path.splitext(embedding.filename)
|
||||
preview_file = path + ".preview.png"
|
||||
|
||||
preview = None
|
||||
if os.path.isfile(preview_file):
|
||||
preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
|
||||
|
||||
yield {
|
||||
"name": embedding.name,
|
||||
"filename": embedding.filename,
|
||||
"preview": preview,
|
||||
"prompt": embedding.name,
|
||||
"local_preview": path + ".preview.png",
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
11
script.js
11
script.js
@ -13,6 +13,7 @@ function get_uiCurrentTabContent() {
|
||||
}
|
||||
|
||||
uiUpdateCallbacks = []
|
||||
uiLoadedCallbacks = []
|
||||
uiTabChangeCallbacks = []
|
||||
optionsChangedCallbacks = []
|
||||
let uiCurrentTab = null
|
||||
@ -20,6 +21,9 @@ let uiCurrentTab = null
|
||||
function onUiUpdate(callback){
|
||||
uiUpdateCallbacks.push(callback)
|
||||
}
|
||||
function onUiLoaded(callback){
|
||||
uiLoadedCallbacks.push(callback)
|
||||
}
|
||||
function onUiTabChange(callback){
|
||||
uiTabChangeCallbacks.push(callback)
|
||||
}
|
||||
@ -38,8 +42,15 @@ function executeCallbacks(queue, m) {
|
||||
queue.forEach(function(x){runCallback(x, m)})
|
||||
}
|
||||
|
||||
var executedOnLoaded = false;
|
||||
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
|
||||
executedOnLoaded = true;
|
||||
executeCallbacks(uiLoadedCallbacks);
|
||||
}
|
||||
|
||||
executeCallbacks(uiUpdateCallbacks, m);
|
||||
const newTab = get_uiCurrentTab();
|
||||
if ( newTab && ( newTab !== uiCurrentTab ) ) {
|
||||
|
@ -11,7 +11,6 @@ import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
|
||||
from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs):
|
||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||
|
||||
|
||||
def apply_hypernetwork(p, x, xs):
|
||||
if x.lower() in ["", "none"]:
|
||||
name = None
|
||||
else:
|
||||
name = hypernetwork.find_closest_hypernetwork_name(x)
|
||||
if not name:
|
||||
raise RuntimeError(f"Unknown hypernetwork: {x}")
|
||||
hypernetwork.load_hypernetwork(name)
|
||||
|
||||
|
||||
def apply_hypernetwork_strength(p, x, xs):
|
||||
hypernetwork.apply_strength(x)
|
||||
|
||||
|
||||
def confirm_hypernetworks(p, xs):
|
||||
for x in xs:
|
||||
if x.lower() in ["", "none"]:
|
||||
continue
|
||||
if not hypernetwork.find_closest_hypernetwork_name(x):
|
||||
raise RuntimeError(f"Unknown hypernetwork: {x}")
|
||||
|
||||
|
||||
def apply_clip_skip(p, x, xs):
|
||||
opts.data["CLIP_stop_at_last_layers"] = x
|
||||
|
||||
@ -208,8 +185,6 @@ axis_options = [
|
||||
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||
AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
|
||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
|
||||
AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
|
||||
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
||||
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
||||
@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
|
||||
class SharedSettingsStackHelper(object):
|
||||
def __enter__(self):
|
||||
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
|
||||
self.hypernetwork = opts.sd_hypernetwork
|
||||
self.vae = opts.sd_vae
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object):
|
||||
modules.sd_models.reload_model_weights()
|
||||
modules.sd_vae.reload_vae_weights()
|
||||
|
||||
hypernetwork.load_hypernetwork(self.hypernetwork)
|
||||
hypernetwork.apply_strength()
|
||||
|
||||
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
|
||||
|
||||
|
||||
|
164
style.css
164
style.css
@ -132,13 +132,6 @@
|
||||
}
|
||||
|
||||
#roll_col > button {
|
||||
min-width: 2em;
|
||||
min-height: 2em;
|
||||
max-width: 2em;
|
||||
max-height: 2em;
|
||||
flex-grow: 0;
|
||||
padding-left: 0.25em;
|
||||
padding-right: 0.25em;
|
||||
margin: 0.1em 0;
|
||||
}
|
||||
|
||||
@ -146,9 +139,10 @@
|
||||
min-width: 0 !important;
|
||||
max-width: 8em !important;
|
||||
margin-right: 1em;
|
||||
gap: 0;
|
||||
}
|
||||
#interrogate, #deepbooru{
|
||||
margin: 0em 0.25em 0.9em 0.25em;
|
||||
margin: 0em 0.25em 0.5em 0.25em;
|
||||
min-width: 8em;
|
||||
max-width: 8em;
|
||||
}
|
||||
@ -157,8 +151,17 @@
|
||||
min-width: 8em !important;
|
||||
}
|
||||
|
||||
#txt2img_styles_row, #img2img_styles_row{
|
||||
gap: 0.25em;
|
||||
margin-top: 0.5em;
|
||||
}
|
||||
|
||||
#txt2img_styles_row > button, #img2img_styles_row > button{
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
#txt2img_styles, #img2img_styles{
|
||||
margin-top: 1em;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
#txt2img_styles ul, #img2img_styles ul{
|
||||
@ -635,16 +638,20 @@ canvas[key="mask"] {
|
||||
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
|
||||
}
|
||||
|
||||
.gr-button-tool{
|
||||
.gr-button-tool, .gr-button-tool-top{
|
||||
max-width: 2.5em;
|
||||
min-width: 2.5em !important;
|
||||
height: 2.4em;
|
||||
}
|
||||
|
||||
.gr-button-tool{
|
||||
margin: 0.6em 0em 0.55em 0;
|
||||
}
|
||||
|
||||
.gr-button-tool-top, #settings .gr-button-tool{
|
||||
margin: 1.6em 0.7em 0.55em 0;
|
||||
}
|
||||
|
||||
#tab_modelmerger .gr-button-tool{
|
||||
margin: 0.6em 0em 0.55em 0;
|
||||
}
|
||||
|
||||
#modelmerger_results_container{
|
||||
margin-top: 1em;
|
||||
@ -763,81 +770,88 @@ footer {
|
||||
line-height: 2.4em;
|
||||
}
|
||||
|
||||
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
||||
The rtl media type will only be activated by the logic in javascript/localization.js.
|
||||
If you change anything above, you need to make sure it is RTL compliant by just running
|
||||
your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
|
||||
Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
|
||||
@media rtl {
|
||||
/* this part was added manually */
|
||||
:host {
|
||||
direction: rtl;
|
||||
}
|
||||
select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
|
||||
direction: ltr;
|
||||
}
|
||||
#script_list > label > select,
|
||||
#x_type > label > select,
|
||||
#y_type > label > select {
|
||||
direction: rtl;
|
||||
}
|
||||
.gr-radio, .gr-checkbox{
|
||||
margin-left: 0.25em;
|
||||
#txt2img_extra_networks, #img2img_extra_networks{
|
||||
margin-top: -1em;
|
||||
}
|
||||
|
||||
/* automatically generated with few manual modifications */
|
||||
.performance .time {
|
||||
margin-right: unset;
|
||||
margin-left: 0;
|
||||
.extra-networks > div > [id *= '_extra_']{
|
||||
margin: 0.3em;
|
||||
}
|
||||
.justify-center.overflow-x-scroll {
|
||||
justify-content: right;
|
||||
|
||||
.extra-network-cards .nocards{
|
||||
margin: 1.25em 0.5em 0.5em 0.5em;
|
||||
}
|
||||
.justify-center.overflow-x-scroll button:first-of-type {
|
||||
margin-left: unset;
|
||||
margin-right: auto;
|
||||
|
||||
.extra-network-cards .nocards h1{
|
||||
font-size: 1.5em;
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.justify-center.overflow-x-scroll button:last-of-type {
|
||||
margin-right: unset;
|
||||
margin-left: auto;
|
||||
|
||||
.extra-network-cards .nocards li{
|
||||
margin-left: 0.5em;
|
||||
}
|
||||
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
|
||||
margin-right: unset;
|
||||
margin-left: 8em;
|
||||
|
||||
.extra-network-cards .card{
|
||||
display: inline-block;
|
||||
margin: 0.5em;
|
||||
width: 16em;
|
||||
height: 24em;
|
||||
box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);
|
||||
border-radius: 0.2em;
|
||||
position: relative;
|
||||
|
||||
background-size: auto 100%;
|
||||
background-position: center;
|
||||
overflow: hidden;
|
||||
cursor: pointer;
|
||||
|
||||
background-image: url('./file=html/card-no-preview.png')
|
||||
}
|
||||
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
|
||||
right: unset;
|
||||
|
||||
.extra-network-cards .card:hover{
|
||||
box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35);
|
||||
}
|
||||
|
||||
.extra-network-cards .card .actions .additional{
|
||||
display: none;
|
||||
}
|
||||
|
||||
.extra-network-cards .card .actions{
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
}
|
||||
.progressDiv .progress{
|
||||
padding: 0 0 0 8px;
|
||||
text-align: left;
|
||||
}
|
||||
#lightboxModal{
|
||||
left: unset;
|
||||
right: 0;
|
||||
padding: 0.5em;
|
||||
color: white;
|
||||
background: rgba(0,0,0,0.5);
|
||||
box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5);
|
||||
text-shadow: 0 0 0.2em black;
|
||||
}
|
||||
.modalPrev, .modalNext{
|
||||
border-radius: 3px 0 0 3px;
|
||||
|
||||
.extra-network-cards .card .actions:hover{
|
||||
box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important;
|
||||
}
|
||||
.modalNext {
|
||||
right: unset;
|
||||
left: 0;
|
||||
border-radius: 0 3px 3px 0;
|
||||
|
||||
.extra-network-cards .card .actions .name{
|
||||
font-size: 1.7em;
|
||||
font-weight: bold;
|
||||
line-break: anywhere;
|
||||
}
|
||||
#imageARPreview{
|
||||
left:unset;
|
||||
right:0px;
|
||||
|
||||
.extra-network-cards .card .actions:hover .additional{
|
||||
display: block;
|
||||
}
|
||||
#txt2img_skip, #img2img_skip{
|
||||
right: unset;
|
||||
left: 0px;
|
||||
|
||||
.extra-network-cards .card ul{
|
||||
margin: 0.25em 0 0.75em 0.25em;
|
||||
cursor: unset;
|
||||
}
|
||||
#context-menu{
|
||||
box-shadow:-1px 1px 2px #CE6400;
|
||||
}
|
||||
.gr-box > div > div > input.gr-text-input{
|
||||
right: unset;
|
||||
left: 0.5em;
|
||||
|
||||
.extra-network-cards .card ul a{
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.extra-network-cards .card ul a:hover{
|
||||
color: red;
|
||||
}
|
||||
|
||||
|
26
webui.py
26
webui.py
@ -9,16 +9,18 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
|
||||
from modules import import_hook, errors
|
||||
from modules import import_hook, errors, extra_networks
|
||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||
from modules.paths import script_path
|
||||
|
||||
import torch
|
||||
|
||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
|
||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
|
||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
|
||||
import modules.codeformer_model as codeformer
|
||||
import modules.extras
|
||||
import modules.face_restoration
|
||||
@ -84,10 +86,17 @@ def initialize():
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
|
||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
ui_extra_networks.intialize()
|
||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
|
||||
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
||||
|
||||
try:
|
||||
@ -209,6 +218,15 @@ def webui():
|
||||
|
||||
modules.sd_models.list_models()
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
ui_extra_networks.intialize()
|
||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
|
||||
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if cmd_opts.nowebui:
|
||||
|
Loading…
Reference in New Issue
Block a user