mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-25 06:19:00 +08:00
Merge branch 'dev' into test-fp8
This commit is contained in:
commit
b60e1088db
@ -147,7 +147,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
|
|||||||
## Credits
|
## Credits
|
||||||
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
||||||
|
|
||||||
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
|
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
|
||||||
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
||||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||||
- CodeFormer - https://github.com/sczhou/CodeFormer
|
- CodeFormer - https://github.com/sczhou/CodeFormer
|
||||||
|
@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
|
|||||||
up = up.reshape(up.size(0), -1)
|
up = up.reshape(up.size(0), -1)
|
||||||
down = down.reshape(down.size(0), -1)
|
down = down.reshape(down.size(0), -1)
|
||||||
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
|
||||||
|
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||||
|
'''
|
||||||
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||||
|
secon value is a value for weight.
|
||||||
|
|
||||||
|
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||||
|
|
||||||
|
examples)
|
||||||
|
factor
|
||||||
|
-1 2 4 8 16 ...
|
||||||
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||||
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||||
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||||
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||||
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||||
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||||
|
'''
|
||||||
|
|
||||||
|
if factor > 0 and (dimension % factor) == 0:
|
||||||
|
m = factor
|
||||||
|
n = dimension // factor
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
if factor < 0:
|
||||||
|
factor = dimension
|
||||||
|
m, n = 1, dimension
|
||||||
|
length = m + n
|
||||||
|
while m<n:
|
||||||
|
new_m = m + 1
|
||||||
|
while dimension%new_m != 0:
|
||||||
|
new_m += 1
|
||||||
|
new_n = dimension // new_m
|
||||||
|
if new_m + new_n > length or new_m>factor:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
m, n = new_m, new_n
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
|
||||||
|
97
extensions-builtin/Lora/network_oft.py
Normal file
97
extensions-builtin/Lora/network_oft.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
import network
|
||||||
|
from lyco_helpers import factorization
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeOFT(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
|
||||||
|
return NetworkModuleOFT(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
|
||||||
|
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
|
||||||
|
class NetworkModuleOFT(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.lin_module = None
|
||||||
|
self.org_module: list[torch.Module] = [self.sd_module]
|
||||||
|
|
||||||
|
# kohya-ss
|
||||||
|
if "oft_blocks" in weights.w.keys():
|
||||||
|
self.is_kohya = True
|
||||||
|
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||||
|
self.alpha = weights.w["alpha"] # alpha is constraint
|
||||||
|
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||||
|
# LyCORIS
|
||||||
|
elif "oft_diag" in weights.w.keys():
|
||||||
|
self.is_kohya = False
|
||||||
|
self.oft_blocks = weights.w["oft_diag"]
|
||||||
|
# self.alpha is unused
|
||||||
|
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||||
|
|
||||||
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||||
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
|
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||||
|
|
||||||
|
if is_linear:
|
||||||
|
self.out_dim = self.sd_module.out_features
|
||||||
|
elif is_conv:
|
||||||
|
self.out_dim = self.sd_module.out_channels
|
||||||
|
elif is_other_linear:
|
||||||
|
self.out_dim = self.sd_module.embed_dim
|
||||||
|
|
||||||
|
if self.is_kohya:
|
||||||
|
self.constraint = self.alpha * self.out_dim
|
||||||
|
self.num_blocks = self.dim
|
||||||
|
self.block_size = self.out_dim // self.dim
|
||||||
|
else:
|
||||||
|
self.constraint = None
|
||||||
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
|
|
||||||
|
def calc_updown_kb(self, orig_weight, multiplier):
|
||||||
|
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||||
|
|
||||||
|
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
|
||||||
|
|
||||||
|
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||||
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
|
merged_weight = torch.einsum(
|
||||||
|
'k n m, k n ... -> k m ...',
|
||||||
|
R,
|
||||||
|
merged_weight
|
||||||
|
)
|
||||||
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
|
|
||||||
|
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||||
|
output_shape = orig_weight.shape
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
|
||||||
|
multiplier = self.multiplier()
|
||||||
|
return self.calc_updown_kb(orig_weight, multiplier)
|
||||||
|
|
||||||
|
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
||||||
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
|
if self.bias is not None:
|
||||||
|
updown = updown.reshape(self.bias.shape)
|
||||||
|
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
|
if len(output_shape) == 4:
|
||||||
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
|
if orig_weight.size().numel() == updown.size().numel():
|
||||||
|
updown = updown.reshape(orig_weight.shape)
|
||||||
|
|
||||||
|
if ex_bias is not None:
|
||||||
|
ex_bias = ex_bias * self.multiplier()
|
||||||
|
|
||||||
|
return updown, ex_bias
|
@ -11,6 +11,7 @@ import network_ia3
|
|||||||
import network_lokr
|
import network_lokr
|
||||||
import network_full
|
import network_full
|
||||||
import network_norm
|
import network_norm
|
||||||
|
import network_oft
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@ -28,6 +29,7 @@ module_types = [
|
|||||||
network_full.ModuleTypeFull(),
|
network_full.ModuleTypeFull(),
|
||||||
network_norm.ModuleTypeNorm(),
|
network_norm.ModuleTypeNorm(),
|
||||||
network_glora.ModuleTypeGLora(),
|
network_glora.ModuleTypeGLora(),
|
||||||
|
network_oft.ModuleTypeOFT(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -189,6 +191,17 @@ def load_network(name, network_on_disk):
|
|||||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
# kohya_ss OFT module
|
||||||
|
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
|
||||||
|
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
# KohakuBlueLeaf OFT module
|
||||||
|
if sd_module is None and "oft_diag" in key:
|
||||||
|
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
||||||
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
if sd_module is None:
|
if sd_module is None:
|
||||||
keys_failed_to_match[key_network] = key
|
keys_failed_to_match[key_network] = key
|
||||||
continue
|
continue
|
||||||
|
@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
lora_on_disk = networks.available_networks.get(name)
|
lora_on_disk = networks.available_networks.get(name)
|
||||||
|
if lora_on_disk is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||||
|
|
||||||
@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
return item
|
return item
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(networks.available_networks):
|
# instantiate a list to protect against concurrent modification
|
||||||
|
names = list(networks.available_networks)
|
||||||
|
for index, name in enumerate(names):
|
||||||
item = self.create_item(name, index)
|
item = self.create_item(name, index)
|
||||||
|
|
||||||
if item is not None:
|
if item is not None:
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
@ -1,16 +1,41 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
class TqdmLoggingHandler(logging.Handler):
|
||||||
|
def __init__(self, level=logging.INFO):
|
||||||
|
super().__init__(level)
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
try:
|
||||||
|
msg = self.format(record)
|
||||||
|
tqdm.write(msg)
|
||||||
|
self.flush()
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
TQDM_IMPORTED = True
|
||||||
|
except ImportError:
|
||||||
|
# tqdm does not exist before first launch
|
||||||
|
# I will import once the UI finishes seting up the enviroment and reloads.
|
||||||
|
TQDM_IMPORTED = False
|
||||||
|
|
||||||
def setup_logging(loglevel):
|
def setup_logging(loglevel):
|
||||||
if loglevel is None:
|
if loglevel is None:
|
||||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||||
|
|
||||||
|
loghandlers = []
|
||||||
|
|
||||||
|
if TQDM_IMPORTED:
|
||||||
|
loghandlers.append(TqdmLoggingHandler())
|
||||||
|
|
||||||
if loglevel:
|
if loglevel:
|
||||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=log_level,
|
level=log_level,
|
||||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S',
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
handlers=loghandlers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
image_data.close()
|
image_data.close()
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
shared.state.end()
|
||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class ImageRNG:
|
|||||||
self.is_first = True
|
self.is_first = True
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
|
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
|
||||||
|
|
||||||
xs = []
|
xs = []
|
||||||
|
|
||||||
|
@ -235,7 +235,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||||
"extra_networks_card_order_field": OptionInfo("Name", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||||
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||||
@ -273,6 +273,8 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||||
|
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
|
||||||
|
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
|
||||||
"compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
|
"compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import warnings
|
import warnings
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
@ -270,7 +271,11 @@ def create_ui():
|
|||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with ExitStack() as stack:
|
||||||
|
if shared.opts.txt2img_settings_accordion:
|
||||||
|
stack.enter_context(gr.Accordion("Open for Settings", open=False))
|
||||||
|
stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
|
||||||
|
|
||||||
scripts.scripts_txt2img.prepare_ui()
|
scripts.scripts_txt2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
@ -489,7 +494,11 @@ def create_ui():
|
|||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
with ExitStack() as stack:
|
||||||
|
if shared.opts.img2img_settings_accordion:
|
||||||
|
stack.enter_context(gr.Accordion("Open for Settings", open=False))
|
||||||
|
stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
|
||||||
|
|
||||||
copy_image_buttons = []
|
copy_image_buttons = []
|
||||||
copy_image_destinations = {}
|
copy_image_destinations = {}
|
||||||
|
|
||||||
|
@ -279,6 +279,7 @@ class ExtraNetworksPage:
|
|||||||
"date_created": int(stat.st_ctime or 0),
|
"date_created": int(stat.st_ctime or 0),
|
||||||
"date_modified": int(stat.st_mtime or 0),
|
"date_modified": int(stat.st_mtime or 0),
|
||||||
"name": pth.name.lower(),
|
"name": pth.name.lower(),
|
||||||
|
"path": str(pth.parent).lower(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_preview(self, path):
|
def find_preview(self, path):
|
||||||
@ -382,7 +383,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
|||||||
related_tabs.append(tab)
|
related_tabs.append(tab)
|
||||||
|
|
||||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||||
dropdown_sort = gr.Dropdown(choices=['Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||||
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
|
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
|
||||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
||||||
|
@ -17,6 +17,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
||||||
|
if checkpoint is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(checkpoint.filename)
|
path, ext = os.path.splitext(checkpoint.filename)
|
||||||
return {
|
return {
|
||||||
"name": checkpoint.name_for_extra,
|
"name": checkpoint.name_for_extra,
|
||||||
@ -32,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
|
# instantiate a list to protect against concurrent modification
|
||||||
names = list(sd_models.checkpoints_list)
|
names = list(sd_models.checkpoints_list)
|
||||||
for index, name in enumerate(names):
|
for index, name in enumerate(names):
|
||||||
yield self.create_item(name, index)
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||||
|
@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
full_path = shared.hypernetworks[name]
|
full_path = shared.hypernetworks.get(name)
|
||||||
|
if full_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(full_path)
|
path, ext = os.path.splitext(full_path)
|
||||||
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
||||||
shorthash = sha256[0:10] if sha256 else None
|
shorthash = sha256[0:10] if sha256 else None
|
||||||
@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(shared.hypernetworks):
|
# instantiate a list to protect against concurrent modification
|
||||||
yield self.create_item(name, index)
|
names = list(shared.hypernetworks)
|
||||||
|
for index, name in enumerate(names):
|
||||||
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [shared.cmd_opts.hypernetwork_dir]
|
return [shared.cmd_opts.hypernetwork_dir]
|
||||||
|
@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
||||||
|
if embedding is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(embedding.filename)
|
path, ext = os.path.splitext(embedding.filename)
|
||||||
return {
|
return {
|
||||||
@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
|
# instantiate a list to protect against concurrent modification
|
||||||
yield self.create_item(name, index)
|
names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
|
||||||
|
for index, name in enumerate(names):
|
||||||
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
||||||
|
@ -68,10 +68,10 @@ class UiPromptStyles:
|
|||||||
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
|
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
|
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
||||||
|
15
script.js
15
script.js
@ -133,9 +133,18 @@ document.addEventListener('keydown', function(e) {
|
|||||||
if (isEnter && isModifierKey) {
|
if (isEnter && isModifierKey) {
|
||||||
if (interruptButton.style.display === 'block') {
|
if (interruptButton.style.display === 'block') {
|
||||||
interruptButton.click();
|
interruptButton.click();
|
||||||
setTimeout(function() {
|
const callback = (mutationList) => {
|
||||||
generateButton.click();
|
for (const mutation of mutationList) {
|
||||||
}, 500);
|
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||||
|
if (interruptButton.style.display === 'none') {
|
||||||
|
generateButton.click();
|
||||||
|
observer.disconnect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const observer = new MutationObserver(callback);
|
||||||
|
observer.observe(interruptButton, {attributes: true});
|
||||||
} else {
|
} else {
|
||||||
generateButton.click();
|
generateButton.click();
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user