mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
add hypernetwork multipliers
This commit is contained in:
parent
a10b0e11fc
commit
354ef0da3b
@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
|
||||
def __init__(self, dim, state_dict=None):
|
||||
super().__init__()
|
||||
|
||||
@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module):
|
||||
self.to(devices.device)
|
||||
|
||||
def forward(self, x):
|
||||
return x + (self.linear2(self.linear1(x)))
|
||||
return x + (self.linear2(self.linear1(x))) * self.multiplier
|
||||
|
||||
|
||||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
|
||||
class Hypernetwork:
|
||||
|
@ -238,7 +238,8 @@ options_templates.update(options_section(('training', "Training"), {
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
@ -348,6 +349,8 @@ class Options:
|
||||
item = self.data_labels.get(key)
|
||||
item.onchange = func
|
||||
|
||||
func()
|
||||
|
||||
def dumpjson(self):
|
||||
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
||||
return json.dumps(d)
|
||||
|
@ -1244,7 +1244,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
def refresh():
|
||||
info.refresh()
|
||||
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
|
||||
res.choices = refreshed_args["choices"]
|
||||
|
||||
for k, v in refreshed_args.items():
|
||||
setattr(res, k, v)
|
||||
|
||||
return gr.update(**(refreshed_args or {}))
|
||||
|
||||
refresh_button.click(
|
||||
|
@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs):
|
||||
hypernetwork.load_hypernetwork(name)
|
||||
|
||||
|
||||
def apply_hypernetwork_strength(p, x, xs):
|
||||
hypernetwork.apply_strength(x)
|
||||
|
||||
|
||||
def confirm_hypernetworks(p, xs):
|
||||
for x in xs:
|
||||
if x.lower() in ["", "none"]:
|
||||
@ -165,6 +169,7 @@ axis_options = [
|
||||
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
|
||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
|
||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
|
||||
AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
|
||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
|
||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
|
||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
|
||||
@ -250,7 +255,7 @@ class Script(scripts.Script):
|
||||
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
|
||||
|
||||
draw_legend = gr.Checkbox(label='Draw legend', value=True)
|
||||
include_lone_images = gr.Checkbox(label='Include Separate Images', value=True)
|
||||
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
|
||||
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
|
||||
|
||||
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
|
||||
@ -377,6 +382,8 @@ class Script(scripts.Script):
|
||||
modules.sd_models.reload_model_weights(shared.sd_model)
|
||||
|
||||
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
||||
hypernetwork.apply_strength()
|
||||
|
||||
|
||||
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
|
||||
|
||||
|
@ -522,6 +522,9 @@ canvas[key="mask"] {
|
||||
z-index: 200;
|
||||
width: 8em;
|
||||
}
|
||||
#quicksettings .gr-box > div > div > input.gr-text-input {
|
||||
top: -1.12em;
|
||||
}
|
||||
|
||||
.row.gr-compact{
|
||||
overflow: visible;
|
||||
|
2
webui.py
2
webui.py
@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
|
||||
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||
|
||||
|
||||
def initialize():
|
||||
modelloader.cleanup_models()
|
||||
modules.sd_models.setup_model()
|
||||
@ -86,6 +85,7 @@ def initialize():
|
||||
shared.sd_model = modules.sd_models.load_model()
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||
|
||||
|
||||
def webui():
|
||||
|
Loading…
Reference in New Issue
Block a user