mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
Learning rate sched syntax support for grad clipping
This commit is contained in:
parent
1618df41ba
commit
16451ca573
@ -383,11 +383,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
ititial_step = hypernetwork.step or 0
|
ititial_step = hypernetwork.step or 0
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
clip_grad_mode_value = clip_grad_mode == "value"
|
clip_grad_mode_value = clip_grad_mode == "value"
|
||||||
clip_grad_mode_norm = clip_grad_mode == "norm"
|
clip_grad_mode_norm = clip_grad_mode == "norm"
|
||||||
|
clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
|
||||||
|
if clip_grad_enabled:
|
||||||
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||||
|
|
||||||
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||||
|
|
||||||
@ -407,6 +411,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if clip_grad_enabled:
|
||||||
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||||
@ -430,9 +437,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
||||||
|
|
||||||
if clip_grad_mode_value:
|
if clip_grad_mode_value:
|
||||||
torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_value)
|
torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
|
||||||
elif clip_grad_mode_norm:
|
elif clip_grad_mode_norm:
|
||||||
torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_value)
|
torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
@ -51,14 +51,19 @@ class LearnRateScheduler:
|
|||||||
|
|
||||||
self.finished = False
|
self.finished = False
|
||||||
|
|
||||||
def apply(self, optimizer, step_number):
|
def step(self, step_number):
|
||||||
if step_number <= self.end_step:
|
if step_number <= self.end_step:
|
||||||
return
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||||
except Exception:
|
except StopIteration:
|
||||||
self.finished = True
|
self.finished = True
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply(self, optimizer, step_number):
|
||||||
|
if not self.step(step_number):
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -255,9 +255,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
ititial_step = embedding.step or 0
|
ititial_step = embedding.step or 0
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
clip_grad_mode_value = clip_grad_mode == "value"
|
clip_grad_mode_value = clip_grad_mode == "value"
|
||||||
clip_grad_mode_norm = clip_grad_mode == "norm"
|
clip_grad_mode_norm = clip_grad_mode == "norm"
|
||||||
|
clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
|
||||||
|
if clip_grad_enabled:
|
||||||
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||||
@ -273,6 +276,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if clip_grad_enabled:
|
||||||
|
clip_grad_sched.step(embedding.step)
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
c = cond_model([entry.cond_text for entry in entries])
|
c = cond_model([entry.cond_text for entry in entries])
|
||||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
@ -285,9 +291,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if clip_grad_mode_value:
|
if clip_grad_mode_value:
|
||||||
torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value)
|
torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
|
||||||
elif clip_grad_mode_norm:
|
elif clip_grad_mode_norm:
|
||||||
torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value)
|
torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
@ -1305,7 +1305,9 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
|
||||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||||
|
with gr.Row():
|
||||||
|
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||||
|
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="1.0", show_label=False)
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
@ -1313,9 +1315,6 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
with gr.Row():
|
|
||||||
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
|
||||||
clip_grad_value = gr.Number(value=1.0, show_label=False)
|
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user