Refactor GPU cache during training (#108)

This commit is contained in:
tarepan 2023-04-22 21:05:00 +09:00 committed by GitHub
parent 297d92bf5d
commit 5d5ab5465f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -230,39 +230,50 @@ def train_and_evaluate(
net_g.train() net_g.train()
net_d.train() net_d.train()
if cache == [] or hps.if_cache_data_in_gpu == False: # 第一个epoch把cache全部填满训练集
# print("caching") # Prepare data iterator
for batch_idx, info in enumerate(train_loader): if hps.if_cache_data_in_gpu == True:
if hps.if_f0 == 1: # Use Cache
( data_iterator = cache
phone, if cache == []:
phone_lengths, # Make new cache
pitch, for batch_idx, info in enumerate(train_loader):
pitchf, # Unpack
spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
else:
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
if torch.cuda.is_available():
phone, phone_lengths = phone.cuda(
rank, non_blocking=True
), phone_lengths.cuda(rank, non_blocking=True)
if hps.if_f0 == 1: if hps.if_f0 == 1:
pitch, pitchf = pitch.cuda(rank, non_blocking=True), pitchf.cuda( (
rank, non_blocking=True phone,
) phone_lengths,
sid = sid.cuda(rank, non_blocking=True) pitch,
spec, spec_lengths = spec.cuda( pitchf,
rank, non_blocking=True spec,
), spec_lengths.cuda(rank, non_blocking=True) spec_lengths,
wave, wave_lengths = wave.cuda( wave,
rank, non_blocking=True wave_lengths,
), wave_lengths.cuda(rank, non_blocking=True) sid,
if hps.if_cache_data_in_gpu == True: ) = info
else:
(
phone,
phone_lengths,
spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
# Load on CUDA
if torch.cuda.is_available():
phone = phone.cuda(rank, non_blocking=True)
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
if hps.if_f0 == 1:
pitch = pitch.cuda(rank, non_blocking=True)
pitchf = pitchf.cuda(rank, non_blocking=True)
sid = sid.cuda(rank, non_blocking=True)
spec = spec.cuda(rank, non_blocking=True)
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
wave = wave.cuda(rank, non_blocking=True)
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
# Cache on list
if hps.if_f0 == 1: if hps.if_f0 == 1:
cache.append( cache.append(
( (
@ -295,372 +306,211 @@ def train_and_evaluate(
), ),
) )
) )
with autocast(enabled=hps.train.fp16_run): else:
if hps.if_f0 == 1: # Load shuffled cache
( shuffle(cache)
y_hat, else:
ids_slice, # Loader
x_mask, data_iterator = enumerate(train_loader)
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid
)
else:
(
y_hat,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
mel = spec_to_mel_torch(
spec,
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
with autocast(enabled=False):
y_hat_mel = mel_spectrogram_torch(
y_hat.float().squeeze(1),
hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax,
)
if hps.train.fp16_run == True:
y_hat_mel = y_hat_mel.half()
wave = commons.slice_segments(
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator # Run steps
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) for batch_idx, info in data_iterator:
with autocast(enabled=False): # Data
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( ## Unpack
y_d_hat_r, y_d_hat_g if hps.if_f0 == 1:
) (
optim_d.zero_grad() phone,
scaler.scale(loss_disc).backward() phone_lengths,
scaler.unscale_(optim_d) pitch,
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) pitchf,
scaler.step(optim_d) spec,
spec_lengths,
wave,
wave_lengths,
sid,
) = info
else:
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
## Load on CUDA
if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
phone = phone.cuda(rank, non_blocking=True)
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
if hps.if_f0 == 1:
pitch = pitch.cuda(rank, non_blocking=True)
pitchf = pitchf.cuda(rank, non_blocking=True)
sid = sid.cuda(rank, non_blocking=True)
spec = spec.cuda(rank, non_blocking=True)
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
wave = wave.cuda(rank, non_blocking=True)
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
with autocast(enabled=hps.train.fp16_run): # Calculate
# Generator with autocast(enabled=hps.train.fp16_run):
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
with autocast(enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
)
)
# Amor For Tensorboard display
if loss_mel > 50:
loss_mel = 50
if loss_kl > 5:
loss_kl = 5
logger.info([global_step, lr])
logger.info(
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
)
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update(
{
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl": loss_kl,
}
)
scalar_dict.update(
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
)
scalar_dict.update(
{
"loss/d_r/{}".format(i): v
for i, v in enumerate(losses_disc_r)
}
)
scalar_dict.update(
{
"loss/d_g/{}".format(i): v
for i, v in enumerate(losses_disc_g)
}
)
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)
global_step += 1
# if global_step % hps.train.eval_interval == 0:
if epoch % hps.save_every_epoch == 0 and rank == 0:
if hps.if_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
)
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
)
else: # 后续的epoch直接使用打乱的cache
shuffle(cache)
# print("using cache")
for batch_idx, info in cache:
if hps.if_f0 == 1: if hps.if_f0 == 1:
( (
phone, y_hat,
phone_lengths, ids_slice,
pitch, x_mask,
pitchf, z_mask,
spec, (z, z_p, m_p, logs_p, m_q, logs_q),
spec_lengths, ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
wave,
wave_lengths,
sid,
) = info
else: else:
phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info (
with autocast(enabled=hps.train.fp16_run): y_hat,
if hps.if_f0 == 1: ids_slice,
( x_mask,
y_hat, z_mask,
ids_slice, (z, z_p, m_p, logs_p, m_q, logs_q),
x_mask, ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
z_mask, mel = spec_to_mel_torch(
(z, z_p, m_p, logs_p, m_q, logs_q), spec,
) = net_g( hps.data.filter_length,
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid hps.data.n_mel_channels,
) hps.data.sampling_rate,
else: hps.data.mel_fmin,
( hps.data.mel_fmax,
y_hat, )
ids_slice, y_mel = commons.slice_segments(
x_mask, mel, ids_slice, hps.train.segment_size // hps.data.hop_length
z_mask, )
(z, z_p, m_p, logs_p, m_q, logs_q), with autocast(enabled=False):
) = net_g(phone, phone_lengths, spec, spec_lengths, sid) y_hat_mel = mel_spectrogram_torch(
mel = spec_to_mel_torch( y_hat.float().squeeze(1),
spec,
hps.data.filter_length, hps.data.filter_length,
hps.data.n_mel_channels, hps.data.n_mel_channels,
hps.data.sampling_rate, hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax, hps.data.mel_fmax,
) )
y_mel = commons.slice_segments( if hps.train.fp16_run == True:
mel, ids_slice, hps.train.segment_size // hps.data.hop_length y_hat_mel = y_hat_mel.half()
wave = commons.slice_segments(
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
) )
with autocast(enabled=False): optim_d.zero_grad()
y_hat_mel = mel_spectrogram_torch( scaler.scale(loss_disc).backward()
y_hat.float().squeeze(1), scaler.unscale_(optim_d)
hps.data.filter_length, grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
hps.data.n_mel_channels, scaler.step(optim_d)
hps.data.sampling_rate,
hps.data.hop_length, with autocast(enabled=hps.train.fp16_run):
hps.data.win_length, # Generator
hps.data.mel_fmin, y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
hps.data.mel_fmax, with autocast(enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
) )
if hps.train.fp16_run == True: )
y_hat_mel = y_hat_mel.half() # Amor For Tensorboard display
wave = commons.slice_segments( if loss_mel > 50:
wave, ids_slice * hps.data.hop_length, hps.train.segment_size loss_mel = 50
) # slice if loss_kl > 5:
loss_kl = 5
# Discriminator logger.info([global_step, lr])
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) logger.info(
with autocast(enabled=False): f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( )
y_d_hat_r, y_d_hat_g scalar_dict = {
) "loss/g/total": loss_gen_all,
optim_d.zero_grad() "loss/d/total": loss_disc,
scaler.scale(loss_disc).backward() "learning_rate": lr,
scaler.unscale_(optim_d) "grad_norm_d": grad_norm_d,
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) "grad_norm_g": grad_norm_g,
scaler.step(optim_d) }
scalar_dict.update(
with autocast(enabled=hps.train.fp16_run): {
# Generator "loss/g/fm": loss_fm,
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) "loss/g/mel": loss_mel,
with autocast(enabled=False): "loss/g/kl": loss_kl,
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
)
)
# Amor For Tensorboard display
if loss_mel > 50:
loss_mel = 50
if loss_kl > 5:
loss_kl = 5
logger.info([global_step, lr])
logger.info(
f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
)
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
} }
scalar_dict.update( )
{
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl": loss_kl,
}
)
scalar_dict.update( scalar_dict.update(
{"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
)
scalar_dict.update(
{
"loss/d_r/{}".format(i): v
for i, v in enumerate(losses_disc_r)
}
)
scalar_dict.update(
{
"loss/d_g/{}".format(i): v
for i, v in enumerate(losses_disc_g)
}
)
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)
global_step += 1
# if global_step % hps.train.eval_interval == 0:
if epoch % hps.save_every_epoch == 0 and rank == 0:
if hps.if_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
) )
utils.save_checkpoint( scalar_dict.update(
net_d, {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
) )
else: scalar_dict.update(
utils.save_checkpoint( {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
) )
utils.save_checkpoint( image_dict = {
net_d, "slice/mel_org": utils.plot_spectrogram_to_numpy(
optim_d, y_mel[0].data.cpu().numpy()
hps.train.learning_rate, ),
epoch, "slice/mel_gen": utils.plot_spectrogram_to_numpy(
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)), y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
) )
global_step += 1
# /Run steps
if epoch % hps.save_every_epoch == 0 and rank == 0:
if hps.if_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
)
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
)
if rank == 0: if rank == 0:
logger.info("====> Epoch: {}".format(epoch)) logger.info("====> Epoch: {}".format(epoch))