mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-03 04:02:56 +08:00
patch DDPM.register_betas so that users can put given_betas in model yaml
This commit is contained in:
parent
5ef669de08
commit
d9d94141dc
@ -7,7 +7,7 @@ import threading
|
|||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf, ListConfig
|
||||||
from os import mkdir
|
from os import mkdir
|
||||||
from urllib import request
|
from urllib import request
|
||||||
import ldm.modules.midas as midas
|
import ldm.modules.midas as midas
|
||||||
@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
|
|||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
@ -132,6 +133,7 @@ def setup_model():
|
|||||||
os.makedirs(model_path, exist_ok=True)
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
|
||||||
enable_midas_autodownload()
|
enable_midas_autodownload()
|
||||||
|
patch_given_betas()
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_tiles(use_short=False):
|
def checkpoint_tiles(use_short=False):
|
||||||
@ -453,6 +455,17 @@ def enable_midas_autodownload():
|
|||||||
midas.api.load_model = load_model_wrapper
|
midas.api.load_model = load_model_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def patch_given_betas():
|
||||||
|
original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule
|
||||||
|
def patched_register_schedule(*args, **kwargs):
|
||||||
|
if args[1] is not None and isinstance(args[1], ListConfig):
|
||||||
|
modified_args = list(args) # Convert args tuple to a list
|
||||||
|
modified_args[1] = np.array(args[1]) # Modify the desired element
|
||||||
|
args = tuple(modified_args) # Convert the list back to a tuple
|
||||||
|
original_register_schedule(*args, **kwargs)
|
||||||
|
ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule
|
||||||
|
|
||||||
|
|
||||||
def repair_config(sd_config):
|
def repair_config(sd_config):
|
||||||
|
|
||||||
if not hasattr(sd_config.model.params, "use_ema"):
|
if not hasattr(sd_config.model.params, "use_ema"):
|
||||||
|
Loading…
Reference in New Issue
Block a user