mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-02-07 06:02:49 +08:00
Refactor GPU cache during training (#108)
This commit is contained in:
parent
297d92bf5d
commit
5d5ab5465f
@ -230,39 +230,50 @@ def train_and_evaluate(
|
|||||||
|
|
||||||
net_g.train()
|
net_g.train()
|
||||||
net_d.train()
|
net_d.train()
|
||||||
if cache == [] or hps.if_cache_data_in_gpu == False: # 第一个epoch把cache全部填满训练集
|
|
||||||
# print("caching")
|
# Prepare data iterator
|
||||||
for batch_idx, info in enumerate(train_loader):
|
if hps.if_cache_data_in_gpu == True:
|
||||||
if hps.if_f0 == 1:
|
# Use Cache
|
||||||
(
|
data_iterator = cache
|
||||||
phone,
|
if cache == []:
|
||||||
phone_lengths,
|
# Make new cache
|
||||||
pitch,
|
for batch_idx, info in enumerate(train_loader):
|
||||||
pitchf,
|
# Unpack
|
||||||
spec,
|
|
||||||
spec_lengths,
|
|
||||||
wave,
|
|
||||||
wave_lengths,
|
|
||||||
sid,
|
|
||||||
) = info
|
|
||||||
else:
|
|
||||||
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
phone, phone_lengths = phone.cuda(
|
|
||||||
rank, non_blocking=True
|
|
||||||
), phone_lengths.cuda(rank, non_blocking=True)
|
|
||||||
if hps.if_f0 == 1:
|
if hps.if_f0 == 1:
|
||||||
pitch, pitchf = pitch.cuda(rank, non_blocking=True), pitchf.cuda(
|
(
|
||||||
rank, non_blocking=True
|
phone,
|
||||||
)
|
phone_lengths,
|
||||||
sid = sid.cuda(rank, non_blocking=True)
|
pitch,
|
||||||
spec, spec_lengths = spec.cuda(
|
pitchf,
|
||||||
rank, non_blocking=True
|
spec,
|
||||||
), spec_lengths.cuda(rank, non_blocking=True)
|
spec_lengths,
|
||||||
wave, wave_lengths = wave.cuda(
|
wave,
|
||||||
rank, non_blocking=True
|
wave_lengths,
|
||||||
), wave_lengths.cuda(rank, non_blocking=True)
|
sid,
|
||||||
if hps.if_cache_data_in_gpu == True:
|
) = info
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
phone,
|
||||||
|
phone_lengths,
|
||||||
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
wave,
|
||||||
|
wave_lengths,
|
||||||
|
sid,
|
||||||
|
) = info
|
||||||
|
# Load on CUDA
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
phone = phone.cuda(rank, non_blocking=True)
|
||||||
|
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
||||||
|
if hps.if_f0 == 1:
|
||||||
|
pitch = pitch.cuda(rank, non_blocking=True)
|
||||||
|
pitchf = pitchf.cuda(rank, non_blocking=True)
|
||||||
|
sid = sid.cuda(rank, non_blocking=True)
|
||||||
|
spec = spec.cuda(rank, non_blocking=True)
|
||||||
|
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
||||||
|
wave = wave.cuda(rank, non_blocking=True)
|
||||||
|
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
||||||
|
# Cache on list
|
||||||
if hps.if_f0 == 1:
|
if hps.if_f0 == 1:
|
||||||
cache.append(
|
cache.append(
|
||||||
(
|
(
|
||||||
@ -295,372 +306,211 @@ def train_and_evaluate(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
else:
|
||||||
if hps.if_f0 == 1:
|
# Load shuffled cache
|
||||||
(
|
shuffle(cache)
|
||||||
y_hat,
|
else:
|
||||||
ids_slice,
|
# Loader
|
||||||
x_mask,
|
data_iterator = enumerate(train_loader)
|
||||||
z_mask,
|
|
||||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
|
||||||
) = net_g(
|
|
||||||
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
y_hat,
|
|
||||||
ids_slice,
|
|
||||||
x_mask,
|
|
||||||
z_mask,
|
|
||||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
|
||||||
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
|
||||||
mel = spec_to_mel_torch(
|
|
||||||
spec,
|
|
||||||
hps.data.filter_length,
|
|
||||||
hps.data.n_mel_channels,
|
|
||||||
hps.data.sampling_rate,
|
|
||||||
hps.data.mel_fmin,
|
|
||||||
hps.data.mel_fmax,
|
|
||||||
)
|
|
||||||
y_mel = commons.slice_segments(
|
|
||||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
|
||||||
)
|
|
||||||
with autocast(enabled=False):
|
|
||||||
y_hat_mel = mel_spectrogram_torch(
|
|
||||||
y_hat.float().squeeze(1),
|
|
||||||
hps.data.filter_length,
|
|
||||||
hps.data.n_mel_channels,
|
|
||||||
hps.data.sampling_rate,
|
|
||||||
hps.data.hop_length,
|
|
||||||
hps.data.win_length,
|
|
||||||
hps.data.mel_fmin,
|
|
||||||
hps.data.mel_fmax,
|
|
||||||
)
|
|
||||||
if hps.train.fp16_run == True:
|
|
||||||
y_hat_mel = y_hat_mel.half()
|
|
||||||
wave = commons.slice_segments(
|
|
||||||
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
|
||||||
) # slice
|
|
||||||
|
|
||||||
# Discriminator
|
# Run steps
|
||||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
for batch_idx, info in data_iterator:
|
||||||
with autocast(enabled=False):
|
# Data
|
||||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
## Unpack
|
||||||
y_d_hat_r, y_d_hat_g
|
if hps.if_f0 == 1:
|
||||||
)
|
(
|
||||||
optim_d.zero_grad()
|
phone,
|
||||||
scaler.scale(loss_disc).backward()
|
phone_lengths,
|
||||||
scaler.unscale_(optim_d)
|
pitch,
|
||||||
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
pitchf,
|
||||||
scaler.step(optim_d)
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
wave,
|
||||||
|
wave_lengths,
|
||||||
|
sid,
|
||||||
|
) = info
|
||||||
|
else:
|
||||||
|
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
|
||||||
|
## Load on CUDA
|
||||||
|
if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
|
||||||
|
phone = phone.cuda(rank, non_blocking=True)
|
||||||
|
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
||||||
|
if hps.if_f0 == 1:
|
||||||
|
pitch = pitch.cuda(rank, non_blocking=True)
|
||||||
|
pitchf = pitchf.cuda(rank, non_blocking=True)
|
||||||
|
sid = sid.cuda(rank, non_blocking=True)
|
||||||
|
spec = spec.cuda(rank, non_blocking=True)
|
||||||
|
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
||||||
|
wave = wave.cuda(rank, non_blocking=True)
|
||||||
|
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
||||||
|
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
# Calculate
|
||||||
# Generator
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
|
||||||
with autocast(enabled=False):
|
|
||||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
|
||||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
|
||||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
|
||||||
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
|
||||||
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
|
||||||
optim_g.zero_grad()
|
|
||||||
scaler.scale(loss_gen_all).backward()
|
|
||||||
scaler.unscale_(optim_g)
|
|
||||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
|
||||||
scaler.step(optim_g)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
if global_step % hps.train.log_interval == 0:
|
|
||||||
lr = optim_g.param_groups[0]["lr"]
|
|
||||||
logger.info(
|
|
||||||
"Train Epoch: {} [{:.0f}%]".format(
|
|
||||||
epoch, 100.0 * batch_idx / len(train_loader)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Amor For Tensorboard display
|
|
||||||
if loss_mel > 50:
|
|
||||||
loss_mel = 50
|
|
||||||
if loss_kl > 5:
|
|
||||||
loss_kl = 5
|
|
||||||
|
|
||||||
logger.info([global_step, lr])
|
|
||||||
logger.info(
|
|
||||||
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
|
|
||||||
)
|
|
||||||
scalar_dict = {
|
|
||||||
"loss/g/total": loss_gen_all,
|
|
||||||
"loss/d/total": loss_disc,
|
|
||||||
"learning_rate": lr,
|
|
||||||
"grad_norm_d": grad_norm_d,
|
|
||||||
"grad_norm_g": grad_norm_g,
|
|
||||||
}
|
|
||||||
scalar_dict.update(
|
|
||||||
{
|
|
||||||
"loss/g/fm": loss_fm,
|
|
||||||
"loss/g/mel": loss_mel,
|
|
||||||
"loss/g/kl": loss_kl,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
scalar_dict.update(
|
|
||||||
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
|
|
||||||
)
|
|
||||||
scalar_dict.update(
|
|
||||||
{
|
|
||||||
"loss/d_r/{}".format(i): v
|
|
||||||
for i, v in enumerate(losses_disc_r)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
scalar_dict.update(
|
|
||||||
{
|
|
||||||
"loss/d_g/{}".format(i): v
|
|
||||||
for i, v in enumerate(losses_disc_g)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
image_dict = {
|
|
||||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
|
||||||
y_mel[0].data.cpu().numpy()
|
|
||||||
),
|
|
||||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
|
||||||
y_hat_mel[0].data.cpu().numpy()
|
|
||||||
),
|
|
||||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
|
||||||
mel[0].data.cpu().numpy()
|
|
||||||
),
|
|
||||||
}
|
|
||||||
utils.summarize(
|
|
||||||
writer=writer,
|
|
||||||
global_step=global_step,
|
|
||||||
images=image_dict,
|
|
||||||
scalars=scalar_dict,
|
|
||||||
)
|
|
||||||
global_step += 1
|
|
||||||
# if global_step % hps.train.eval_interval == 0:
|
|
||||||
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
|
||||||
if hps.if_latest == 0:
|
|
||||||
utils.save_checkpoint(
|
|
||||||
net_g,
|
|
||||||
optim_g,
|
|
||||||
hps.train.learning_rate,
|
|
||||||
epoch,
|
|
||||||
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
|
|
||||||
)
|
|
||||||
utils.save_checkpoint(
|
|
||||||
net_d,
|
|
||||||
optim_d,
|
|
||||||
hps.train.learning_rate,
|
|
||||||
epoch,
|
|
||||||
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
utils.save_checkpoint(
|
|
||||||
net_g,
|
|
||||||
optim_g,
|
|
||||||
hps.train.learning_rate,
|
|
||||||
epoch,
|
|
||||||
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
|
|
||||||
)
|
|
||||||
utils.save_checkpoint(
|
|
||||||
net_d,
|
|
||||||
optim_d,
|
|
||||||
hps.train.learning_rate,
|
|
||||||
epoch,
|
|
||||||
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # 后续的epoch直接使用打乱的cache
|
|
||||||
shuffle(cache)
|
|
||||||
# print("using cache")
|
|
||||||
for batch_idx, info in cache:
|
|
||||||
if hps.if_f0 == 1:
|
if hps.if_f0 == 1:
|
||||||
(
|
(
|
||||||
phone,
|
y_hat,
|
||||||
phone_lengths,
|
ids_slice,
|
||||||
pitch,
|
x_mask,
|
||||||
pitchf,
|
z_mask,
|
||||||
spec,
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
spec_lengths,
|
) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
|
||||||
wave,
|
|
||||||
wave_lengths,
|
|
||||||
sid,
|
|
||||||
) = info
|
|
||||||
else:
|
else:
|
||||||
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
|
(
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
y_hat,
|
||||||
if hps.if_f0 == 1:
|
ids_slice,
|
||||||
(
|
x_mask,
|
||||||
y_hat,
|
z_mask,
|
||||||
ids_slice,
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
x_mask,
|
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
||||||
z_mask,
|
mel = spec_to_mel_torch(
|
||||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
spec,
|
||||||
) = net_g(
|
hps.data.filter_length,
|
||||||
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid
|
hps.data.n_mel_channels,
|
||||||
)
|
hps.data.sampling_rate,
|
||||||
else:
|
hps.data.mel_fmin,
|
||||||
(
|
hps.data.mel_fmax,
|
||||||
y_hat,
|
)
|
||||||
ids_slice,
|
y_mel = commons.slice_segments(
|
||||||
x_mask,
|
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||||
z_mask,
|
)
|
||||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
with autocast(enabled=False):
|
||||||
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
y_hat_mel = mel_spectrogram_torch(
|
||||||
mel = spec_to_mel_torch(
|
y_hat.float().squeeze(1),
|
||||||
spec,
|
|
||||||
hps.data.filter_length,
|
hps.data.filter_length,
|
||||||
hps.data.n_mel_channels,
|
hps.data.n_mel_channels,
|
||||||
hps.data.sampling_rate,
|
hps.data.sampling_rate,
|
||||||
|
hps.data.hop_length,
|
||||||
|
hps.data.win_length,
|
||||||
hps.data.mel_fmin,
|
hps.data.mel_fmin,
|
||||||
hps.data.mel_fmax,
|
hps.data.mel_fmax,
|
||||||
)
|
)
|
||||||
y_mel = commons.slice_segments(
|
if hps.train.fp16_run == True:
|
||||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
y_hat_mel = y_hat_mel.half()
|
||||||
|
wave = commons.slice_segments(
|
||||||
|
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||||
|
) # slice
|
||||||
|
|
||||||
|
# Discriminator
|
||||||
|
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
||||||
|
with autocast(enabled=False):
|
||||||
|
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||||
|
y_d_hat_r, y_d_hat_g
|
||||||
)
|
)
|
||||||
with autocast(enabled=False):
|
optim_d.zero_grad()
|
||||||
y_hat_mel = mel_spectrogram_torch(
|
scaler.scale(loss_disc).backward()
|
||||||
y_hat.float().squeeze(1),
|
scaler.unscale_(optim_d)
|
||||||
hps.data.filter_length,
|
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
||||||
hps.data.n_mel_channels,
|
scaler.step(optim_d)
|
||||||
hps.data.sampling_rate,
|
|
||||||
hps.data.hop_length,
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
hps.data.win_length,
|
# Generator
|
||||||
hps.data.mel_fmin,
|
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
||||||
hps.data.mel_fmax,
|
with autocast(enabled=False):
|
||||||
|
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
||||||
|
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
||||||
|
loss_fm = feature_loss(fmap_r, fmap_g)
|
||||||
|
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
||||||
|
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
||||||
|
optim_g.zero_grad()
|
||||||
|
scaler.scale(loss_gen_all).backward()
|
||||||
|
scaler.unscale_(optim_g)
|
||||||
|
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||||
|
scaler.step(optim_g)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
if global_step % hps.train.log_interval == 0:
|
||||||
|
lr = optim_g.param_groups[0]["lr"]
|
||||||
|
logger.info(
|
||||||
|
"Train Epoch: {} [{:.0f}%]".format(
|
||||||
|
epoch, 100.0 * batch_idx / len(train_loader)
|
||||||
)
|
)
|
||||||
if hps.train.fp16_run == True:
|
)
|
||||||
y_hat_mel = y_hat_mel.half()
|
# Amor For Tensorboard display
|
||||||
wave = commons.slice_segments(
|
if loss_mel > 50:
|
||||||
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
loss_mel = 50
|
||||||
) # slice
|
if loss_kl > 5:
|
||||||
|
loss_kl = 5
|
||||||
|
|
||||||
# Discriminator
|
logger.info([global_step, lr])
|
||||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
logger.info(
|
||||||
with autocast(enabled=False):
|
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
|
||||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
)
|
||||||
y_d_hat_r, y_d_hat_g
|
scalar_dict = {
|
||||||
)
|
"loss/g/total": loss_gen_all,
|
||||||
optim_d.zero_grad()
|
"loss/d/total": loss_disc,
|
||||||
scaler.scale(loss_disc).backward()
|
"learning_rate": lr,
|
||||||
scaler.unscale_(optim_d)
|
"grad_norm_d": grad_norm_d,
|
||||||
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
"grad_norm_g": grad_norm_g,
|
||||||
scaler.step(optim_d)
|
}
|
||||||
|
scalar_dict.update(
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
{
|
||||||
# Generator
|
"loss/g/fm": loss_fm,
|
||||||
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
"loss/g/mel": loss_mel,
|
||||||
with autocast(enabled=False):
|
"loss/g/kl": loss_kl,
|
||||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
|
||||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
|
||||||
|
|
||||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
|
||||||
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
|
||||||
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
|
||||||
optim_g.zero_grad()
|
|
||||||
scaler.scale(loss_gen_all).backward()
|
|
||||||
scaler.unscale_(optim_g)
|
|
||||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
|
||||||
scaler.step(optim_g)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
if global_step % hps.train.log_interval == 0:
|
|
||||||
lr = optim_g.param_groups[0]["lr"]
|
|
||||||
logger.info(
|
|
||||||
"Train Epoch: {} [{:.0f}%]".format(
|
|
||||||
epoch, 100.0 * batch_idx / len(train_loader)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Amor For Tensorboard display
|
|
||||||
if loss_mel > 50:
|
|
||||||
loss_mel = 50
|
|
||||||
if loss_kl > 5:
|
|
||||||
loss_kl = 5
|
|
||||||
|
|
||||||
logger.info([global_step, lr])
|
|
||||||
logger.info(
|
|
||||||
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
|
|
||||||
)
|
|
||||||
scalar_dict = {
|
|
||||||
"loss/g/total": loss_gen_all,
|
|
||||||
"loss/d/total": loss_disc,
|
|
||||||
"learning_rate": lr,
|
|
||||||
"grad_norm_d": grad_norm_d,
|
|
||||||
"grad_norm_g": grad_norm_g,
|
|
||||||
}
|
}
|
||||||
scalar_dict.update(
|
)
|
||||||
{
|
|
||||||
"loss/g/fm": loss_fm,
|
|
||||||
"loss/g/mel": loss_mel,
|
|
||||||
"loss/g/kl": loss_kl,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
scalar_dict.update(
|
scalar_dict.update(
|
||||||
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
|
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
|
||||||
)
|
|
||||||
scalar_dict.update(
|
|
||||||
{
|
|
||||||
"loss/d_r/{}".format(i): v
|
|
||||||
for i, v in enumerate(losses_disc_r)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
scalar_dict.update(
|
|
||||||
{
|
|
||||||
"loss/d_g/{}".format(i): v
|
|
||||||
for i, v in enumerate(losses_disc_g)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
image_dict = {
|
|
||||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
|
||||||
y_mel[0].data.cpu().numpy()
|
|
||||||
),
|
|
||||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
|
||||||
y_hat_mel[0].data.cpu().numpy()
|
|
||||||
),
|
|
||||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
|
||||||
mel[0].data.cpu().numpy()
|
|
||||||
),
|
|
||||||
}
|
|
||||||
utils.summarize(
|
|
||||||
writer=writer,
|
|
||||||
global_step=global_step,
|
|
||||||
images=image_dict,
|
|
||||||
scalars=scalar_dict,
|
|
||||||
)
|
|
||||||
global_step += 1
|
|
||||||
# if global_step % hps.train.eval_interval == 0:
|
|
||||||
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
|
||||||
if hps.if_latest == 0:
|
|
||||||
utils.save_checkpoint(
|
|
||||||
net_g,
|
|
||||||
optim_g,
|
|
||||||
hps.train.learning_rate,
|
|
||||||
epoch,
|
|
||||||
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
|
|
||||||
)
|
)
|
||||||
utils.save_checkpoint(
|
scalar_dict.update(
|
||||||
net_d,
|
{"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
|
||||||
optim_d,
|
|
||||||
hps.train.learning_rate,
|
|
||||||
epoch,
|
|
||||||
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
|
||||||
)
|
)
|
||||||
else:
|
scalar_dict.update(
|
||||||
utils.save_checkpoint(
|
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
|
||||||
net_g,
|
|
||||||
optim_g,
|
|
||||||
hps.train.learning_rate,
|
|
||||||
epoch,
|
|
||||||
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
|
|
||||||
)
|
)
|
||||||
utils.save_checkpoint(
|
image_dict = {
|
||||||
net_d,
|
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||||
optim_d,
|
y_mel[0].data.cpu().numpy()
|
||||||
hps.train.learning_rate,
|
),
|
||||||
epoch,
|
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||||
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
|
y_hat_mel[0].data.cpu().numpy()
|
||||||
|
),
|
||||||
|
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||||
|
mel[0].data.cpu().numpy()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
utils.summarize(
|
||||||
|
writer=writer,
|
||||||
|
global_step=global_step,
|
||||||
|
images=image_dict,
|
||||||
|
scalars=scalar_dict,
|
||||||
)
|
)
|
||||||
|
global_step += 1
|
||||||
|
# /Run steps
|
||||||
|
|
||||||
|
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
||||||
|
if hps.if_latest == 0:
|
||||||
|
utils.save_checkpoint(
|
||||||
|
net_g,
|
||||||
|
optim_g,
|
||||||
|
hps.train.learning_rate,
|
||||||
|
epoch,
|
||||||
|
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
|
||||||
|
)
|
||||||
|
utils.save_checkpoint(
|
||||||
|
net_d,
|
||||||
|
optim_d,
|
||||||
|
hps.train.learning_rate,
|
||||||
|
epoch,
|
||||||
|
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
utils.save_checkpoint(
|
||||||
|
net_g,
|
||||||
|
optim_g,
|
||||||
|
hps.train.learning_rate,
|
||||||
|
epoch,
|
||||||
|
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
|
||||||
|
)
|
||||||
|
utils.save_checkpoint(
|
||||||
|
net_d,
|
||||||
|
optim_d,
|
||||||
|
hps.train.learning_rate,
|
||||||
|
epoch,
|
||||||
|
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
|
||||||
|
)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("====> Epoch: {}".format(epoch))
|
logger.info("====> Epoch: {}".format(epoch))
|
||||||
|
Loading…
Reference in New Issue
Block a user