mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-28 11:04:54 +08:00
change hypernets to use sha256 hashes
This commit is contained in:
parent
a95f135308
commit
f9ac3352cb
@ -12,7 +12,7 @@ import torch
|
|||||||
import tqdm
|
import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from modules import devices, processing, sd_models, shared, sd_samplers
|
from modules import devices, processing, sd_models, shared, sd_samplers, hashes
|
||||||
from modules.textual_inversion import textual_inversion, logging
|
from modules.textual_inversion import textual_inversion, logging
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -225,7 +225,7 @@ class Hypernetwork:
|
|||||||
|
|
||||||
torch.save(state_dict, filename)
|
torch.save(state_dict, filename)
|
||||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
||||||
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
optimizer_saved_dict['hash'] = self.shorthash()
|
||||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||||
|
|
||||||
@ -237,32 +237,33 @@ class Hypernetwork:
|
|||||||
state_dict = torch.load(filename, map_location='cpu')
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
|
||||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
print(self.layer_structure)
|
self.optional_info = state_dict.get('optional_info', None)
|
||||||
optional_info = state_dict.get('optional_info', None)
|
|
||||||
if optional_info is not None:
|
|
||||||
print(f"INFO:\n {optional_info}\n")
|
|
||||||
self.optional_info = optional_info
|
|
||||||
self.activation_func = state_dict.get('activation_func', None)
|
self.activation_func = state_dict.get('activation_func', None)
|
||||||
print(f"Activation function is {self.activation_func}")
|
|
||||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||||
print(f"Weight initialization is {self.weight_init}")
|
|
||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
|
||||||
self.dropout_structure = state_dict.get('dropout_structure', None)
|
self.dropout_structure = state_dict.get('dropout_structure', None)
|
||||||
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
||||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
|
||||||
self.activate_output = state_dict.get('activate_output', True)
|
self.activate_output = state_dict.get('activate_output', True)
|
||||||
print(f"Activate last layer is set to {self.activate_output}")
|
|
||||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||||
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
||||||
if self.dropout_structure is None:
|
if self.dropout_structure is None:
|
||||||
print("Using previous dropout structure")
|
|
||||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||||
print(f"Dropout structure is set to {self.dropout_structure}")
|
|
||||||
|
|
||||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
if shared.opts.print_hypernet_extra:
|
||||||
|
if self.optional_info is not None:
|
||||||
|
print(f" INFO:\n {self.optional_info}\n")
|
||||||
|
|
||||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
print(f" Layer structure: {self.layer_structure}")
|
||||||
|
print(f" Activation function: {self.activation_func}")
|
||||||
|
print(f" Weight initialization: {self.weight_init}")
|
||||||
|
print(f" Layer norm: {self.add_layer_norm}")
|
||||||
|
print(f" Dropout usage: {self.use_dropout}" )
|
||||||
|
print(f" Activate last layer: {self.activate_output}")
|
||||||
|
print(f" Dropout structure: {self.dropout_structure}")
|
||||||
|
|
||||||
|
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||||
|
|
||||||
|
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
||||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||||
else:
|
else:
|
||||||
self.optimizer_state_dict = None
|
self.optimizer_state_dict = None
|
||||||
@ -289,6 +290,11 @@ class Hypernetwork:
|
|||||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
|
def shorthash(self):
|
||||||
|
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||||
|
|
||||||
|
return sha256[0:10]
|
||||||
|
|
||||||
|
|
||||||
def list_hypernetworks(path):
|
def list_hypernetworks(path):
|
||||||
res = {}
|
res = {}
|
||||||
@ -296,7 +302,7 @@ def list_hypernetworks(path):
|
|||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
# Prevent a hypothetical "None.pt" from being listed.
|
# Prevent a hypothetical "None.pt" from being listed.
|
||||||
if name != "None":
|
if name != "None":
|
||||||
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
res[name] = filename
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
|
"Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
|
||||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
|
@ -125,7 +125,7 @@ def list_models():
|
|||||||
def get_closet_checkpoint_match(search_string):
|
def get_closet_checkpoint_match(search_string):
|
||||||
checkpoint_info = checkpoint_alisases.get(search_string, None)
|
checkpoint_info = checkpoint_alisases.get(search_string, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return
|
return checkpoint_info
|
||||||
|
|
||||||
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
|
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
|
||||||
if found:
|
if found:
|
||||||
|
@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
|
Loading…
Reference in New Issue
Block a user