diff --git a/configs/config.py b/configs/config.py index 05819b9..c089997 100644 --- a/configs/config.py +++ b/configs/config.py @@ -5,7 +5,13 @@ import json from multiprocessing import cpu_count import torch - +try: + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + if torch.xpu.is_available(): + from infer.modules.ipex import ipex_init + ipex_init() +except Exception: + pass import logging logger = logging.getLogger(__name__) @@ -103,12 +109,22 @@ class Config: except Exception: return False + @staticmethod + def has_xpu() -> bool: + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return True + else: + return False + def use_fp32_config(self): for config_file in version_config_list: self.json_config[config_file]["train"]["fp16_run"] = False def device_config(self) -> tuple: if torch.cuda.is_available(): + if self.has_xpu(): + self.device = self.instead = "xpu:0" + self.is_half = True i_device = int(self.device.split(":")[-1]) self.gpu_name = torch.cuda.get_device_name(i_device) if ( diff --git a/docs/en/README.en.md b/docs/en/README.en.md index cce3b92..5d7559a 100644 --- a/docs/en/README.en.md +++ b/docs/en/README.en.md @@ -43,6 +43,7 @@ This repository has the following features: + Use the UVR5 model to quickly separate vocals and instruments. + Use the most powerful High-pitch Voice Extraction Algorithm [InterSpeech2023-RMVPE](#Credits) to prevent the muted sound problem. Provides the best results (significantly) and is faster, with even lower resource consumption than Crepe_full. + AMD/Intel graphics cards acceleration supported. ++ Intel ARC graphics cards acceleration with IPEX supported. ## Preparing the environment The following commands need to be executed in the environment of Python version 3.8 or higher. @@ -77,6 +78,9 @@ for Nvidia graphics cards for AMD/Intel graphics cards: pip install -r requirements-dml.txt +for Intel ARC graphics cards on Linux / WSL using Python 3.10: + pip install -r requirements-ipex.txt + ``` ------ @@ -124,6 +128,9 @@ https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.pt https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.onnx ``` + +Intel ARC graphics cards users needs to run `source /opt/intel/oneapi/setvars.sh` command before starting Webui. + Then use this command to start Webui: ```bash python infer-web.py diff --git a/docs/tr/README.tr.md b/docs/tr/README.tr.md index 846b929..a420b41 100644 --- a/docs/tr/README.tr.md +++ b/docs/tr/README.tr.md @@ -42,6 +42,7 @@ Bu depo aşağıdaki özelliklere sahiptir: + UVR5 modelini kullanarak hızla vokalleri ve enstrümanları ayırma. + En güçlü Yüksek tiz Ses Çıkarma Algoritması [InterSpeech2023-RMVPE](#Krediler) sessiz ses sorununu önlemek için kullanılır. En iyi sonuçları (önemli ölçüde) sağlar ve Crepe_full'den daha hızlı çalışır, hatta daha düşük kaynak tüketimi sağlar. + AMD/Intel grafik kartları hızlandırması desteklenir. ++ Intel ARC grafik kartları hızlandırması IPEX ile desteklenir. ## Ortamın Hazırlanması Aşağıdaki komutlar, Python sürümü 3.8 veya daha yüksek olan bir ortamda çalıştırılmalıdır. @@ -73,11 +74,12 @@ Ayrıca bunları pip kullanarak da kurabilirsiniz: Nvidia grafik kartları için pip install -r requirements.txt - - AMD/Intel grafik kartları için: pip install -r requirements-dml.txt +Intel ARC grafik kartları için Linux / WSL ile Python 3.10 kullanarak: + pip install -r requirements-ipex.txt + ``` ------ @@ -125,6 +127,9 @@ https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.pt https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.onnx ``` + +Intel ARC grafik kartları kullanıcıları Webui'yi başlatmadan önce `source /opt/intel/oneapi/setvars.sh` komutunu çalıştırmalı. + Daha sonra bu komutu kullanarak Webui'yi başlatabilirsiniz: ```bash python infer-web.py diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index d02d8cf..425cc9c 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -13,6 +13,7 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm from infer.lib.infer_pack import attentions, commons, modules from infer.lib.infer_pack.commons import get_padding, init_weights +has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) class TextEncoder256(nn.Module): def __init__( @@ -1156,7 +1157,10 @@ class DiscriminatorP(torch.nn.Module): b, c, t = x.shape if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") + if has_xpu and x.dtype == torch.bfloat16: + x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(dtype=torch.bfloat16) + else: + x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) diff --git a/infer/lib/rmvpe.py b/infer/lib/rmvpe.py index 57e4de5..a4fcb1d 100644 --- a/infer/lib/rmvpe.py +++ b/infer/lib/rmvpe.py @@ -2,6 +2,14 @@ import pdb, os import numpy as np import torch +try: + #Fix "Torch not compiled with CUDA enabled" + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + if torch.xpu.is_available(): + from infer.modules.ipex import ipex_init + ipex_init() +except Exception: + pass import torch.nn as nn import torch.nn.functional as F from librosa.util import normalize, pad_center, tiny diff --git a/infer/modules/ipex/__init__.py b/infer/modules/ipex/__init__.py new file mode 100644 index 0000000..3207452 --- /dev/null +++ b/infer/modules/ipex/__init__.py @@ -0,0 +1,165 @@ +import os +import sys +import contextlib +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from .hijacks import ipex_hijacks +from .attention import attention_init + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +def ipex_init(): # pylint: disable=too-many-statements + try: + #Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + + #Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + + #RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + #AMP: + torch.cuda.amp = torch.xpu.amp + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False + try: + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + + #C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 + + #Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.version.cuda = "11.7" + torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] + torch.cuda.get_device_properties.major = 11 + torch.cuda.get_device_properties.minor = 7 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 + + ipex_hijacks() + attention_init() + except Exception as e: + return False, e + return True, None diff --git a/infer/modules/ipex/attention.py b/infer/modules/ipex/attention.py new file mode 100644 index 0000000..d7335bf --- /dev/null +++ b/infer/modules/ipex/attention.py @@ -0,0 +1,128 @@ +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +original_torch_bmm = torch.bmm +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] + block_multiply = 2.4 if input.dtype == torch.float32 else 1.2 + block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB + split_slice_size = batch_size_attention + if block_size >= 4000: + do_split = True + #Find something divisible with the input_tokens + while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB + split_2_slice_size = input_tokens + if split_block_size >= 4000: + do_split_2 = True + #Find something divisible with the input_tokens + while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) + else: + hidden_states[start_idx:end_idx] = original_torch_bmm( + input[start_idx:end_idx], + mat2[start_idx:end_idx], + out=out + ) + else: + return original_torch_bmm(input, mat2, out=out) + return hidden_states + +original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + shape_one, batch_size_attention, query_tokens, shape_four = query.shape + block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 + block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB + split_slice_size = batch_size_attention + if block_size >= 4000: + do_split = True + #Find something divisible with the shape_one + while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB + split_2_slice_size = query_tokens + if split_block_size >= 4000: + do_split_2 = True + #Find something divisible with the batch_size_attention + while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx, start_idx_2:end_idx_2], + key[:, start_idx:end_idx, start_idx_2:end_idx_2], + value[:, start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx], + key[:, start_idx:end_idx], + value[:, start_idx:end_idx], + attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + return original_scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal + ) + return hidden_states + +def attention_init(): + #ARC GPUs can't allocate more than 4GB to a single block: + torch.bmm = torch_bmm + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/infer/modules/ipex/gradscaler.py b/infer/modules/ipex/gradscaler.py new file mode 100644 index 0000000..5302121 --- /dev/null +++ b/infer/modules/ipex/gradscaler.py @@ -0,0 +1,179 @@ +from collections import defaultdict +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +OptState = ipex.cpu.autocast._grad_scaler.OptState +_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator +_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state + +def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + # sync grad to master weight + if hasattr(optimizer, "sync_grad"): + optimizer.sync_grad() + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # -: is there a way to split by device and dtype without appending in the inner loop? + to_unscale = to_unscale.to("cpu") + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + core._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get("cpu"), + per_device_inv_scale.get("cpu"), + ) + + return per_device_found_inf._per_device_tensors + +def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device + ) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False + ) + optimizer_state["stage"] = OptState.UNSCALED + +def update(self, new_scale=None): + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device="cpu", non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + to_device = _scale.device + _scale = _scale.to("cpu") + _growth_tracker = _growth_tracker.to("cpu") + + core._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + + _scale = _scale.to(to_device) + _growth_tracker = _growth_tracker.to(to_device) + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + +def gradscaler_init(): + torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ + torch.xpu.amp.GradScaler.unscale_ = unscale_ + torch.xpu.amp.GradScaler.update = update + return torch.xpu.amp.GradScaler diff --git a/infer/modules/ipex/hijacks.py b/infer/modules/ipex/hijacks.py new file mode 100644 index 0000000..78d7e03 --- /dev/null +++ b/infer/modules/ipex/hijacks.py @@ -0,0 +1,196 @@ +import contextlib +import importlib +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return + +class CondFunc: # pylint: disable=missing-class-docstring + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-1, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) + +_utils = torch.utils.data._utils +def _shutdown_workers(self): + if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: + return + if hasattr(self, "_shutdown") and not self._shutdown: + self._shutdown = True + try: + if hasattr(self, '_pin_memory_thread'): + self._pin_memory_thread_done_event.set() + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: # pylint: disable=invalid-name + w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: # pylint: disable=invalid-name + q.cancel_join_thread() + q.close() + finally: + if self._worker_pids_set: + torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: # pylint: disable=invalid-name + if w.is_alive(): + w.terminate() + +class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods + def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument + if isinstance(device_ids, list) and len(device_ids) > 1: + print("IPEX backend doesn't support DataParallel on multiple XPU devices") + return module.to("xpu") + +def return_null_context(*args, **kwargs): # pylint: disable=unused-argument + return contextlib.nullcontext() + +def check_device(device): + return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) + +def return_xpu(device): + return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + +def ipex_no_cuda(orig_func, *args, **kwargs): + torch.cuda.is_available = lambda: False + orig_func(*args, **kwargs) + torch.cuda.is_available = torch.xpu.is_available + +original_autocast = torch.autocast +def ipex_autocast(*args, **kwargs): + if len(args) > 0 and args[0] == "cuda": + return original_autocast("xpu", *args[1:], **kwargs) + else: + return original_autocast(*args, **kwargs) + +original_torch_cat = torch.cat +def torch_cat(tensor, *args, **kwargs): + if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): + return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) + else: + return original_torch_cat(tensor, *args, **kwargs) + +original_interpolate = torch.nn.functional.interpolate +def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments + if antialias or align_corners is not None: + return_device = tensor.device + return_dtype = tensor.dtype + return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) + else: + return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + +original_linalg_solve = torch.linalg.solve +def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name + if A.device != torch.device("cpu") or B.device != torch.device("cpu"): + return_device = A.device + return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) + else: + return original_linalg_solve(A, B, *args, **kwargs) + +def ipex_hijacks(): + CondFunc('torch.Tensor.to', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.Tensor.cuda', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.empty', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.load', + lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), + lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) + CondFunc('torch.randn', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.ones', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.zeros', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.tensor', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.linspace', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + + CondFunc('torch.batch_norm', + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, + weight if weight is not None else torch.ones(input.size()[1], device=input.device), + bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + CondFunc('torch.instance_norm', + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, + weight if weight is not None else torch.ones(input.size()[1], device=input.device), + bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + + #Functions with dtype errors: + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.linear.Linear.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + + #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + if not torch.xpu.has_fp64_dtype(): + CondFunc('torch.from_numpy', + lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), + lambda orig_func, ndarray: ndarray.dtype == float) + + #Broken functions when torch.cuda.is_available is True: + CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', + lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), + lambda orig_func, *args, **kwargs: True) + + #Functions that make compile mad with CondFunc: + torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + torch.nn.DataParallel = DummyDataParallel + torch.autocast = ipex_autocast + torch.cat = torch_cat + torch.linalg.solve = linalg_solve + torch.nn.functional.interpolate = interpolate + torch.backends.cuda.sdp_kernel = return_null_context diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 1c2ab2b..7fed4f4 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -17,6 +17,18 @@ n_gpus = len(hps.gpus.split("-")) from random import randint, shuffle import torch +try: + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + if torch.xpu.is_available(): + from infer.modules.ipex import ipex_init + from infer.modules.ipex.gradscaler import gradscaler_init + from torch.xpu.amp import autocast + GradScaler = gradscaler_init() + ipex_init() + else: + from torch.cuda.amp import GradScaler, autocast +except Exception: + from torch.cuda.amp import GradScaler, autocast torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False @@ -25,7 +37,6 @@ from time import time as ttime import torch.distributed as dist import torch.multiprocessing as mp -from torch.cuda.amp import GradScaler, autocast from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader @@ -185,7 +196,9 @@ def run(rank, n_gpus, hps): ) # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) - if torch.cuda.is_available(): + if hasattr(torch, "xpu") and torch.xpu.is_available(): + pass + elif torch.cuda.is_available(): net_g = DDP(net_g, device_ids=[rank]) net_d = DDP(net_d, device_ids=[rank]) else: @@ -212,19 +225,33 @@ def run(rank, n_gpus, hps): if hps.pretrainG != "": if rank == 0: logger.info("loaded pretrained %s" % (hps.pretrainG)) - logger.info( - net_g.module.load_state_dict( - torch.load(hps.pretrainG, map_location="cpu")["model"] - ) - ) ##测试不加载优化器 + if hasattr(net_g, "module"): + logger.info( + net_g.module.load_state_dict( + torch.load(hps.pretrainG, map_location="cpu")["model"] + ) + ) ##测试不加载优化器 + else: + logger.info( + net_g.load_state_dict( + torch.load(hps.pretrainG, map_location="cpu")["model"] + ) + ) ##测试不加载优化器 if hps.pretrainD != "": if rank == 0: logger.info("loaded pretrained %s" % (hps.pretrainD)) - logger.info( - net_d.module.load_state_dict( - torch.load(hps.pretrainD, map_location="cpu")["model"] + if hasattr(net_d, "module"): + logger.info( + net_d.module.load_state_dict( + torch.load(hps.pretrainD, map_location="cpu")["model"] + ) + ) + else: + logger.info( + net_d.load_state_dict( + torch.load(hps.pretrainD, map_location="cpu")["model"] + ) ) - ) scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 diff --git a/infer/modules/uvr5/modules.py b/infer/modules/uvr5/modules.py index f627ca9..f63ac6a 100644 --- a/infer/modules/uvr5/modules.py +++ b/infer/modules/uvr5/modules.py @@ -76,10 +76,18 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format infos.append("%s->Success" % (os.path.basename(inp_path))) yield "\n".join(infos) except: - infos.append( - "%s->%s" % (os.path.basename(inp_path), traceback.format_exc()) - ) - yield "\n".join(infos) + try: + if done == 0: + pre_fun._path_audio_( + inp_path, save_root_ins, save_root_vocal, format0 + ) + infos.append("%s->Success" % (os.path.basename(inp_path))) + yield "\n".join(infos) + except: + infos.append( + "%s->%s" % (os.path.basename(inp_path), traceback.format_exc()) + ) + yield "\n".join(infos) except: infos.append(traceback.format_exc()) yield "\n".join(infos) diff --git a/infer/modules/vc/pipeline.py b/infer/modules/vc/pipeline.py index 6fcefca..0399f29 100644 --- a/infer/modules/vc/pipeline.py +++ b/infer/modules/vc/pipeline.py @@ -362,7 +362,7 @@ class Pipeline(object): ) pitch = pitch[:p_len] pitchf = pitchf[:p_len] - if self.device == "mps": + if self.device == "mps" or "xpu" in self.device: pitchf = pitchf.astype(np.float32) pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long() pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float() diff --git a/requirements-ipex.txt b/requirements-ipex.txt new file mode 100644 index 0000000..1a96cf0 --- /dev/null +++ b/requirements-ipex.txt @@ -0,0 +1,54 @@ +torch==2.0.1a0 +intel_extension_for_pytorch==2.0.110+xpu +torchvision==0.15.2a0 +https://github.com/Disty0/Retrieval-based-Voice-Conversion-WebUI/releases/download/torchaudio_wheels_for_ipex/torchaudio-2.0.2+31de77d-cp310-cp310-linux_x86_64.whl +-f https://developer.intel.com/ipex-whl-stable-xpu +joblib>=1.1.0 +numba==0.56.4 +numpy==1.23.5 +scipy +librosa==0.9.1 +llvmlite==0.39.0 +fairseq==0.12.2 +faiss-cpu==1.7.3 +gradio==3.34.0 +Cython +pydub>=0.25.1 +soundfile>=0.12.1 +ffmpeg-python>=0.2.0 +tensorboardX +Jinja2>=3.1.2 +json5 +Markdown +matplotlib>=3.7.0 +matplotlib-inline>=0.1.3 +praat-parselmouth>=0.4.2 +Pillow>=9.1.1 +resampy>=0.4.2 +scikit-learn +tensorboard +tqdm>=4.63.1 +tornado>=6.1 +Werkzeug>=2.2.3 +uc-micro-py>=1.0.1 +sympy>=1.11.1 +tabulate>=0.8.10 +PyYAML>=6.0 +pyasn1>=0.4.8 +pyasn1-modules>=0.2.8 +fsspec>=2022.11.0 +absl-py>=1.2.0 +audioread +uvicorn>=0.21.1 +colorama>=0.4.5 +pyworld==0.3.2 +httpx +onnxruntime; sys_platform == 'darwin' +onnxruntime-gpu; sys_platform != 'darwin' +torchcrepe==0.0.20 +fastapi==0.88 +ffmpy==0.3.1 +python-dotenv>=1.0.0 +av +PySimpleGUI +sounddevice \ No newline at end of file