fix bugs and optimizations

This commit is contained in:
AngelBottomless 2022-10-21 01:00:41 +09:00 committed by GitHub
parent a71e021236
commit 108be15500
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values # if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue # if i < 1: continue
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
if activation_func in HypernetworkModule.activation_dict: if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]()) linears.append(HypernetworkModule.activation_dict[activation_func]())
else: else:
print("Invalid key {} encountered as activation function!".format(activation_func)) print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout: # if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3)) # linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears) self.linear = torch.nn.Sequential(*linears)
@ -115,11 +115,24 @@ class Hypernetwork:
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
layer.train()
res += layer.trainables() res += layer.trainables()
return res return res
def eval(self):
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
for items in self.weights():
items.requires_grad = False
def train(self):
for k, layers in self.layers.items():
for layer in layers:
layer.train()
for items in self.weights():
items.requires_grad = True
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
@ -290,10 +303,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
losses = torch.zeros((32,)) losses = torch.zeros((32,))
last_saved_file = "<none>" last_saved_file = "<none>"
@ -304,10 +313,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate)
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
hypernetwork.train()
for i, entries in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
@ -328,8 +337,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item() losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad() optimizer.zero_grad(set_to_none=True)
loss.backward() loss.backward()
del loss
optimizer.step() optimizer.step()
mean_loss = losses.mean() mean_loss = losses.mean()
if torch.isnan(mean_loss): if torch.isnan(mean_loss):
@ -346,44 +356,47 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
}) })
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
torch.cuda.empty_cache()
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
with torch.no_grad():
hypernetwork.eval()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
optimizer.zero_grad() p = processing.StableDiffusionProcessingTxt2Img(
shared.sd_model.cond_stage_model.to(devices.device) sd_model=shared.sd_model,
shared.sd_model.first_stage_model.to(devices.device) do_not_save_grid=True,
do_not_save_samples=True,
)
p = processing.StableDiffusionProcessingTxt2Img( if preview_from_txt2img:
sd_model=shared.sd_model, p.prompt = preview_prompt
do_not_save_grid=True, p.negative_prompt = preview_negative_prompt
do_not_save_samples=True, p.steps = preview_steps
) p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
if preview_from_txt2img: preview_text = p.prompt
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None
processed = processing.process_images(p) if unload:
image = processed.images[0] if len(processed.images)>0 else None shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
if unload: if image is not None:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.state.current_image = image
shared.sd_model.first_stage_model.to(devices.cpu) image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
if image is not None: hypernetwork.train()
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step