mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
remove duplicate code for log loss, add step, make it read from options rather than gradio input
This commit is contained in:
parent
326fe7d44b
commit
03d62538ae
@ -15,6 +15,7 @@ import torch
|
|||||||
from torch import einsum
|
from torch import einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
from modules.textual_inversion import textual_inversion
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
|
|
||||||
@ -210,7 +211,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
|||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
@ -263,19 +264,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
|||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||||
hypernetwork.save(last_saved_file)
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
|
|
||||||
|
|
||||||
with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
|
|
||||||
|
|
||||||
csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
|
|
||||||
|
|
||||||
if write_csv_header:
|
|
||||||
csv_writer.writeheader()
|
|
||||||
|
|
||||||
csv_writer.writerow({"step": hypernetwork.step,
|
|
||||||
"loss": f"{losses.mean():.7f}",
|
"loss": f"{losses.mean():.7f}",
|
||||||
"learn_rate": scheduler.learn_rate})
|
"learn_rate": scheduler.learn_rate
|
||||||
|
})
|
||||||
|
|
||||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||||
|
@ -236,7 +236,8 @@ options_templates.update(options_section(('training', "Training"), {
|
|||||||
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
|
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
|
@ -173,6 +173,32 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
|
if shared.opts.training_write_csv_every == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if step % shared.opts.training_write_csv_every != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
||||||
|
|
||||||
|
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
|
||||||
|
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
|
||||||
|
|
||||||
|
if write_csv_header:
|
||||||
|
csv_writer.writeheader()
|
||||||
|
|
||||||
|
epoch = step // epoch_len
|
||||||
|
epoch_step = step - epoch * epoch_len
|
||||||
|
|
||||||
|
csv_writer.writerow({
|
||||||
|
"step": step + 1,
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"epoch_step": epoch_step + 1,
|
||||||
|
**values,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
@ -257,20 +283,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
|||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
embedding.save(last_saved_file)
|
embedding.save(last_saved_file)
|
||||||
|
|
||||||
if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
|
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
||||||
write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
|
|
||||||
|
|
||||||
with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
|
|
||||||
|
|
||||||
csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"])
|
|
||||||
|
|
||||||
if write_csv_header:
|
|
||||||
csv_writer.writeheader()
|
|
||||||
|
|
||||||
csv_writer.writerow({"epoch": epoch_num + 1,
|
|
||||||
"epoch_step": epoch_step - 1,
|
|
||||||
"loss": f"{losses.mean():.7f}",
|
"loss": f"{losses.mean():.7f}",
|
||||||
"learn_rate": scheduler.learn_rate})
|
"learn_rate": scheduler.learn_rate
|
||||||
|
})
|
||||||
|
|
||||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
@ -1172,7 +1172,6 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
write_csv_every = gr.Number(label='Save an csv containing the loss to log directory every N steps, 0 to disable', value=500, precision=0)
|
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||||
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||||
@ -1251,7 +1250,6 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
steps,
|
steps,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
write_csv_every,
|
|
||||||
template_file,
|
template_file,
|
||||||
save_image_with_stored_embedding,
|
save_image_with_stored_embedding,
|
||||||
preview_from_txt2img,
|
preview_from_txt2img,
|
||||||
@ -1274,7 +1272,6 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
steps,
|
steps,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
write_csv_every,
|
|
||||||
template_file,
|
template_file,
|
||||||
preview_from_txt2img,
|
preview_from_txt2img,
|
||||||
*txt2img_preview_params,
|
*txt2img_preview_params,
|
||||||
|
Loading…
Reference in New Issue
Block a user