2023-04-15 19:44:24 +08:00
import os , sys , traceback
2023-04-13 23:57:27 +08:00
# device=sys.argv[1]
2023-04-15 19:44:24 +08:00
n_part = int ( sys . argv [ 2 ] )
i_part = int ( sys . argv [ 3 ] )
2023-04-11 18:14:55 +08:00
if len ( sys . argv ) == 5 :
2023-04-15 19:44:24 +08:00
exp_dir = sys . argv [ 4 ]
2023-05-15 13:11:01 +08:00
version = sys . argv [ 5 ]
2023-04-11 18:14:55 +08:00
else :
2023-04-15 19:44:24 +08:00
i_gpu = sys . argv [ 4 ]
exp_dir = sys . argv [ 5 ]
os . environ [ " CUDA_VISIBLE_DEVICES " ] = str ( i_gpu )
2023-05-15 13:11:01 +08:00
version = sys . argv [ 6 ]
2023-03-31 17:54:38 +08:00
import torch
import torch . nn . functional as F
import soundfile as sf
import numpy as np
from fairseq import checkpoint_utils
2023-04-15 19:44:24 +08:00
2023-04-13 23:57:27 +08:00
device = torch . device ( " cuda " if torch . cuda . is_available ( ) else " cpu " )
2023-03-31 17:54:38 +08:00
2023-05-14 15:52:36 +08:00
if torch . cuda . is_available ( ) :
device = " cuda "
elif torch . backends . mps . is_available ( ) :
device = " mps "
else :
device = " cpu "
2023-05-13 03:29:30 +08:00
2023-04-15 19:44:24 +08:00
f = open ( " %s /extract_f0_feature.log " % exp_dir , " a+ " )
2023-03-31 17:54:38 +08:00
def printt ( strr ) :
print ( strr )
f . write ( " %s \n " % strr )
f . flush ( )
2023-04-15 19:44:24 +08:00
2023-03-31 17:54:38 +08:00
printt ( sys . argv )
model_path = " hubert_base.pt "
printt ( exp_dir )
2023-04-15 19:44:24 +08:00
wavPath = " %s /1_16k_wavs " % exp_dir
2023-05-14 15:52:36 +08:00
outPath = (
" %s /3_feature256 " % exp_dir if version == " v1 " else " %s /3_feature768 " % exp_dir
)
2023-04-15 19:44:24 +08:00
os . makedirs ( outPath , exist_ok = True )
2023-03-31 17:54:38 +08:00
# wave must be 16k, hop_size=320
def readwave ( wav_path , normalize = False ) :
wav , sr = sf . read ( wav_path )
assert sr == 16000
feats = torch . from_numpy ( wav ) . float ( )
if feats . dim ( ) == 2 : # double channels
feats = feats . mean ( - 1 )
assert feats . dim ( ) == 1 , feats . dim ( )
if normalize :
with torch . no_grad ( ) :
feats = F . layer_norm ( feats , feats . shape )
feats = feats . view ( 1 , - 1 )
return feats
2023-04-15 19:44:24 +08:00
2023-03-31 17:54:38 +08:00
# HuBERT model
printt ( " load model(s) from {} " . format ( model_path ) )
2023-06-18 12:39:10 +08:00
# if hubert model is exist
if ( os . access ( model_path , os . F_OK ) == False ) :
printt ( " Error: Extracting is shut down because %s does not exist, you may download it from https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main " % model_path )
exit ( 0 )
2023-03-31 17:54:38 +08:00
models , saved_cfg , task = checkpoint_utils . load_model_ensemble_and_task (
[ model_path ] ,
suffix = " " ,
)
model = models [ 0 ]
model = model . to ( device )
2023-04-15 19:44:24 +08:00
printt ( " move model to %s " % device )
2023-05-14 15:52:36 +08:00
if device not in [ " mps " , " cpu " ] :
2023-04-15 19:44:24 +08:00
model = model . half ( )
2023-03-31 17:54:38 +08:00
model . eval ( )
2023-04-15 19:44:24 +08:00
todo = sorted ( list ( os . listdir ( wavPath ) ) ) [ i_part : : n_part ]
n = max ( 1 , len ( todo ) / / 10 ) # 最多打印十条
if len ( todo ) == 0 :
printt ( " no-feature-todo " )
2023-03-31 17:54:38 +08:00
else :
2023-04-15 19:44:24 +08:00
printt ( " all-feature- %s " % len ( todo ) )
for idx , file in enumerate ( todo ) :
2023-03-31 17:54:38 +08:00
try :
if file . endswith ( " .wav " ) :
2023-04-15 19:44:24 +08:00
wav_path = " %s / %s " % ( wavPath , file )
out_path = " %s / %s " % ( outPath , file . replace ( " wav " , " npy " ) )
2023-03-31 17:54:38 +08:00
2023-04-15 19:44:24 +08:00
if os . path . exists ( out_path ) :
continue
2023-03-31 17:54:38 +08:00
feats = readwave ( wav_path , normalize = saved_cfg . task . normalize )
padding_mask = torch . BoolTensor ( feats . shape ) . fill_ ( False )
inputs = {
2023-04-15 19:44:24 +08:00
" source " : feats . half ( ) . to ( device )
2023-05-13 03:29:30 +08:00
if device not in [ " mps " , " cpu " ]
2023-04-15 19:44:24 +08:00
else feats . to ( device ) ,
2023-03-31 17:54:38 +08:00
" padding_mask " : padding_mask . to ( device ) ,
2023-05-14 15:52:36 +08:00
" output_layer " : 9 if version == " v1 " else 12 , # layer 9
2023-03-31 17:54:38 +08:00
}
with torch . no_grad ( ) :
logits = model . extract_features ( * * inputs )
2023-05-14 15:52:36 +08:00
feats = (
model . final_proj ( logits [ 0 ] ) if version == " v1 " else logits [ 0 ]
)
2023-03-31 17:54:38 +08:00
feats = feats . squeeze ( 0 ) . float ( ) . cpu ( ) . numpy ( )
2023-04-15 19:44:24 +08:00
if np . isnan ( feats ) . sum ( ) == 0 :
2023-04-09 23:31:38 +08:00
np . save ( out_path , feats , allow_pickle = False )
else :
2023-04-15 19:44:24 +08:00
printt ( " %s -contains nan " % file )
if idx % n == 0 :
printt ( " now- %s ,all- %s , %s , %s " % ( len ( todo ) , idx , file , feats . shape ) )
2023-03-31 17:54:38 +08:00
except :
printt ( traceback . format_exc ( ) )
2023-04-09 23:31:38 +08:00
printt ( " all-feature-done " )