mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-01 03:03:00 +08:00
Merge branch 'dev' into multiple_loaded_models
This commit is contained in:
commit
0c9b1e7969
@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
|
|||||||
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
for (var comp of window.gradio_config.components) {
|
||||||
|
if (comp.props.webui_tooltip && comp.props.elem_id) {
|
||||||
|
var elem = gradioApp().getElementById(comp.props.elem_id);
|
||||||
|
if (elem) {
|
||||||
|
elem.title = comp.props.webui_tooltip;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
@ -152,7 +152,11 @@ function submit() {
|
|||||||
showSubmitButtons('txt2img', false);
|
showSubmitButtons('txt2img', false);
|
||||||
|
|
||||||
var id = randomId();
|
var id = randomId();
|
||||||
localStorage.setItem("txt2img_task_id", id);
|
try {
|
||||||
|
localStorage.setItem("txt2img_task_id", id);
|
||||||
|
} catch (e) {
|
||||||
|
console.warn(`Failed to save txt2img task id to localStorage: ${e}`);
|
||||||
|
}
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
showSubmitButtons('txt2img', true);
|
showSubmitButtons('txt2img', true);
|
||||||
@ -171,7 +175,11 @@ function submit_img2img() {
|
|||||||
showSubmitButtons('img2img', false);
|
showSubmitButtons('img2img', false);
|
||||||
|
|
||||||
var id = randomId();
|
var id = randomId();
|
||||||
localStorage.setItem("img2img_task_id", id);
|
try {
|
||||||
|
localStorage.setItem("img2img_task_id", id);
|
||||||
|
} catch (e) {
|
||||||
|
console.warn(`Failed to save img2img task id to localStorage: ${e}`);
|
||||||
|
}
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||||
showSubmitButtons('img2img', true);
|
showSubmitButtons('img2img', true);
|
||||||
@ -191,8 +199,6 @@ function restoreProgressTxt2img() {
|
|||||||
showRestoreProgressButton("txt2img", false);
|
showRestoreProgressButton("txt2img", false);
|
||||||
var id = localStorage.getItem("txt2img_task_id");
|
var id = localStorage.getItem("txt2img_task_id");
|
||||||
|
|
||||||
id = localStorage.getItem("txt2img_task_id");
|
|
||||||
|
|
||||||
if (id) {
|
if (id) {
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
showSubmitButtons('txt2img', true);
|
showSubmitButtons('txt2img', true);
|
||||||
|
@ -3,7 +3,7 @@ import html
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules import shared, progress, errors
|
from modules import shared, progress, errors, devices
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
error_message = f'{type(e).__name__}: {e}'
|
error_message = f'{type(e).__name__}: {e}'
|
||||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
shared.state.skipped = False
|
shared.state.skipped = False
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
|
@ -14,7 +14,8 @@ def record_exception():
|
|||||||
if exception_records and exception_records[-1] == e:
|
if exception_records and exception_records[-1] == e:
|
||||||
return
|
return
|
||||||
|
|
||||||
exception_records.append((e, tb))
|
from modules import sysinfo
|
||||||
|
exception_records.append(sysinfo.format_exception(e, tb))
|
||||||
|
|
||||||
if len(exception_records) > 5:
|
if len(exception_records) > 5:
|
||||||
exception_records.pop(0)
|
exception_records.pop(0)
|
||||||
|
@ -7,7 +7,7 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
|
||||||
from modules.ui_common import plaintext_to_html
|
from modules.ui_common import plaintext_to_html
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -72,7 +72,20 @@ def to_half(tensor, enable):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
|
||||||
|
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
|
||||||
|
if checkpoint_info is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata.update(checkpoint_info.metadata)
|
||||||
|
|
||||||
|
return json.dumps(metadata, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
|
||||||
shared.state.begin(job="model-merge")
|
shared.state.begin(job="model-merge")
|
||||||
|
|
||||||
def fail(message):
|
def fail(message):
|
||||||
@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
shared.state.textinfo = "Saving"
|
shared.state.textinfo = "Saving"
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
metadata = None
|
metadata = {}
|
||||||
|
|
||||||
|
if save_metadata and copy_metadata_fields:
|
||||||
|
if primary_model_info:
|
||||||
|
metadata.update(primary_model_info.metadata)
|
||||||
|
if secondary_model_info:
|
||||||
|
metadata.update(secondary_model_info.metadata)
|
||||||
|
if tertiary_model_info:
|
||||||
|
metadata.update(tertiary_model_info.metadata)
|
||||||
|
|
||||||
if save_metadata:
|
if save_metadata:
|
||||||
metadata = {"format": "pt"}
|
try:
|
||||||
|
metadata.update(json.loads(metadata_json))
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "readin metadata from json")
|
||||||
|
|
||||||
|
metadata["format"] = "pt"
|
||||||
|
|
||||||
|
if save_metadata and add_merge_recipe:
|
||||||
merge_recipe = {
|
merge_recipe = {
|
||||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||||
"primary_model_hash": primary_model_info.sha256,
|
"primary_model_hash": primary_model_info.sha256,
|
||||||
@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
"is_inpainting": result_is_inpainting_model,
|
"is_inpainting": result_is_inpainting_model,
|
||||||
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||||
}
|
}
|
||||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
|
||||||
|
|
||||||
sd_merge_models = {}
|
sd_merge_models = {}
|
||||||
|
|
||||||
@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
if tertiary_model_info:
|
if tertiary_model_info:
|
||||||
add_model_metadata(tertiary_model_info)
|
add_model_metadata(tertiary_model_info)
|
||||||
|
|
||||||
|
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||||
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
|
||||||
else:
|
else:
|
||||||
torch.save(theta_0, output_modelname)
|
torch.save(theta_0, output_modelname)
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ from modules import sd_samplers, images as imgutil
|
|||||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from modules.images import save_image
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
@ -18,6 +17,7 @@ import modules.scripts
|
|||||||
|
|
||||||
|
|
||||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||||
|
output_dir = output_dir.strip()
|
||||||
processing.fix_seed(p)
|
processing.fix_seed(p)
|
||||||
|
|
||||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
||||||
@ -32,11 +32,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
|
|
||||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||||
|
|
||||||
save_normally = output_dir == ''
|
|
||||||
|
|
||||||
p.do_not_save_grid = True
|
|
||||||
p.do_not_save_samples = not save_normally
|
|
||||||
|
|
||||||
state.job_count = len(images) * p.n_iter
|
state.job_count = len(images) * p.n_iter
|
||||||
|
|
||||||
# extract "default" params to use in case getting png info fails
|
# extract "default" params to use in case getting png info fails
|
||||||
@ -111,21 +106,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
|
|
||||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||||
if proc is None:
|
if proc is None:
|
||||||
proc = process_images(p)
|
if output_dir:
|
||||||
|
p.outpath_samples = output_dir
|
||||||
for n, processed_image in enumerate(proc.images):
|
p.override_settings['save_to_dirs'] = False
|
||||||
filename = image_path.stem
|
if p.n_iter > 1 or p.batch_size > 1:
|
||||||
infotext = proc.infotext(p, n)
|
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||||
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
else:
|
||||||
|
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||||
if n > 0:
|
process_images(p)
|
||||||
filename += f"-{n}"
|
|
||||||
|
|
||||||
if not save_normally:
|
|
||||||
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
|
|
||||||
if processed_image.mode == 'RGBA':
|
|
||||||
processed_image = processed_image.convert("RGB")
|
|
||||||
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
|
@ -19,7 +19,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
|||||||
!emphasized: "(" prompt ")"
|
!emphasized: "(" prompt ")"
|
||||||
| "(" prompt ":" prompt ")"
|
| "(" prompt ":" prompt ")"
|
||||||
| "[" prompt "]"
|
| "[" prompt "]"
|
||||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
|
||||||
alternate: "[" prompt ("|" prompt)+ "]"
|
alternate: "[" prompt ("|" prompt)+ "]"
|
||||||
WHITESPACE: /\s+/
|
WHITESPACE: /\s+/
|
||||||
plain: /([^\\\[\]():|]|\\.)+/
|
plain: /([^\\\[\]():|]|\\.)+/
|
||||||
@ -60,11 +60,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-1] = float(tree.children[-1])
|
tree.children[-2] = float(tree.children[-2])
|
||||||
if tree.children[-1] < 1:
|
if tree.children[-2] < 1:
|
||||||
tree.children[-1] *= steps
|
tree.children[-2] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-2] = min(steps, int(tree.children[-2]))
|
||||||
res.append(tree.children[-1])
|
res.append(tree.children[-2])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
res.extend(range(1, steps+1))
|
res.extend(range(1, steps+1))
|
||||||
@ -75,7 +75,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
def at_step(step, tree):
|
def at_step(step, tree):
|
||||||
class AtStep(lark.Transformer):
|
class AtStep(lark.Transformer):
|
||||||
def scheduled(self, args):
|
def scheduled(self, args):
|
||||||
before, after, _, when = args
|
before, after, _, when, _ = args
|
||||||
yield before or () if step <= when else after
|
yield before or () if step <= when else after
|
||||||
def alternate(self, args):
|
def alternate(self, args):
|
||||||
yield next(args[(step - 1)%len(args)])
|
yield next(args[(step - 1)%len(args)])
|
||||||
@ -333,7 +333,7 @@ re_attention = re.compile(r"""
|
|||||||
\\|
|
\\|
|
||||||
\(|
|
\(|
|
||||||
\[|
|
\[|
|
||||||
:([+-]?[.\d]+)\)|
|
:\s*([+-]?[.\d]+)\s*\)|
|
||||||
\)|
|
\)|
|
||||||
]|
|
]|
|
||||||
[^\\()\[\]:]+|
|
[^\\()\[\]:]+|
|
||||||
|
@ -646,6 +646,8 @@ def add_classes_to_gradio_component(comp):
|
|||||||
|
|
||||||
|
|
||||||
def IOComponent_init(self, *args, **kwargs):
|
def IOComponent_init(self, *args, **kwargs):
|
||||||
|
self.webui_tooltip = kwargs.pop('tooltip', None)
|
||||||
|
|
||||||
if scripts_current is not None:
|
if scripts_current is not None:
|
||||||
scripts_current.before_component(self, **kwargs)
|
scripts_current.before_component(self, **kwargs)
|
||||||
|
|
||||||
@ -663,8 +665,20 @@ def IOComponent_init(self, *args, **kwargs):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def Block_get_config(self):
|
||||||
|
config = original_Block_get_config(self)
|
||||||
|
|
||||||
|
webui_tooltip = getattr(self, 'webui_tooltip', None)
|
||||||
|
if webui_tooltip:
|
||||||
|
config["webui_tooltip"] = webui_tooltip
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
original_IOComponent_init = gr.components.IOComponent.__init__
|
||||||
|
original_Block_get_config = gr.components.Block.get_config
|
||||||
gr.components.IOComponent.__init__ = IOComponent_init
|
gr.components.IOComponent.__init__ = IOComponent_init
|
||||||
|
gr.components.Block.get_config = Block_get_config
|
||||||
|
|
||||||
|
|
||||||
def BlockContext_init(self, *args, **kwargs):
|
def BlockContext_init(self, *args, **kwargs):
|
||||||
|
@ -14,7 +14,8 @@ import ldm.modules.midas as midas
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
||||||
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
|
||||||
@ -32,6 +33,8 @@ class CheckpointInfo:
|
|||||||
self.filename = filename
|
self.filename = filename
|
||||||
abspath = os.path.abspath(filename)
|
abspath = os.path.abspath(filename)
|
||||||
|
|
||||||
|
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||||
|
|
||||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||||
elif abspath.startswith(model_path):
|
elif abspath.startswith(model_path):
|
||||||
@ -42,6 +45,19 @@ class CheckpointInfo:
|
|||||||
if name.startswith("\\") or name.startswith("/"):
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
name = name[1:]
|
name = name[1:]
|
||||||
|
|
||||||
|
def read_metadata():
|
||||||
|
metadata = read_metadata_from_safetensors(filename)
|
||||||
|
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
self.metadata = {}
|
||||||
|
if self.is_safetensors:
|
||||||
|
try:
|
||||||
|
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading metadata for {filename}")
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
@ -54,15 +70,6 @@ class CheckpointInfo:
|
|||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||||
|
|
||||||
self.metadata = {}
|
|
||||||
|
|
||||||
_, ext = os.path.splitext(self.filename)
|
|
||||||
if ext.lower() == ".safetensors":
|
|
||||||
try:
|
|
||||||
self.metadata = read_metadata_from_safetensors(filename)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"reading checkpoint metadata: {filename}")
|
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
for id in self.ids:
|
for id in self.ids:
|
||||||
@ -78,7 +85,7 @@ class CheckpointInfo:
|
|||||||
if self.shorthash not in self.ids:
|
if self.shorthash not in self.ids:
|
||||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
||||||
|
|
||||||
checkpoints_list.pop(self.title)
|
checkpoints_list.pop(self.title, None)
|
||||||
self.title = f'{self.name} [{self.shorthash}]'
|
self.title = f'{self.name} [{self.shorthash}]'
|
||||||
self.register()
|
self.register()
|
||||||
|
|
||||||
|
@ -109,11 +109,15 @@ def format_traceback(tb):
|
|||||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||||
|
|
||||||
|
|
||||||
|
def format_exception(e, tb):
|
||||||
|
return {"exception": str(e), "traceback": format_traceback(tb)}
|
||||||
|
|
||||||
|
|
||||||
def get_exceptions():
|
def get_exceptions():
|
||||||
try:
|
try:
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
|
||||||
return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
|
return list(reversed(errors.exception_records))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return str(e)
|
return str(e)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin # noqa: F401
|
from PIL import Image, PngImagePlugin # noqa: F401
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
@ -1083,58 +1083,7 @@ def create_ui():
|
|||||||
outputs=[html, generation_info, html2],
|
outputs=[html, generation_info, html2],
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_interp_description(value):
|
modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
|
||||||
interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
|
|
||||||
interp_descriptions = {
|
|
||||||
"No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
|
|
||||||
"Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
|
|
||||||
"Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
|
|
||||||
}
|
|
||||||
return interp_descriptions[value]
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
|
||||||
with gr.Row().style(equal_height=False):
|
|
||||||
with gr.Column(variant='compact'):
|
|
||||||
interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
|
|
||||||
|
|
||||||
with FormRow(elem_id="modelmerger_models"):
|
|
||||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
|
||||||
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
|
||||||
|
|
||||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
|
||||||
create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
|
|
||||||
|
|
||||||
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
|
||||||
create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
|
|
||||||
|
|
||||||
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
|
||||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
|
||||||
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
|
||||||
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
|
|
||||||
|
|
||||||
with FormRow():
|
|
||||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
|
||||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
|
||||||
save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
|
|
||||||
|
|
||||||
with FormRow():
|
|
||||||
with gr.Column():
|
|
||||||
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
with FormRow():
|
|
||||||
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
|
||||||
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
|
||||||
|
|
||||||
with FormRow():
|
|
||||||
discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
|
||||||
|
|
||||||
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
|
||||||
with gr.Group(elem_id="modelmerger_results_panel"):
|
|
||||||
modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
@ -1464,7 +1413,7 @@ def create_ui():
|
|||||||
(img2img_interface, "img2img", "img2img"),
|
(img2img_interface, "img2img", "img2img"),
|
||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
|
||||||
(train_interface, "Train", "train"),
|
(train_interface, "Train", "train"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1516,49 +1465,11 @@ def create_ui():
|
|||||||
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
|
|
||||||
def modelmerger(*args):
|
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
|
||||||
try:
|
|
||||||
results = modules.extras.run_modelmerger(*args)
|
|
||||||
except Exception as e:
|
|
||||||
errors.report("Error loading/saving model file", exc_info=True)
|
|
||||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
|
||||||
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
|
||||||
return results
|
|
||||||
|
|
||||||
modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
|
|
||||||
modelmerger_merge.click(
|
|
||||||
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
|
||||||
_js='modelmerger',
|
|
||||||
inputs=[
|
|
||||||
dummy_component,
|
|
||||||
primary_model_name,
|
|
||||||
secondary_model_name,
|
|
||||||
tertiary_model_name,
|
|
||||||
interp_method,
|
|
||||||
interp_amount,
|
|
||||||
save_as_half,
|
|
||||||
custom_name,
|
|
||||||
checkpoint_format,
|
|
||||||
config_source,
|
|
||||||
bake_in_vae,
|
|
||||||
discard_weights,
|
|
||||||
save_metadata,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
primary_model_name,
|
|
||||||
secondary_model_name,
|
|
||||||
tertiary_model_name,
|
|
||||||
settings.component_dict['sd_model_checkpoint'],
|
|
||||||
modelmerger_result,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
loadsave.dump_defaults()
|
loadsave.dump_defaults()
|
||||||
demo.ui_loadsave = loadsave
|
demo.ui_loadsave = loadsave
|
||||||
|
|
||||||
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
|
||||||
interp_description.value = update_interp_description(interp_method.value)
|
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
|
124
modules/ui_checkpoint_merger.py
Normal file
124
modules/ui_checkpoint_merger.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import sd_models, sd_vae, errors, extras, call_queue
|
||||||
|
from modules.ui_components import FormRow
|
||||||
|
from modules.ui_common import create_refresh_button
|
||||||
|
|
||||||
|
|
||||||
|
def update_interp_description(value):
|
||||||
|
interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
|
||||||
|
interp_descriptions = {
|
||||||
|
"No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
|
||||||
|
"Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
|
||||||
|
"Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
|
||||||
|
}
|
||||||
|
return interp_descriptions[value]
|
||||||
|
|
||||||
|
|
||||||
|
def modelmerger(*args):
|
||||||
|
try:
|
||||||
|
results = extras.run_modelmerger(*args)
|
||||||
|
except Exception as e:
|
||||||
|
errors.report("Error loading/saving model file", exc_info=True)
|
||||||
|
sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
|
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class UiCheckpointMerger:
|
||||||
|
def __init__(self):
|
||||||
|
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
||||||
|
with gr.Row().style(equal_height=False):
|
||||||
|
with gr.Column(variant='compact'):
|
||||||
|
self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
|
||||||
|
|
||||||
|
with FormRow(elem_id="modelmerger_models"):
|
||||||
|
self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||||
|
create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||||
|
|
||||||
|
self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
||||||
|
create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
|
||||||
|
|
||||||
|
self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
||||||
|
create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
|
||||||
|
|
||||||
|
self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
||||||
|
self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
||||||
|
self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
||||||
|
self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||||
|
self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
with gr.Column():
|
||||||
|
self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with FormRow():
|
||||||
|
self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
||||||
|
create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
|
||||||
|
|
||||||
|
with gr.Accordion("Metadata", open=False) as metadata_editor:
|
||||||
|
with FormRow():
|
||||||
|
self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
|
||||||
|
self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
|
||||||
|
self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
|
||||||
|
|
||||||
|
self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
|
||||||
|
self.read_metadata = gr.Button("Read metadata from selected checkpoints")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||||
|
|
||||||
|
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
||||||
|
with gr.Group(elem_id="modelmerger_results_panel"):
|
||||||
|
self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
||||||
|
|
||||||
|
self.metadata_editor = metadata_editor
|
||||||
|
self.blocks = modelmerger_interface
|
||||||
|
|
||||||
|
def setup_ui(self, dummy_component, sd_model_checkpoint_component):
|
||||||
|
self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
|
||||||
|
|
||||||
|
self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
|
||||||
|
|
||||||
|
self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
|
||||||
|
self.modelmerger_merge.click(
|
||||||
|
fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
||||||
|
_js='modelmerger',
|
||||||
|
inputs=[
|
||||||
|
dummy_component,
|
||||||
|
self.primary_model_name,
|
||||||
|
self.secondary_model_name,
|
||||||
|
self.tertiary_model_name,
|
||||||
|
self.interp_method,
|
||||||
|
self.interp_amount,
|
||||||
|
self.save_as_half,
|
||||||
|
self.custom_name,
|
||||||
|
self.checkpoint_format,
|
||||||
|
self.config_source,
|
||||||
|
self.bake_in_vae,
|
||||||
|
self.discard_weights,
|
||||||
|
self.save_metadata,
|
||||||
|
self.add_merge_recipe,
|
||||||
|
self.copy_metadata_fields,
|
||||||
|
self.metadata_json,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.primary_model_name,
|
||||||
|
self.secondary_model_name,
|
||||||
|
self.tertiary_model_name,
|
||||||
|
sd_model_checkpoint_component,
|
||||||
|
self.modelmerger_result,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
||||||
|
self.interp_description.value = update_interp_description(self.interp_method.value)
|
||||||
|
|
@ -23,6 +23,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
|
"metadata": checkpoint.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,6 +96,7 @@ class UserMetadataEditor:
|
|||||||
|
|
||||||
stats = os.stat(filename)
|
stats = os.stat(filename)
|
||||||
params = [
|
params = [
|
||||||
|
('Filename: ', os.path.basename(filename)),
|
||||||
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
||||||
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
||||||
]
|
]
|
||||||
|
@ -3,6 +3,7 @@ from copy import copy
|
|||||||
from itertools import permutations, chain
|
from itertools import permutations, chain
|
||||||
import random
|
import random
|
||||||
import csv
|
import csv
|
||||||
|
import os.path
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,7 +11,7 @@ import numpy as np
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
|
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors
|
||||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -182,6 +183,8 @@ def do_nothing(p, x, xs):
|
|||||||
def format_nothing(p, opt, x):
|
def format_nothing(p, opt, x):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def format_remove_path(p, opt, x):
|
||||||
|
return os.path.basename(x)
|
||||||
|
|
||||||
def str_permutations(x):
|
def str_permutations(x):
|
||||||
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
||||||
@ -223,7 +226,7 @@ axis_options = [
|
|||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||||
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||||
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
|
||||||
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
|
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||||
@ -648,7 +651,12 @@ class Script(scripts.Script):
|
|||||||
y_opt.apply(pc, y, ys)
|
y_opt.apply(pc, y, ys)
|
||||||
z_opt.apply(pc, z, zs)
|
z_opt.apply(pc, z, zs)
|
||||||
|
|
||||||
res = process_images(pc)
|
try:
|
||||||
|
res = process_images(pc)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "generating image for xyz plot")
|
||||||
|
|
||||||
|
res = Processed(p, [], p.seed, "")
|
||||||
|
|
||||||
# Sets subgrid infotexts
|
# Sets subgrid infotexts
|
||||||
subgrid_index = 1 + iz
|
subgrid_index = 1 + iz
|
||||||
|
Loading…
Reference in New Issue
Block a user