mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 15:15:05 +08:00
rework torchsde._brownian.brownian_interval replacement to use device.randn_local and respect the NV setting.
This commit is contained in:
parent
84b6fcd02c
commit
fca42949a3
@ -71,14 +71,17 @@ def enable_tf32():
|
|||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
cpu = torch.device("cpu")
|
cpu: torch.device = torch.device("cpu")
|
||||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
device: torch.device = None
|
||||||
dtype = torch.float16
|
device_interrogate: torch.device = None
|
||||||
dtype_vae = torch.float16
|
device_gfpgan: torch.device = None
|
||||||
dtype_unet = torch.float16
|
device_esrgan: torch.device = None
|
||||||
|
device_codeformer: torch.device = None
|
||||||
|
dtype: torch.dtype = torch.float16
|
||||||
|
dtype_vae: torch.dtype = torch.float16
|
||||||
|
dtype_unet: torch.dtype = torch.float16
|
||||||
unet_needs_upcast = False
|
unet_needs_upcast = False
|
||||||
|
|
||||||
|
|
||||||
@ -94,6 +97,10 @@ nv_rng = None
|
|||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||||
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
manual_seed(seed)
|
manual_seed(seed)
|
||||||
@ -107,7 +114,27 @@ def randn(seed, shape):
|
|||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_local(seed, shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
||||||
|
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
if opts.randn_source == "NV":
|
||||||
|
rng = rng_philox.Generator(seed)
|
||||||
|
return torch.asarray(rng.randn(shape), device=device)
|
||||||
|
|
||||||
|
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
|
||||||
|
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
||||||
|
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
|
||||||
|
|
||||||
|
|
||||||
def randn_like(x):
|
def randn_like(x):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
if opts.randn_source == "NV":
|
||||||
@ -120,6 +147,10 @@ def randn_like(x):
|
|||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape):
|
def randn_without_seed(shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
if opts.randn_source == "NV":
|
||||||
@ -132,6 +163,7 @@ def randn_without_seed(shape):
|
|||||||
|
|
||||||
|
|
||||||
def manual_seed(seed):
|
def manual_seed(seed):
|
||||||
|
"""Set up a global random number generator using the specified seed."""
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
if opts.randn_source == "NV":
|
if opts.randn_source == "NV":
|
||||||
|
@ -2,10 +2,8 @@ from collections import namedtuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
|
||||||
|
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
@ -85,11 +83,13 @@ class InterruptedException(BaseException):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU":
|
def replace_torchsde_browinan():
|
||||||
import torchsde._brownian.brownian_interval
|
import torchsde._brownian.brownian_interval
|
||||||
|
|
||||||
def torchsde_randn(size, dtype, device, seed):
|
def torchsde_randn(size, dtype, device, seed):
|
||||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
return devices.randn_local(seed, size).to(device=device, dtype=dtype)
|
||||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
|
||||||
|
|
||||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||||
|
|
||||||
|
|
||||||
|
replace_torchsde_browinan()
|
||||||
|
Loading…
Reference in New Issue
Block a user