mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
added VAE selection to checkpoint user metadata
This commit is contained in:
parent
31a9966b9d
commit
4560176640
@ -1,3 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
@ -177,3 +179,20 @@ def parse_prompts(prompts):
|
||||
|
||||
return res, extra_data
|
||||
|
||||
|
||||
def get_user_metadata(filename):
|
||||
if filename is None:
|
||||
return {}
|
||||
|
||||
basename, ext = os.path.splitext(filename)
|
||||
metadata_filename = basename + '.json'
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if os.path.isfile(metadata_filename):
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||
|
||||
return metadata
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import collections
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
|
||||
@ -16,6 +16,7 @@ checkpoint_info = None
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
return base_vae
|
||||
@ -100,6 +101,16 @@ def resolve_vae(checkpoint_file):
|
||||
if shared.cmd_opts.vae_path is not None:
|
||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||
|
||||
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||
vae_metadata = metadata.get("vae", None)
|
||||
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||
if vae_metadata == "None":
|
||||
return None, None
|
||||
|
||||
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||
if vae_from_metadata is not None:
|
||||
return vae_from_metadata, "from user metadata"
|
||||
|
||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||
|
||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||
|
@ -2,7 +2,7 @@ import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
from modules import shared, ui_extra_networks_user_metadata, errors
|
||||
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||
from modules.ui import up_down_symbol
|
||||
import gradio as gr
|
||||
@ -101,16 +101,7 @@ class ExtraNetworksPage:
|
||||
|
||||
def read_user_metadata(self, item):
|
||||
filename = item.get("filename", None)
|
||||
basename, ext = os.path.splitext(filename)
|
||||
metadata_filename = basename + '.json'
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if os.path.isfile(metadata_filename):
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||
metadata = extra_networks.get_user_metadata(filename)
|
||||
|
||||
desc = metadata.get("description", None)
|
||||
if desc is not None:
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
|
||||
from modules import shared, ui_extra_networks, sd_models
|
||||
from modules.ui_extra_networks import quote_js
|
||||
from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
|
||||
|
||||
|
||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
def allowed_directories_for_previews(self):
|
||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||
|
||||
def create_user_metadata_editor(self, ui, tabname):
|
||||
return CheckpointUserMetadataEditor(ui, tabname, self)
|
||||
|
60
modules/ui_extra_networks_checkpoints_user_metadata.py
Normal file
60
modules/ui_extra_networks_checkpoints_user_metadata.py
Normal file
@ -0,0 +1,60 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import ui_extra_networks_user_metadata, sd_vae
|
||||
from modules.ui_common import create_refresh_button
|
||||
|
||||
|
||||
class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
|
||||
def __init__(self, ui, tabname, page):
|
||||
super().__init__(ui, tabname, page)
|
||||
|
||||
self.select_vae = None
|
||||
|
||||
def save_user_metadata(self, name, desc, notes, vae):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
user_metadata["description"] = desc
|
||||
user_metadata["notes"] = notes
|
||||
user_metadata["vae"] = vae
|
||||
|
||||
self.write_user_metadata(name, user_metadata)
|
||||
|
||||
def put_values_into_components(self, name):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
values = super().put_values_into_components(name)
|
||||
|
||||
return [
|
||||
*values[0:5],
|
||||
user_metadata.get('vae', ''),
|
||||
]
|
||||
|
||||
def create_editor(self):
|
||||
self.create_default_editor_elems()
|
||||
|
||||
with gr.Row():
|
||||
self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
|
||||
create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
|
||||
|
||||
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||
|
||||
self.create_default_buttons()
|
||||
|
||||
viewed_components = [
|
||||
self.edit_name,
|
||||
self.edit_description,
|
||||
self.html_filedata,
|
||||
self.html_preview,
|
||||
self.edit_notes,
|
||||
self.select_vae,
|
||||
]
|
||||
|
||||
self.button_edit\
|
||||
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
|
||||
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
||||
|
||||
edited_components = [
|
||||
self.edit_description,
|
||||
self.edit_notes,
|
||||
self.select_vae,
|
||||
]
|
||||
|
||||
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
Loading…
Reference in New Issue
Block a user