patch DDPM.register_betas so that users can put given_betas in model yaml

This commit is contained in:
woweenie 2023-09-15 18:59:44 +02:00 committed by GitHub
parent 5ef669de08
commit d9d94141dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,7 @@ import threading
import torch import torch
import re import re
import safetensors.torch import safetensors.torch
from omegaconf import OmegaConf from omegaconf import OmegaConf, ListConfig
from os import mkdir from os import mkdir
from urllib import request from urllib import request
import ldm.modules.midas as midas import ldm.modules.midas as midas
@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd import tomesd
import numpy as np
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@ -132,6 +133,7 @@ def setup_model():
os.makedirs(model_path, exist_ok=True) os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload() enable_midas_autodownload()
patch_given_betas()
def checkpoint_tiles(use_short=False): def checkpoint_tiles(use_short=False):
@ -453,6 +455,17 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
def patch_given_betas():
original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule
def patched_register_schedule(*args, **kwargs):
if args[1] is not None and isinstance(args[1], ListConfig):
modified_args = list(args) # Convert args tuple to a list
modified_args[1] = np.array(args[1]) # Modify the desired element
args = tuple(modified_args) # Convert the list back to a tuple
original_register_schedule(*args, **kwargs)
ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule
def repair_config(sd_config): def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"): if not hasattr(sd_config.model.params, "use_ema"):