mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 07:05:06 +08:00
Merge branch 'dev' into efficient-vae-methods
This commit is contained in:
commit
9ac2989edd
@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
|
|||||||
train_hypernetwork: 'OPTION',
|
train_hypernetwork: 'OPTION',
|
||||||
txt2img_styles: 'OPTION',
|
txt2img_styles: 'OPTION',
|
||||||
img2img_styles: 'OPTION',
|
img2img_styles: 'OPTION',
|
||||||
setting_random_artist_categories: 'SPAN',
|
setting_random_artist_categories: 'OPTION',
|
||||||
setting_face_restoration_model: 'SPAN',
|
setting_face_restoration_model: 'OPTION',
|
||||||
setting_realesrgan_enabled_models: 'SPAN',
|
setting_realesrgan_enabled_models: 'OPTION',
|
||||||
extras_upscaler_1: 'SPAN',
|
extras_upscaler_1: 'OPTION',
|
||||||
extras_upscaler_2: 'SPAN',
|
extras_upscaler_2: 'OPTION',
|
||||||
};
|
};
|
||||||
|
|
||||||
var re_num = /^[.\d]+$/;
|
var re_num = /^[.\d]+$/;
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
@ -177,3 +179,20 @@ def parse_prompts(prompts):
|
|||||||
|
|
||||||
return res, extra_data
|
return res, extra_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_metadata(filename):
|
||||||
|
if filename is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
basename, ext = os.path.splitext(filename)
|
||||||
|
metadata_filename = basename + '.json'
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
try:
|
||||||
|
if os.path.isfile(metadata_filename):
|
||||||
|
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||||
|
metadata = json.load(file)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
@ -3,7 +3,7 @@ from contextlib import closing
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers, images as imgutil
|
from modules import sd_samplers, images as imgutil
|
||||||
@ -129,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
mask = None
|
mask = None
|
||||||
elif mode == 2: # inpaint
|
elif mode == 2: # inpaint
|
||||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||||
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
|
|
||||||
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
elif mode == 3: # inpaint sketch
|
elif mode == 3: # inpaint sketch
|
||||||
image = inpaint_color_sketch
|
image = inpaint_color_sketch
|
||||||
|
@ -20,7 +20,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
|||||||
| "(" prompt ":" prompt ")"
|
| "(" prompt ":" prompt ")"
|
||||||
| "[" prompt "]"
|
| "[" prompt "]"
|
||||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
|
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
|
||||||
alternate: "[" prompt ("|" prompt)+ "]"
|
alternate: "[" prompt ("|" [prompt])+ "]"
|
||||||
WHITESPACE: /\s+/
|
WHITESPACE: /\s+/
|
||||||
plain: /([^\\\[\]():|]|\\.)+/
|
plain: /([^\\\[\]():|]|\\.)+/
|
||||||
%import common.SIGNED_NUMBER -> NUMBER
|
%import common.SIGNED_NUMBER -> NUMBER
|
||||||
@ -53,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||||
>>> g("[a|(b:1.1)]")
|
>>> g("[a|(b:1.1)]")
|
||||||
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||||
|
>>> g("[fe|]male")
|
||||||
|
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
|
>>> g("[fe|||]male")
|
||||||
|
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
@ -78,7 +82,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
before, after, _, when, _ = args
|
before, after, _, when, _ = args
|
||||||
yield before or () if step <= when else after
|
yield before or () if step <= when else after
|
||||||
def alternate(self, args):
|
def alternate(self, args):
|
||||||
yield next(args[(step - 1)%len(args)])
|
args = ["" if not arg else arg for arg in args]
|
||||||
|
yield args[(step - 1) % len(args)]
|
||||||
def start(self, args):
|
def start(self, args):
|
||||||
def flatten(x):
|
def flatten(x):
|
||||||
if type(x) == str:
|
if type(x) == str:
|
||||||
|
@ -303,12 +303,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
sd_models_xl.extend_sdxl(model)
|
sd_models_xl.extend_sdxl(model)
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
del state_dict
|
|
||||||
timer.record("apply weights to model")
|
timer.record("apply weights to model")
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0:
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
# cache newly loaded model
|
# cache newly loaded model
|
||||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
checkpoints_loaded[checkpoint_info] = state_dict
|
||||||
|
|
||||||
|
del state_dict
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
@ -16,6 +16,7 @@ checkpoint_info = None
|
|||||||
|
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
def get_base_vae(model):
|
def get_base_vae(model):
|
||||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||||
return base_vae
|
return base_vae
|
||||||
@ -100,6 +101,16 @@ def resolve_vae(checkpoint_file):
|
|||||||
if shared.cmd_opts.vae_path is not None:
|
if shared.cmd_opts.vae_path is not None:
|
||||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||||
|
|
||||||
|
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||||
|
vae_metadata = metadata.get("vae", None)
|
||||||
|
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||||
|
if vae_metadata == "None":
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||||
|
if vae_from_metadata is not None:
|
||||||
|
return vae_from_metadata, "from user metadata"
|
||||||
|
|
||||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||||
|
|
||||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||||
|
@ -585,14 +585,15 @@ def create_ui():
|
|||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
(batch_size, "Batch size"),
|
(batch_size, "Batch size"),
|
||||||
|
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||||
(subseed, "Variation seed"),
|
(subseed, "Variation seed"),
|
||||||
(subseed_strength, "Variation seed strength"),
|
(subseed_strength, "Variation seed strength"),
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
(enable_hr, lambda d: "Denoising strength" in d),
|
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
|
||||||
(hr_scale, "Hires upscale"),
|
(hr_scale, "Hires upscale"),
|
||||||
(hr_upscaler, "Hires upscaler"),
|
(hr_upscaler, "Hires upscaler"),
|
||||||
(hr_second_pass_steps, "Hires steps"),
|
(hr_second_pass_steps, "Hires steps"),
|
||||||
@ -666,11 +667,11 @@ def create_ui():
|
|||||||
add_copy_image_controls('sketch', sketch)
|
add_copy_image_controls('sketch', sketch)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height)
|
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff')
|
||||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height)
|
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff')
|
||||||
inpaint_color_sketch_orig = gr.State(None)
|
inpaint_color_sketch_orig = gr.State(None)
|
||||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||||
|
|
||||||
@ -972,6 +973,7 @@ def create_ui():
|
|||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
(batch_size, "Batch size"),
|
(batch_size, "Batch size"),
|
||||||
|
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||||
(subseed, "Variation seed"),
|
(subseed, "Variation seed"),
|
||||||
(subseed_strength, "Variation seed strength"),
|
(subseed_strength, "Variation seed strength"),
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
|
@ -239,13 +239,13 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
|
|||||||
for comp in refresh_components:
|
for comp in refresh_components:
|
||||||
setattr(comp, k, v)
|
setattr(comp, k, v)
|
||||||
|
|
||||||
return [gr.update(**(args or {})) for _ in refresh_components]
|
return (gr.update(**(args or {})) for _ in refresh_components) if len(refresh_components) > 1 else gr.update(**(args or {}))
|
||||||
|
|
||||||
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
|
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
|
||||||
refresh_button.click(
|
refresh_button.click(
|
||||||
fn=refresh,
|
fn=refresh,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[*refresh_components]
|
outputs=refresh_components
|
||||||
)
|
)
|
||||||
return refresh_button
|
return refresh_button
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import os.path
|
|||||||
import urllib.parse
|
import urllib.parse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from modules import shared, ui_extra_networks_user_metadata, errors
|
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
||||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||||
from modules.ui import up_down_symbol
|
from modules.ui import up_down_symbol
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -101,16 +101,7 @@ class ExtraNetworksPage:
|
|||||||
|
|
||||||
def read_user_metadata(self, item):
|
def read_user_metadata(self, item):
|
||||||
filename = item.get("filename", None)
|
filename = item.get("filename", None)
|
||||||
basename, ext = os.path.splitext(filename)
|
metadata = extra_networks.get_user_metadata(filename)
|
||||||
metadata_filename = basename + '.json'
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
try:
|
|
||||||
if os.path.isfile(metadata_filename):
|
|
||||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
|
||||||
metadata = json.load(file)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
|
||||||
|
|
||||||
desc = metadata.get("description", None)
|
desc = metadata.get("description", None)
|
||||||
if desc is not None:
|
if desc is not None:
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
|
|
||||||
from modules import shared, ui_extra_networks, sd_models
|
from modules import shared, ui_extra_networks, sd_models
|
||||||
from modules.ui_extra_networks import quote_js
|
from modules.ui_extra_networks import quote_js
|
||||||
|
from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||||
@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||||
|
|
||||||
|
def create_user_metadata_editor(self, ui, tabname):
|
||||||
|
return CheckpointUserMetadataEditor(ui, tabname, self)
|
||||||
|
60
modules/ui_extra_networks_checkpoints_user_metadata.py
Normal file
60
modules/ui_extra_networks_checkpoints_user_metadata.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import ui_extra_networks_user_metadata, sd_vae
|
||||||
|
from modules.ui_common import create_refresh_button
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
|
||||||
|
def __init__(self, ui, tabname, page):
|
||||||
|
super().__init__(ui, tabname, page)
|
||||||
|
|
||||||
|
self.select_vae = None
|
||||||
|
|
||||||
|
def save_user_metadata(self, name, desc, notes, vae):
|
||||||
|
user_metadata = self.get_user_metadata(name)
|
||||||
|
user_metadata["description"] = desc
|
||||||
|
user_metadata["notes"] = notes
|
||||||
|
user_metadata["vae"] = vae
|
||||||
|
|
||||||
|
self.write_user_metadata(name, user_metadata)
|
||||||
|
|
||||||
|
def put_values_into_components(self, name):
|
||||||
|
user_metadata = self.get_user_metadata(name)
|
||||||
|
values = super().put_values_into_components(name)
|
||||||
|
|
||||||
|
return [
|
||||||
|
*values[0:5],
|
||||||
|
user_metadata.get('vae', ''),
|
||||||
|
]
|
||||||
|
|
||||||
|
def create_editor(self):
|
||||||
|
self.create_default_editor_elems()
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
|
||||||
|
create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
|
||||||
|
|
||||||
|
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||||
|
|
||||||
|
self.create_default_buttons()
|
||||||
|
|
||||||
|
viewed_components = [
|
||||||
|
self.edit_name,
|
||||||
|
self.edit_description,
|
||||||
|
self.html_filedata,
|
||||||
|
self.html_preview,
|
||||||
|
self.edit_notes,
|
||||||
|
self.select_vae,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.button_edit\
|
||||||
|
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
|
||||||
|
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
||||||
|
|
||||||
|
edited_components = [
|
||||||
|
self.edit_description,
|
||||||
|
self.edit_notes,
|
||||||
|
self.select_vae,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
@ -158,7 +158,7 @@ class UiSettings:
|
|||||||
loadsave.create_ui()
|
loadsave.create_ui()
|
||||||
|
|
||||||
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
|
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
|
||||||
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo" target="_blank">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
|
Loading…
Reference in New Issue
Block a user