mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
suggestions and fixes from the PR
This commit is contained in:
parent
d25219b7e8
commit
3ec7b705c7
@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
|||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(lora.available_loras)}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -644,17 +644,13 @@ class SwinIR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=None, num_heads=None,
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(SwinIR, self).__init__()
|
super(SwinIR, self).__init__()
|
||||||
|
|
||||||
depths = depths or [6, 6, 6, 6]
|
|
||||||
num_heads = num_heads or [6, 6, 6, 6]
|
|
||||||
|
|
||||||
num_in_ch = in_chans
|
num_in_ch = in_chans
|
||||||
num_out_ch = in_chans
|
num_out_ch = in_chans
|
||||||
num_feat = 64
|
num_feat = 64
|
||||||
|
@ -74,12 +74,9 @@ class WindowAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||||
pretrained_window_size=None):
|
pretrained_window_size=(0, 0)):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
pretrained_window_size = pretrained_window_size or [0, 0]
|
|
||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.window_size = window_size # Wh, Ww
|
self.window_size = window_size # Wh, Ww
|
||||||
self.pretrained_window_size = pretrained_window_size
|
self.pretrained_window_size = pretrained_window_size
|
||||||
@ -701,17 +698,13 @@ class Swin2SR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=None, num_heads=None,
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(Swin2SR, self).__init__()
|
super(Swin2SR, self).__init__()
|
||||||
|
|
||||||
depths = depths or [6, 6, 6, 6]
|
|
||||||
num_heads = num_heads or [6, 6, 6, 6]
|
|
||||||
|
|
||||||
num_in_ch = in_chans
|
num_in_ch = in_chans
|
||||||
num_out_ch = in_chans
|
num_out_ch = in_chans
|
||||||
num_feat = 64
|
num_feat = 64
|
||||||
|
@ -161,13 +161,10 @@ class Fuse_sft_block(nn.Module):
|
|||||||
class CodeFormer(VQAutoEncoder):
|
class CodeFormer(VQAutoEncoder):
|
||||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||||
codebook_size=1024, latent_size=256,
|
codebook_size=1024, latent_size=256,
|
||||||
connect_list=None,
|
connect_list=('32', '64', '128', '256'),
|
||||||
fix_modules=None):
|
fix_modules=('quantize', 'generator')):
|
||||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||||
|
|
||||||
connect_list = connect_list or ['32', '64', '128', '256']
|
|
||||||
fix_modules = fix_modules or ['quantize', 'generator']
|
|
||||||
|
|
||||||
if fix_modules is not None:
|
if fix_modules is not None:
|
||||||
for module in fix_modules:
|
for module in fix_modules:
|
||||||
for param in getattr(self, module).parameters():
|
for param in getattr(self, module).parameters():
|
||||||
|
@ -5,13 +5,13 @@ import modules.hypernetworks.hypernetwork
|
|||||||
from modules import devices, sd_hijack, shared
|
from modules import devices, sd_hijack, shared
|
||||||
|
|
||||||
not_available = ["hardswish", "multiheadattention"]
|
not_available = ["hardswish", "multiheadattention"]
|
||||||
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available]
|
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||||
|
|
||||||
return gr.Dropdown.update(choices=sorted(shared.hypernetworks.keys())), f"Created: {filename}", ""
|
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(*args):
|
def train_hypernetwork(*args):
|
||||||
|
@ -275,8 +275,8 @@ def model_wrapper(
|
|||||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_kwargs = model_kwargs or []
|
model_kwargs = model_kwargs or {}
|
||||||
classifier_kwargs = classifier_kwargs or []
|
classifier_kwargs = classifier_kwargs or {}
|
||||||
|
|
||||||
def get_model_input_time(t_continuous):
|
def get_model_input_time(t_continuous):
|
||||||
"""
|
"""
|
||||||
|
@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
|
|||||||
script_args = args[script.args_from:script.args_to]
|
script_args = args[script.args_from:script.args_to]
|
||||||
|
|
||||||
process_args = {}
|
process_args = {}
|
||||||
for (name, component), value in zip(script.controls.items(), script_args): # noqa B007
|
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||||
process_args[name] = value
|
process_args[name] = value
|
||||||
|
|
||||||
script.process(pp, **process_args)
|
script.process(pp, **process_args)
|
||||||
|
@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
|
|
||||||
for fixes in self.hijack.fixes:
|
for fixes in self.hijack.fixes:
|
||||||
for position, embedding in fixes: # noqa: B007
|
for _position, embedding in fixes:
|
||||||
used_embeddings[embedding.name] = embedding
|
used_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
z = self.process_tokens(tokens, multipliers)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
|
@ -381,7 +381,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(hypernetworks.keys())}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", hypernetworks]}, refresh=reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
|
@ -166,8 +166,7 @@ class EmbeddingDatabase:
|
|||||||
# textual inversion embeddings
|
# textual inversion embeddings
|
||||||
if 'string_to_param' in data:
|
if 'string_to_param' in data:
|
||||||
param_dict = data['string_to_param']
|
param_dict = data['string_to_param']
|
||||||
if hasattr(param_dict, '_parameters'):
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
param_dict = param_dict._parameters # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
emb = next(iter(param_dict.items()))[1]
|
emb = next(iter(param_dict.items()))[1]
|
||||||
# diffuser concepts
|
# diffuser concepts
|
||||||
|
@ -1230,8 +1230,8 @@ def create_ui():
|
|||||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||||
|
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=list(shared.hypernetworks.keys()))
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
|
||||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks.keys())}, "refresh_train_hypernetwork_name")
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||||
|
Loading…
Reference in New Issue
Block a user