Add support for Tensorboard for training embeddings

This commit is contained in:
Melan 2022-10-20 16:26:16 +02:00
parent 7f8ab1ee8f
commit 29e74d6e71
2 changed files with 34 additions and 1 deletions

View File

@ -254,6 +254,10 @@ options_templates.update(options_section(('training', "Training"), {
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {

View File

@ -7,9 +7,11 @@ import tqdm
import html import html
import datetime import datetime
import csv import csv
import numpy as np
import torchvision.transforms
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
@ -199,6 +201,19 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values, **values,
}) })
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
if shared.opts.training_enable_tensorboard:
tensorboard_writer.add_scalar(tag=tag,
scalar_value=value, global_step=step)
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
if shared.opts.training_enable_tensorboard:
# Convert a pil image to a torch tensor
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands()))
img_tensor = img_tensor.permute((2, 0, 1))
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
@ -252,6 +267,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
if shared.opts.training_enable_tensorboard:
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
tensorboard_writer = SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
flush_secs=shared.opts.training_tensorboard_flush_every)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar: for i, entries in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
@ -270,6 +291,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
del x del x
losses[embedding.step % losses.shape[0]] = loss.item() losses[embedding.step % losses.shape[0]] = loss.item()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
@ -285,6 +307,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding.save(last_saved_file) embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
if shared.opts.training_enable_tensorboard:
tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step)
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step)
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step)
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step)
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}", "loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
@ -349,6 +377,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
image.save(last_saved_image) image.save(last_saved_image)
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"