mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
Merge pull request #3264 from Melanpan/tensorboard
Add support for Tensorboard (training)
This commit is contained in:
commit
1849f6eb80
@ -24,7 +24,6 @@ from statistics import stdev, mean
|
||||
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
@ -498,6 +497,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
|
||||
@ -632,6 +634,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
|
||||
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
epoch_num = hypernetwork.step // len(ds)
|
||||
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||
|
||||
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||
"loss": f"{loss_step:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
@ -673,6 +683,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
|
@ -373,6 +373,9 @@ options_templates.update(options_section(('training', "Training"), {
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
|
@ -12,6 +12,7 @@ import csv
|
||||
import safetensors.torch
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
||||
import modules.textual_inversion.dataset
|
||||
@ -294,6 +295,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
**values,
|
||||
})
|
||||
|
||||
def tensorboard_setup(log_directory):
|
||||
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
||||
return SummaryWriter(
|
||||
log_dir=os.path.join(log_directory, "tensorboard"),
|
||||
flush_secs=shared.opts.training_tensorboard_flush_every)
|
||||
|
||||
def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
|
||||
tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
|
||||
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
|
||||
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
|
||||
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||
|
||||
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||
tensorboard_writer.add_scalar(tag=tag,
|
||||
scalar_value=value, global_step=step)
|
||||
|
||||
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||
# Convert a pil image to a torch tensor
|
||||
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||
len(pil_image.getbands()))
|
||||
img_tensor = img_tensor.permute((2, 0, 1))
|
||||
|
||||
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
assert model_name, f"{name} not selected"
|
||||
@ -372,6 +397,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
@ -535,6 +563,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||
|
Loading…
Reference in New Issue
Block a user