mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 07:05:06 +08:00
add option to select hypernetwork modules when creating
This commit is contained in:
parent
5ba23cb41f
commit
d682444ecc
@ -42,7 +42,7 @@ class Hypernetwork:
|
|||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None, enable_sizes=None):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
@ -50,7 +50,7 @@ class Hypernetwork:
|
|||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
|
|
||||||
for size in [320, 640, 768, 1280]:
|
for size in enable_sizes or [320, 640, 768, 1280]:
|
||||||
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
|
@ -9,11 +9,11 @@ from modules import sd_hijack, shared
|
|||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name):
|
def create_hypernetwork(name, enable_sizes):
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name)
|
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
|
||||||
hypernet.save(fn)
|
hypernet.save(fn)
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
@ -1037,6 +1037,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
|
||||||
|
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
@ -1114,6 +1115,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
fn=modules.hypernetworks.ui.create_hypernetwork,
|
fn=modules.hypernetworks.ui.create_hypernetwork,
|
||||||
inputs=[
|
inputs=[
|
||||||
new_hypernetwork_name,
|
new_hypernetwork_name,
|
||||||
|
new_hypernetwork_sizes,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user