mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-01-30 18:22:51 +08:00
Refactor GPU cache during training (#108)
This commit is contained in:
parent
297d92bf5d
commit
5d5ab5465f
@ -230,39 +230,50 @@ def train_and_evaluate(
|
||||
|
||||
net_g.train()
|
||||
net_d.train()
|
||||
if cache == [] or hps.if_cache_data_in_gpu == False: # 第一个epoch把cache全部填满训练集
|
||||
# print("caching")
|
||||
for batch_idx, info in enumerate(train_loader):
|
||||
if hps.if_f0 == 1:
|
||||
(
|
||||
phone,
|
||||
phone_lengths,
|
||||
pitch,
|
||||
pitchf,
|
||||
spec,
|
||||
spec_lengths,
|
||||
wave,
|
||||
wave_lengths,
|
||||
sid,
|
||||
) = info
|
||||
else:
|
||||
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
|
||||
if torch.cuda.is_available():
|
||||
phone, phone_lengths = phone.cuda(
|
||||
rank, non_blocking=True
|
||||
), phone_lengths.cuda(rank, non_blocking=True)
|
||||
|
||||
# Prepare data iterator
|
||||
if hps.if_cache_data_in_gpu == True:
|
||||
# Use Cache
|
||||
data_iterator = cache
|
||||
if cache == []:
|
||||
# Make new cache
|
||||
for batch_idx, info in enumerate(train_loader):
|
||||
# Unpack
|
||||
if hps.if_f0 == 1:
|
||||
pitch, pitchf = pitch.cuda(rank, non_blocking=True), pitchf.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
sid = sid.cuda(rank, non_blocking=True)
|
||||
spec, spec_lengths = spec.cuda(
|
||||
rank, non_blocking=True
|
||||
), spec_lengths.cuda(rank, non_blocking=True)
|
||||
wave, wave_lengths = wave.cuda(
|
||||
rank, non_blocking=True
|
||||
), wave_lengths.cuda(rank, non_blocking=True)
|
||||
if hps.if_cache_data_in_gpu == True:
|
||||
(
|
||||
phone,
|
||||
phone_lengths,
|
||||
pitch,
|
||||
pitchf,
|
||||
spec,
|
||||
spec_lengths,
|
||||
wave,
|
||||
wave_lengths,
|
||||
sid,
|
||||
) = info
|
||||
else:
|
||||
(
|
||||
phone,
|
||||
phone_lengths,
|
||||
spec,
|
||||
spec_lengths,
|
||||
wave,
|
||||
wave_lengths,
|
||||
sid,
|
||||
) = info
|
||||
# Load on CUDA
|
||||
if torch.cuda.is_available():
|
||||
phone = phone.cuda(rank, non_blocking=True)
|
||||
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
||||
if hps.if_f0 == 1:
|
||||
pitch = pitch.cuda(rank, non_blocking=True)
|
||||
pitchf = pitchf.cuda(rank, non_blocking=True)
|
||||
sid = sid.cuda(rank, non_blocking=True)
|
||||
spec = spec.cuda(rank, non_blocking=True)
|
||||
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
||||
wave = wave.cuda(rank, non_blocking=True)
|
||||
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
||||
# Cache on list
|
||||
if hps.if_f0 == 1:
|
||||
cache.append(
|
||||
(
|
||||
@ -295,372 +306,211 @@ def train_and_evaluate(
|
||||
),
|
||||
)
|
||||
)
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
if hps.if_f0 == 1:
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(
|
||||
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid
|
||||
)
|
||||
else:
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = commons.slice_segments(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
)
|
||||
with autocast(enabled=False):
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.float().squeeze(1),
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
if hps.train.fp16_run == True:
|
||||
y_hat_mel = y_hat_mel.half()
|
||||
wave = commons.slice_segments(
|
||||
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||
) # slice
|
||||
else:
|
||||
# Load shuffled cache
|
||||
shuffle(cache)
|
||||
else:
|
||||
# Loader
|
||||
data_iterator = enumerate(train_loader)
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
||||
with autocast(enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r, y_d_hat_g
|
||||
)
|
||||
optim_d.zero_grad()
|
||||
scaler.scale(loss_disc).backward()
|
||||
scaler.unscale_(optim_d)
|
||||
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
||||
scaler.step(optim_d)
|
||||
# Run steps
|
||||
for batch_idx, info in data_iterator:
|
||||
# Data
|
||||
## Unpack
|
||||
if hps.if_f0 == 1:
|
||||
(
|
||||
phone,
|
||||
phone_lengths,
|
||||
pitch,
|
||||
pitchf,
|
||||
spec,
|
||||
spec_lengths,
|
||||
wave,
|
||||
wave_lengths,
|
||||
sid,
|
||||
) = info
|
||||
else:
|
||||
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
|
||||
## Load on CUDA
|
||||
if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
|
||||
phone = phone.cuda(rank, non_blocking=True)
|
||||
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
||||
if hps.if_f0 == 1:
|
||||
pitch = pitch.cuda(rank, non_blocking=True)
|
||||
pitchf = pitchf.cuda(rank, non_blocking=True)
|
||||
sid = sid.cuda(rank, non_blocking=True)
|
||||
spec = spec.cuda(rank, non_blocking=True)
|
||||
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
||||
wave = wave.cuda(rank, non_blocking=True)
|
||||
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
# Generator
|
||||
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
||||
with autocast(enabled=False):
|
||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
||||
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
||||
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if rank == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch, 100.0 * batch_idx / len(train_loader)
|
||||
)
|
||||
)
|
||||
# Amor For Tensorboard display
|
||||
if loss_mel > 50:
|
||||
loss_mel = 50
|
||||
if loss_kl > 5:
|
||||
loss_kl = 5
|
||||
|
||||
logger.info([global_step, lr])
|
||||
logger.info(
|
||||
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
|
||||
)
|
||||
scalar_dict = {
|
||||
"loss/g/total": loss_gen_all,
|
||||
"loss/d/total": loss_disc,
|
||||
"learning_rate": lr,
|
||||
"grad_norm_d": grad_norm_d,
|
||||
"grad_norm_g": grad_norm_g,
|
||||
}
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/g/fm": loss_fm,
|
||||
"loss/g/mel": loss_mel,
|
||||
"loss/g/kl": loss_kl,
|
||||
}
|
||||
)
|
||||
|
||||
scalar_dict.update(
|
||||
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
|
||||
)
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/d_r/{}".format(i): v
|
||||
for i, v in enumerate(losses_disc_r)
|
||||
}
|
||||
)
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/d_g/{}".format(i): v
|
||||
for i, v in enumerate(losses_disc_g)
|
||||
}
|
||||
)
|
||||
image_dict = {
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||
y_mel[0].data.cpu().numpy()
|
||||
),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].data.cpu().numpy()
|
||||
),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].data.cpu().numpy()
|
||||
),
|
||||
}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
global_step += 1
|
||||
# if global_step % hps.train.eval_interval == 0:
|
||||
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
||||
if hps.if_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
|
||||
)
|
||||
|
||||
else: # 后续的epoch直接使用打乱的cache
|
||||
shuffle(cache)
|
||||
# print("using cache")
|
||||
for batch_idx, info in cache:
|
||||
# Calculate
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
if hps.if_f0 == 1:
|
||||
(
|
||||
phone,
|
||||
phone_lengths,
|
||||
pitch,
|
||||
pitchf,
|
||||
spec,
|
||||
spec_lengths,
|
||||
wave,
|
||||
wave_lengths,
|
||||
sid,
|
||||
) = info
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
|
||||
else:
|
||||
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
if hps.if_f0 == 1:
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(
|
||||
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid
|
||||
)
|
||||
else:
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
(
|
||||
y_hat,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = commons.slice_segments(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
)
|
||||
with autocast(enabled=False):
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.float().squeeze(1),
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = commons.slice_segments(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
if hps.train.fp16_run == True:
|
||||
y_hat_mel = y_hat_mel.half()
|
||||
wave = commons.slice_segments(
|
||||
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||
) # slice
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
||||
with autocast(enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r, y_d_hat_g
|
||||
)
|
||||
with autocast(enabled=False):
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.float().squeeze(1),
|
||||
hps.data.filter_length,
|
||||
hps.data.n_mel_channels,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
optim_d.zero_grad()
|
||||
scaler.scale(loss_disc).backward()
|
||||
scaler.unscale_(optim_d)
|
||||
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
||||
scaler.step(optim_d)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
# Generator
|
||||
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
||||
with autocast(enabled=False):
|
||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
||||
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
||||
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if rank == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch, 100.0 * batch_idx / len(train_loader)
|
||||
)
|
||||
if hps.train.fp16_run == True:
|
||||
y_hat_mel = y_hat_mel.half()
|
||||
wave = commons.slice_segments(
|
||||
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||
) # slice
|
||||
)
|
||||
# Amor For Tensorboard display
|
||||
if loss_mel > 50:
|
||||
loss_mel = 50
|
||||
if loss_kl > 5:
|
||||
loss_kl = 5
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
||||
with autocast(enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r, y_d_hat_g
|
||||
)
|
||||
optim_d.zero_grad()
|
||||
scaler.scale(loss_disc).backward()
|
||||
scaler.unscale_(optim_d)
|
||||
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
||||
scaler.step(optim_d)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
# Generator
|
||||
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
||||
with autocast(enabled=False):
|
||||
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
||||
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
||||
|
||||
loss_fm = feature_loss(fmap_r, fmap_g)
|
||||
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
||||
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
||||
scaler.step(optim_g)
|
||||
scaler.update()
|
||||
|
||||
if rank == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch, 100.0 * batch_idx / len(train_loader)
|
||||
)
|
||||
)
|
||||
# Amor For Tensorboard display
|
||||
if loss_mel > 50:
|
||||
loss_mel = 50
|
||||
if loss_kl > 5:
|
||||
loss_kl = 5
|
||||
|
||||
logger.info([global_step, lr])
|
||||
logger.info(
|
||||
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
|
||||
)
|
||||
scalar_dict = {
|
||||
"loss/g/total": loss_gen_all,
|
||||
"loss/d/total": loss_disc,
|
||||
"learning_rate": lr,
|
||||
"grad_norm_d": grad_norm_d,
|
||||
"grad_norm_g": grad_norm_g,
|
||||
logger.info([global_step, lr])
|
||||
logger.info(
|
||||
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
|
||||
)
|
||||
scalar_dict = {
|
||||
"loss/g/total": loss_gen_all,
|
||||
"loss/d/total": loss_disc,
|
||||
"learning_rate": lr,
|
||||
"grad_norm_d": grad_norm_d,
|
||||
"grad_norm_g": grad_norm_g,
|
||||
}
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/g/fm": loss_fm,
|
||||
"loss/g/mel": loss_mel,
|
||||
"loss/g/kl": loss_kl,
|
||||
}
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/g/fm": loss_fm,
|
||||
"loss/g/mel": loss_mel,
|
||||
"loss/g/kl": loss_kl,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
scalar_dict.update(
|
||||
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
|
||||
)
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/d_r/{}".format(i): v
|
||||
for i, v in enumerate(losses_disc_r)
|
||||
}
|
||||
)
|
||||
scalar_dict.update(
|
||||
{
|
||||
"loss/d_g/{}".format(i): v
|
||||
for i, v in enumerate(losses_disc_g)
|
||||
}
|
||||
)
|
||||
image_dict = {
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||
y_mel[0].data.cpu().numpy()
|
||||
),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].data.cpu().numpy()
|
||||
),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].data.cpu().numpy()
|
||||
),
|
||||
}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
global_step += 1
|
||||
# if global_step % hps.train.eval_interval == 0:
|
||||
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
||||
if hps.if_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
|
||||
scalar_dict.update(
|
||||
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
||||
scalar_dict.update(
|
||||
{"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
|
||||
scalar_dict.update(
|
||||
{"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
|
||||
image_dict = {
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||
y_mel[0].data.cpu().numpy()
|
||||
),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].data.cpu().numpy()
|
||||
),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].data.cpu().numpy()
|
||||
),
|
||||
}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
global_step += 1
|
||||
# /Run steps
|
||||
|
||||
if epoch % hps.save_every_epoch == 0 and rank == 0:
|
||||
if hps.if_latest == 0:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
net_g,
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
net_d,
|
||||
optim_d,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
logger.info("====> Epoch: {}".format(epoch))
|
||||
|
Loading…
Reference in New Issue
Block a user