Make Tensorboard a late import (it was implicitly installed by basicsr)

This commit is contained in:
Aarni Koskela 2023-12-30 19:44:05 +02:00
parent c9174253fb
commit 1465dab715

View File

@ -11,7 +11,6 @@ import safetensors.torch
import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
def tensorboard_setup(log_directory):
from torch.utils.tensorboard import SummaryWriter
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
return SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
tensorboard_writer = None
if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory)
try:
tensorboard_writer = tensorboard_setup(log_directory)
except ImportError:
errors.report("Error initializing tensorboard", exc_info=True)
pin_memory = shared.opts.pin_memory
@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
if tensorboard_writer and shared.opts.training_tensorboard_save_images:
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: