Feature: Intel ARC GPU support with IPEX (#1204)

* Initial Intel ARC support with IPEX

* Fix infer

* Fix train model

* Cleanup

* Cleanup

* Update README

* Make pylint happy

* Move dataloader fix to hijacks

* Fix torch.linalg.solve

* Fix SDP

* Add has_xpu to config.py

* Revert return_xpu fix
This commit is contained in:
Disty0 2023-09-09 07:00:29 +03:00 committed by GitHub
parent c761bda09a
commit 0c94f60093
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 817 additions and 20 deletions

View File

@ -5,7 +5,13 @@ import json
from multiprocessing import cpu_count from multiprocessing import cpu_count
import torch import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
ipex_init()
except Exception:
pass
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -103,12 +109,22 @@ class Config:
except Exception: except Exception:
return False return False
@staticmethod
def has_xpu() -> bool:
if hasattr(torch, "xpu") and torch.xpu.is_available():
return True
else:
return False
def use_fp32_config(self): def use_fp32_config(self):
for config_file in version_config_list: for config_file in version_config_list:
self.json_config[config_file]["train"]["fp16_run"] = False self.json_config[config_file]["train"]["fp16_run"] = False
def device_config(self) -> tuple: def device_config(self) -> tuple:
if torch.cuda.is_available(): if torch.cuda.is_available():
if self.has_xpu():
self.device = self.instead = "xpu:0"
self.is_half = True
i_device = int(self.device.split(":")[-1]) i_device = int(self.device.split(":")[-1])
self.gpu_name = torch.cuda.get_device_name(i_device) self.gpu_name = torch.cuda.get_device_name(i_device)
if ( if (

View File

@ -43,6 +43,7 @@ This repository has the following features:
+ Use the UVR5 model to quickly separate vocals and instruments. + Use the UVR5 model to quickly separate vocals and instruments.
+ Use the most powerful High-pitch Voice Extraction Algorithm [InterSpeech2023-RMVPE](#Credits) to prevent the muted sound problem. Provides the best results (significantly) and is faster, with even lower resource consumption than Crepe_full. + Use the most powerful High-pitch Voice Extraction Algorithm [InterSpeech2023-RMVPE](#Credits) to prevent the muted sound problem. Provides the best results (significantly) and is faster, with even lower resource consumption than Crepe_full.
+ AMD/Intel graphics cards acceleration supported. + AMD/Intel graphics cards acceleration supported.
+ Intel ARC graphics cards acceleration with IPEX supported.
## Preparing the environment ## Preparing the environment
The following commands need to be executed in the environment of Python version 3.8 or higher. The following commands need to be executed in the environment of Python version 3.8 or higher.
@ -77,6 +78,9 @@ for Nvidia graphics cards
for AMD/Intel graphics cards for AMD/Intel graphics cards
pip install -r requirements-dml.txt pip install -r requirements-dml.txt
for Intel ARC graphics cards on Linux / WSL using Python 3.10:
pip install -r requirements-ipex.txt
``` ```
------ ------
@ -124,6 +128,9 @@ https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.pt
https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.onnx https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.onnx
``` ```
Intel ARC graphics cards users needs to run `source /opt/intel/oneapi/setvars.sh` command before starting Webui.
Then use this command to start Webui: Then use this command to start Webui:
```bash ```bash
python infer-web.py python infer-web.py

View File

@ -42,6 +42,7 @@ Bu depo aşağıdaki özelliklere sahiptir:
+ UVR5 modelini kullanarak hızla vokalleri ve enstrümanları ayırma. + UVR5 modelini kullanarak hızla vokalleri ve enstrümanları ayırma.
+ En güçlü Yüksek tiz Ses Çıkarma Algoritması [InterSpeech2023-RMVPE](#Krediler) sessiz ses sorununu önlemek için kullanılır. En iyi sonuçları (önemli ölçüde) sağlar ve Crepe_full'den daha hızlı çalışır, hatta daha düşük kaynak tüketimi sağlar. + En güçlü Yüksek tiz Ses Çıkarma Algoritması [InterSpeech2023-RMVPE](#Krediler) sessiz ses sorununu önlemek için kullanılır. En iyi sonuçları (önemli ölçüde) sağlar ve Crepe_full'den daha hızlı çalışır, hatta daha düşük kaynak tüketimi sağlar.
+ AMD/Intel grafik kartları hızlandırması desteklenir. + AMD/Intel grafik kartları hızlandırması desteklenir.
+ Intel ARC grafik kartları hızlandırması IPEX ile desteklenir.
## Ortamın Hazırlanması ## Ortamın Hazırlanması
Aşağıdaki komutlar, Python sürümü 3.8 veya daha yüksek olan bir ortamda çalıştırılmalıdır. Aşağıdaki komutlar, Python sürümü 3.8 veya daha yüksek olan bir ortamda çalıştırılmalıdır.
@ -73,11 +74,12 @@ Ayrıca bunları pip kullanarak da kurabilirsiniz:
Nvidia grafik kartları için Nvidia grafik kartları için
pip install -r requirements.txt pip install -r requirements.txt
AMD/Intel grafik kartları için AMD/Intel grafik kartları için
pip install -r requirements-dml.txt pip install -r requirements-dml.txt
Intel ARC grafik kartları için Linux / WSL ile Python 3.10 kullanarak:
pip install -r requirements-ipex.txt
``` ```
------ ------
@ -125,6 +127,9 @@ https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.pt
https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.onnx https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/rmvpe.onnx
``` ```
Intel ARC grafik kartları kullanıcıları Webui'yi başlatmadan önce `source /opt/intel/oneapi/setvars.sh` komutunu çalıştırmalı.
Daha sonra bu komutu kullanarak Webui'yi başlatabilirsiniz: Daha sonra bu komutu kullanarak Webui'yi başlatabilirsiniz:
```bash ```bash
python infer-web.py python infer-web.py

View File

@ -13,6 +13,7 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from infer.lib.infer_pack import attentions, commons, modules from infer.lib.infer_pack import attentions, commons, modules
from infer.lib.infer_pack.commons import get_padding, init_weights from infer.lib.infer_pack.commons import get_padding, init_weights
has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
class TextEncoder256(nn.Module): class TextEncoder256(nn.Module):
def __init__( def __init__(
@ -1156,7 +1157,10 @@ class DiscriminatorP(torch.nn.Module):
b, c, t = x.shape b, c, t = x.shape
if t % self.period != 0: # pad first if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period) n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect") if has_xpu and x.dtype == torch.bfloat16:
x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(dtype=torch.bfloat16)
else:
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad t = t + n_pad
x = x.view(b, c, t // self.period, self.period) x = x.view(b, c, t // self.period, self.period)

View File

@ -2,6 +2,14 @@ import pdb, os
import numpy as np import numpy as np
import torch import torch
try:
#Fix "Torch not compiled with CUDA enabled"
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
ipex_init()
except Exception:
pass
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from librosa.util import normalize, pad_center, tiny from librosa.util import normalize, pad_center, tiny

View File

@ -0,0 +1,165 @@
import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
try:
#Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
#Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
#RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
#AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7"
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
torch.cuda.get_device_properties.major = 11
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks()
attention_init()
except Exception as e:
return False, e
return True, None

View File

@ -0,0 +1,128 @@
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = 2.4 if input.dtype == torch.float32 else 1.2
block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB
split_slice_size = batch_size_attention
if block_size >= 4000:
do_split = True
#Find something divisible with the input_tokens
while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB
split_2_slice_size = input_tokens
if split_block_size >= 4000:
do_split_2 = True
#Find something divisible with the input_tokens
while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
block_multiply = 2.4 if query.dtype == torch.float32 else 1.2
block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB
split_slice_size = batch_size_attention
if block_size >= 4000:
do_split = True
#Find something divisible with the shape_one
while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB
split_2_slice_size = query_tokens
if split_block_size >= 4000:
do_split_2 = True
#Find something divisible with the batch_size_attention
while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return hidden_states
def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@ -0,0 +1,179 @@
from collections import defaultdict
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
# sync grad to master weight
if hasattr(optimizer, "sync_grad"):
optimizer.sync_grad()
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# -: is there a way to split by device and dtype without appending in the inner loop?
to_unscale = to_unscale.to("cpu")
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
core._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get("cpu"),
per_device_inv_scale.get("cpu"),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device="cpu", non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
to_device = _scale.device
_scale = _scale.to("cpu")
_growth_tracker = _growth_tracker.to("cpu")
core._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
_scale = _scale.to(to_device)
_growth_tracker = _growth_tracker.to(to_device)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def gradscaler_init():
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
torch.xpu.amp.GradScaler.unscale_ = unscale_
torch.xpu.amp.GradScaler.update = update
return torch.xpu.amp.GradScaler

View File

@ -0,0 +1,196 @@
import contextlib
import importlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class CondFunc: # pylint: disable=missing-class-docstring
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
_utils = torch.utils.data._utils
def _shutdown_workers(self):
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
return
if hasattr(self, "_shutdown") and not self._shutdown:
self._shutdown = True
try:
if hasattr(self, '_pin_memory_thread'):
self._pin_memory_thread_done_event.set()
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.cancel_join_thread()
self._worker_result_queue.close()
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers: # pylint: disable=invalid-name
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
for q in self._index_queues: # pylint: disable=invalid-name
q.cancel_join_thread()
q.close()
finally:
if self._worker_pids_set:
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers: # pylint: disable=invalid-name
if w.is_alive():
w.terminate()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device):
return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
def ipex_no_cuda(orig_func, *args, **kwargs):
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
return original_autocast("xpu", *args[1:], **kwargs)
else:
return original_autocast(*args, **kwargs)
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
original_interpolate = torch.nn.functional.interpolate
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
else:
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
original_linalg_solve = torch.linalg.solve
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
return_device = A.device
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
else:
return original_linalg_solve(A, B, *args, **kwargs)
def ipex_hijacks():
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.ones',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
#Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
#Broken functions when torch.cuda.is_available is True:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
#Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
torch.nn.DataParallel = DummyDataParallel
torch.autocast = ipex_autocast
torch.cat = torch_cat
torch.linalg.solve = linalg_solve
torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context

View File

@ -17,6 +17,18 @@ n_gpus = len(hps.gpus.split("-"))
from random import randint, shuffle from random import randint, shuffle
import torch import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
if torch.xpu.is_available():
from infer.modules.ipex import ipex_init
from infer.modules.ipex.gradscaler import gradscaler_init
from torch.xpu.amp import autocast
GradScaler = gradscaler_init()
ipex_init()
else:
from torch.cuda.amp import GradScaler, autocast
except Exception:
from torch.cuda.amp import GradScaler, autocast
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
@ -25,7 +37,6 @@ from time import time as ttime
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -185,7 +196,9 @@ def run(rank, n_gpus, hps):
) )
# net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
if torch.cuda.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
pass
elif torch.cuda.is_available():
net_g = DDP(net_g, device_ids=[rank]) net_g = DDP(net_g, device_ids=[rank])
net_d = DDP(net_d, device_ids=[rank]) net_d = DDP(net_d, device_ids=[rank])
else: else:
@ -212,19 +225,33 @@ def run(rank, n_gpus, hps):
if hps.pretrainG != "": if hps.pretrainG != "":
if rank == 0: if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainG)) logger.info("loaded pretrained %s" % (hps.pretrainG))
logger.info( if hasattr(net_g, "module"):
net_g.module.load_state_dict( logger.info(
torch.load(hps.pretrainG, map_location="cpu")["model"] net_g.module.load_state_dict(
) torch.load(hps.pretrainG, map_location="cpu")["model"]
) ##测试不加载优化器 )
) ##测试不加载优化器
else:
logger.info(
net_g.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
if hps.pretrainD != "": if hps.pretrainD != "":
if rank == 0: if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainD)) logger.info("loaded pretrained %s" % (hps.pretrainD))
logger.info( if hasattr(net_d, "module"):
net_d.module.load_state_dict( logger.info(
torch.load(hps.pretrainD, map_location="cpu")["model"] net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
)
else:
logger.info(
net_d.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
) )
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR( scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2

View File

@ -76,10 +76,18 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
infos.append("%s->Success" % (os.path.basename(inp_path))) infos.append("%s->Success" % (os.path.basename(inp_path)))
yield "\n".join(infos) yield "\n".join(infos)
except: except:
infos.append( try:
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc()) if done == 0:
) pre_fun._path_audio_(
yield "\n".join(infos) inp_path, save_root_ins, save_root_vocal, format0
)
infos.append("%s->Success" % (os.path.basename(inp_path)))
yield "\n".join(infos)
except:
infos.append(
"%s->%s" % (os.path.basename(inp_path), traceback.format_exc())
)
yield "\n".join(infos)
except: except:
infos.append(traceback.format_exc()) infos.append(traceback.format_exc())
yield "\n".join(infos) yield "\n".join(infos)

View File

@ -362,7 +362,7 @@ class Pipeline(object):
) )
pitch = pitch[:p_len] pitch = pitch[:p_len]
pitchf = pitchf[:p_len] pitchf = pitchf[:p_len]
if self.device == "mps": if self.device == "mps" or "xpu" in self.device:
pitchf = pitchf.astype(np.float32) pitchf = pitchf.astype(np.float32)
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long() pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float() pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()

54
requirements-ipex.txt Normal file
View File

@ -0,0 +1,54 @@
torch==2.0.1a0
intel_extension_for_pytorch==2.0.110+xpu
torchvision==0.15.2a0
https://github.com/Disty0/Retrieval-based-Voice-Conversion-WebUI/releases/download/torchaudio_wheels_for_ipex/torchaudio-2.0.2+31de77d-cp310-cp310-linux_x86_64.whl
-f https://developer.intel.com/ipex-whl-stable-xpu
joblib>=1.1.0
numba==0.56.4
numpy==1.23.5
scipy
librosa==0.9.1
llvmlite==0.39.0
fairseq==0.12.2
faiss-cpu==1.7.3
gradio==3.34.0
Cython
pydub>=0.25.1
soundfile>=0.12.1
ffmpeg-python>=0.2.0
tensorboardX
Jinja2>=3.1.2
json5
Markdown
matplotlib>=3.7.0
matplotlib-inline>=0.1.3
praat-parselmouth>=0.4.2
Pillow>=9.1.1
resampy>=0.4.2
scikit-learn
tensorboard
tqdm>=4.63.1
tornado>=6.1
Werkzeug>=2.2.3
uc-micro-py>=1.0.1
sympy>=1.11.1
tabulate>=0.8.10
PyYAML>=6.0
pyasn1>=0.4.8
pyasn1-modules>=0.2.8
fsspec>=2022.11.0
absl-py>=1.2.0
audioread
uvicorn>=0.21.1
colorama>=0.4.5
pyworld==0.3.2
httpx
onnxruntime; sys_platform == 'darwin'
onnxruntime-gpu; sys_platform != 'darwin'
torchcrepe==0.0.20
fastapi==0.88
ffmpy==0.3.1
python-dotenv>=1.0.0
av
PySimpleGUI
sounddevice