mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 07:05:06 +08:00
a4a5475cfa
Implements variable dropout rate from #4549 Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training. Changes function name to match torch.nn.module standard Fixes RNG reset issue when generating previews by restoring RNG state
42 lines
1.6 KiB
Python
42 lines
1.6 KiB
Python
import html
|
|
import os
|
|
import re
|
|
|
|
import gradio as gr
|
|
import modules.hypernetworks.hypernetwork
|
|
from modules import devices, sd_hijack, shared
|
|
|
|
not_available = ["hardswish", "multiheadattention"]
|
|
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
|
|
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
|
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
|
|
|
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
|
|
|
|
|
def train_hypernetwork(*args):
|
|
|
|
initial_hypernetwork = shared.loaded_hypernetwork
|
|
|
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
|
|
|
try:
|
|
sd_hijack.undo_optimizations()
|
|
|
|
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
|
|
|
|
res = f"""
|
|
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
|
Hypernetwork saved to {html.escape(filename)}
|
|
"""
|
|
return res, ""
|
|
except Exception:
|
|
raise
|
|
finally:
|
|
shared.loaded_hypernetwork = initial_hypernetwork
|
|
shared.sd_model.cond_stage_model.to(devices.device)
|
|
shared.sd_model.first_stage_model.to(devices.device)
|
|
sd_hijack.apply_optimizations()
|
|
|