Merge branch 'dev' into test-fp8

This commit is contained in:
Kohaku-Blueleaf 2023-11-19 15:24:57 +08:00
commit b60e1088db
16 changed files with 246 additions and 21 deletions

View File

@ -147,7 +147,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
- CodeFormer - https://github.com/sczhou/CodeFormer

View File

@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, AB BA. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

View File

@ -0,0 +1,97 @@
import torch
import network
from lyco_helpers import factorization
from einops import rearrange
class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
return NetworkModuleOFT(net, weights)
return None
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS
elif "oft_diag" in weights.w.keys():
self.is_kohya = False
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
if is_linear:
self.out_dim = self.sd_module.out_features
elif is_conv:
self.out_dim = self.sd_module.out_channels
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim
if self.is_kohya:
self.constraint = self.alpha * self.out_dim
self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim
else:
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown_kb(self, orig_weight, multiplier):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
merged_weight = torch.einsum(
'k n m, k n ... -> k m ...',
R,
merged_weight
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)
def calc_updown(self, orig_weight):
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
multiplier = self.multiplier()
return self.calc_updown_kb(orig_weight, multiplier)
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
return updown, ex_bias

View File

@ -11,6 +11,7 @@ import network_ia3
import network_lokr
import network_full
import network_norm
import network_oft
import torch
from typing import Union
@ -28,6 +29,7 @@ module_types = [
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
network_oft.ModuleTypeOFT(),
]
@ -189,6 +191,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# kohya_ss OFT module
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# KohakuBlueLeaf OFT module
if sd_module is None and "oft_diag" in key:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
keys_failed_to_match[key_network] = key
continue

View File

@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name)
if lora_on_disk is None:
return
path, ext = os.path.splitext(lora_on_disk.filename)
@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item
def list_items(self):
for index, name in enumerate(networks.available_networks):
# instantiate a list to protect against concurrent modification
names = list(networks.available_networks)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item

View File

@ -1,16 +1,41 @@
import os
import logging
try:
from tqdm.auto import tqdm
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
TQDM_IMPORTED = True
except ImportError:
# tqdm does not exist before first launch
# I will import once the UI finishes seting up the enviroment and reloads.
TQDM_IMPORTED = False
def setup_logging(loglevel):
if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
loghandlers = []
if TQDM_IMPORTED:
loghandlers.append(TqdmLoggingHandler())
if loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=loghandlers
)

View File

@ -78,7 +78,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
image_data.close()
devices.torch_gc()
shared.state.end()
return outputs, ui_common.plaintext_to_html(infotext), ''

View File

@ -110,7 +110,7 @@ class ImageRNG:
self.is_first = True
def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
xs = []

View File

@ -235,7 +235,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
"extra_networks_card_order_field": OptionInfo("Name", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
@ -273,6 +273,8 @@ options_templates.update(options_section(('ui', "User interface"), {
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
"compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
}))

View File

@ -4,6 +4,7 @@ import os
import sys
from functools import reduce
import warnings
from contextlib import ExitStack
import gradio as gr
import gradio.utils
@ -270,7 +271,11 @@ def create_ui():
extra_tabs.__enter__()
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
with ExitStack() as stack:
if shared.opts.txt2img_settings_accordion:
stack.enter_context(gr.Accordion("Open for Settings", open=False))
stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
@ -489,7 +494,11 @@ def create_ui():
extra_tabs.__enter__()
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
with ExitStack() as stack:
if shared.opts.img2img_settings_accordion:
stack.enter_context(gr.Accordion("Open for Settings", open=False))
stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
copy_image_buttons = []
copy_image_destinations = {}

View File

@ -279,6 +279,7 @@ class ExtraNetworksPage:
"date_created": int(stat.st_ctime or 0),
"date_modified": int(stat.st_mtime or 0),
"name": pth.name.lower(),
"path": str(pth.parent).lower(),
}
def find_preview(self, path):
@ -382,7 +383,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs.append(tab)
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
dropdown_sort = gr.Dropdown(choices=['Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)

View File

@ -17,6 +17,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
if checkpoint is None:
return
path, ext = os.path.splitext(checkpoint.filename)
return {
"name": checkpoint.name_for_extra,
@ -32,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
# instantiate a list to protect against concurrent modification
names = list(sd_models.checkpoints_list)
for index, name in enumerate(names):
yield self.create_item(name, index)
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]

View File

@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
shared.reload_hypernetworks()
def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name]
full_path = shared.hypernetworks.get(name)
if full_path is None:
return
path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None
@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
for index, name in enumerate(shared.hypernetworks):
yield self.create_item(name, index)
# instantiate a list to protect against concurrent modification
names = list(shared.hypernetworks)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir]

View File

@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
if embedding is None:
return
path, ext = os.path.splitext(embedding.filename)
return {
@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
yield self.create_item(name, index)
# instantiate a list to protect against concurrent modification
names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)

View File

@ -68,10 +68,10 @@ class UiPromptStyles:
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
with gr.Row():
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
with gr.Row():
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"])
with gr.Row():
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)

View File

@ -133,9 +133,18 @@ document.addEventListener('keydown', function(e) {
if (isEnter && isModifierKey) {
if (interruptButton.style.display === 'block') {
interruptButton.click();
setTimeout(function() {
generateButton.click();
}, 500);
const callback = (mutationList) => {
for (const mutation of mutationList) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if (interruptButton.style.display === 'none') {
generateButton.click();
observer.disconnect();
}
}
}
};
const observer = new MutationObserver(callback);
observer.observe(interruptButton, {attributes: true});
} else {
generateButton.click();
}