mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
use dataclass for StableDiffusionProcessing
This commit is contained in:
parent
fa9370b741
commit
599f61a1e0
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@ -11,7 +13,7 @@ from PIL import Image, ImageOps
|
||||
import random
|
||||
import cv2
|
||||
from skimage import exposure
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||
@ -104,106 +106,126 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class StableDiffusionProcessing:
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
"""
|
||||
sd_model: object = None
|
||||
outpath_samples: str = None
|
||||
outpath_grids: str = None
|
||||
prompt: str = ""
|
||||
prompt_for_display: str = None
|
||||
negative_prompt: str = ""
|
||||
styles: list[str] = field(default_factory=list)
|
||||
seed: int = -1
|
||||
subseed: int = -1
|
||||
subseed_strength: float = 0
|
||||
seed_resize_from_h: int = -1
|
||||
seed_resize_from_w: int = -1
|
||||
seed_enable_extras: bool = True
|
||||
sampler_name: str = None
|
||||
batch_size: int = 1
|
||||
n_iter: int = 1
|
||||
steps: int = 50
|
||||
cfg_scale: float = 7.0
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
restore_faces: bool = None
|
||||
tiling: bool = None
|
||||
do_not_save_samples: bool = False
|
||||
do_not_save_grid: bool = False
|
||||
extra_generation_params: dict[str, Any] = None
|
||||
overlay_images: list = None
|
||||
eta: float = None
|
||||
do_not_reload_embeddings: bool = False
|
||||
denoising_strength: float = 0
|
||||
ddim_discretize: str = None
|
||||
s_min_uncond: float = None
|
||||
s_churn: float = None
|
||||
s_tmax: float = None
|
||||
s_tmin: float = None
|
||||
s_noise: float = None
|
||||
override_settings: dict[str, Any] = None
|
||||
override_settings_restore_afterwards: bool = True
|
||||
sampler_index: int = None
|
||||
refiner_checkpoint: str = None
|
||||
refiner_switch_at: float = None
|
||||
token_merging_ratio = 0
|
||||
token_merging_ratio_hr = 0
|
||||
disable_extra_networks: bool = False
|
||||
|
||||
script_args: list = None
|
||||
|
||||
cached_uc = [None, None]
|
||||
cached_c = [None, None]
|
||||
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, refiner_checkpoint: str = None, refiner_switch_at: float = None, script_args: list = None):
|
||||
if sampler_index is not None:
|
||||
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
||||
is_using_inpainting_conditioning: bool = field(default=False, init=False)
|
||||
paste_to: tuple | None = field(default=None, init=False)
|
||||
|
||||
is_hr_pass: bool = field(default=False, init=False)
|
||||
|
||||
c: tuple = field(default=None, init=False)
|
||||
uc: tuple = field(default=None, init=False)
|
||||
|
||||
rng: rng.ImageRNG | None = field(default=None, init=False)
|
||||
step_multiplier: int = field(default=1, init=False)
|
||||
color_corrections: list = field(default=None, init=False)
|
||||
|
||||
scripts: list = field(default=None, init=False)
|
||||
all_prompts: list = field(default=None, init=False)
|
||||
all_negative_prompts: list = field(default=None, init=False)
|
||||
all_seeds: list = field(default=None, init=False)
|
||||
all_subseeds: list = field(default=None, init=False)
|
||||
iteration: int = field(default=0, init=False)
|
||||
main_prompt: str = field(default=None, init=False)
|
||||
main_negative_prompt: str = field(default=None, init=False)
|
||||
|
||||
prompts: list = field(default=None, init=False)
|
||||
negative_prompts: list = field(default=None, init=False)
|
||||
seeds: list = field(default=None, init=False)
|
||||
subseeds: list = field(default=None, init=False)
|
||||
extra_network_data: dict = field(default=None, init=False)
|
||||
|
||||
user: str = field(default=None, init=False)
|
||||
|
||||
sd_model_name: str = field(default=None, init=False)
|
||||
sd_model_hash: str = field(default=None, init=False)
|
||||
sd_vae_name: str = field(default=None, init=False)
|
||||
sd_vae_hash: str = field(default=None, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
self.prompt: str = prompt
|
||||
self.prompt_for_display: str = None
|
||||
self.negative_prompt: str = (negative_prompt or "")
|
||||
self.styles: list = styles or []
|
||||
self.seed: int = seed
|
||||
self.subseed: int = subseed
|
||||
self.subseed_strength: float = subseed_strength
|
||||
self.seed_resize_from_h: int = seed_resize_from_h
|
||||
self.seed_resize_from_w: int = seed_resize_from_w
|
||||
self.sampler_name: str = sampler_name
|
||||
self.batch_size: int = batch_size
|
||||
self.n_iter: int = n_iter
|
||||
self.steps: int = steps
|
||||
self.cfg_scale: float = cfg_scale
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
self.restore_faces: bool = restore_faces
|
||||
self.tiling: bool = tiling
|
||||
self.do_not_save_samples: bool = do_not_save_samples
|
||||
self.do_not_save_grid: bool = do_not_save_grid
|
||||
self.extra_generation_params: dict = extra_generation_params or {}
|
||||
self.overlay_images = overlay_images
|
||||
self.eta = eta
|
||||
self.do_not_reload_embeddings = do_not_reload_embeddings
|
||||
self.paste_to = None
|
||||
self.color_corrections = None
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
||||
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
||||
self.s_churn = s_churn or opts.s_churn
|
||||
self.s_tmin = s_tmin or opts.s_tmin
|
||||
self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
|
||||
self.s_noise = s_noise if s_noise is not None else opts.s_noise
|
||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||
self.refiner_checkpoint = refiner_checkpoint
|
||||
self.refiner_switch_at = refiner_switch_at
|
||||
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
|
||||
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
|
||||
self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
|
||||
self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
|
||||
self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
|
||||
|
||||
self.extra_generation_params = self.extra_generation_params or {}
|
||||
self.override_settings = self.override_settings or {}
|
||||
self.script_args = self.script_args or {}
|
||||
|
||||
self.is_using_inpainting_conditioning = False
|
||||
self.disable_extra_networks = False
|
||||
self.token_merging_ratio = 0
|
||||
self.token_merging_ratio_hr = 0
|
||||
self.refiner_checkpoint_info = None
|
||||
|
||||
if not seed_enable_extras:
|
||||
if not self.seed_enable_extras:
|
||||
self.subseed = -1
|
||||
self.subseed_strength = 0
|
||||
self.seed_resize_from_h = 0
|
||||
self.seed_resize_from_w = 0
|
||||
|
||||
self.scripts = None
|
||||
self.script_args = script_args
|
||||
self.all_prompts = None
|
||||
self.all_negative_prompts = None
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
self.is_hr_pass = False
|
||||
self.sampler = None
|
||||
self.main_prompt = None
|
||||
self.main_negative_prompt = None
|
||||
|
||||
self.prompts = None
|
||||
self.negative_prompts = None
|
||||
self.extra_network_data = None
|
||||
self.seeds = None
|
||||
self.subseeds = None
|
||||
|
||||
self.step_multiplier = 1
|
||||
self.cached_uc = StableDiffusionProcessing.cached_uc
|
||||
self.cached_c = StableDiffusionProcessing.cached_c
|
||||
self.uc = None
|
||||
self.c = None
|
||||
self.rng: rng.ImageRNG = None
|
||||
|
||||
self.user = None
|
||||
|
||||
self.sd_model_name = None
|
||||
self.sd_model_hash = None
|
||||
self.sd_vae_name = None
|
||||
self.sd_vae_hash = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
return shared.sd_model
|
||||
|
||||
@sd_model.setter
|
||||
def sd_model(self, value):
|
||||
pass
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
|
||||
@ -932,49 +954,51 @@ def old_hires_fix_first_pass_dimensions(width, height):
|
||||
return width, height
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
enable_hr: bool = False
|
||||
denoising_strength: float = 0.75
|
||||
firstphase_width: int = 0
|
||||
firstphase_height: int = 0
|
||||
hr_scale: float = 2.0
|
||||
hr_upscaler: str = None
|
||||
hr_second_pass_steps: int = 0
|
||||
hr_resize_x: int = 0
|
||||
hr_resize_y: int = 0
|
||||
hr_checkpoint_name: str = None
|
||||
hr_sampler_name: str = None
|
||||
hr_prompt: str = ''
|
||||
hr_negative_prompt: str = ''
|
||||
|
||||
cached_hr_uc = [None, None]
|
||||
cached_hr_c = [None, None]
|
||||
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.denoising_strength = denoising_strength
|
||||
self.hr_scale = hr_scale
|
||||
self.hr_upscaler = hr_upscaler
|
||||
self.hr_second_pass_steps = hr_second_pass_steps
|
||||
self.hr_resize_x = hr_resize_x
|
||||
self.hr_resize_y = hr_resize_y
|
||||
self.hr_upscale_to_x = hr_resize_x
|
||||
self.hr_upscale_to_y = hr_resize_y
|
||||
self.hr_checkpoint_name = hr_checkpoint_name
|
||||
self.hr_checkpoint_info = None
|
||||
self.hr_sampler_name = hr_sampler_name
|
||||
self.hr_prompt = hr_prompt
|
||||
self.hr_negative_prompt = hr_negative_prompt
|
||||
self.all_hr_prompts = None
|
||||
self.all_hr_negative_prompts = None
|
||||
self.latent_scale_mode = None
|
||||
hr_checkpoint_info: dict = field(default=None, init=False)
|
||||
hr_upscale_to_x: int = field(default=0, init=False)
|
||||
hr_upscale_to_y: int = field(default=0, init=False)
|
||||
truncate_x: int = field(default=0, init=False)
|
||||
truncate_y: int = field(default=0, init=False)
|
||||
applied_old_hires_behavior_to: tuple = field(default=None, init=False)
|
||||
latent_scale_mode: dict = field(default=None, init=False)
|
||||
hr_c: tuple | None = field(default=None, init=False)
|
||||
hr_uc: tuple | None = field(default=None, init=False)
|
||||
all_hr_prompts: list = field(default=None, init=False)
|
||||
all_hr_negative_prompts: list = field(default=None, init=False)
|
||||
hr_prompts: list = field(default=None, init=False)
|
||||
hr_negative_prompts: list = field(default=None, init=False)
|
||||
hr_extra_network_data: list = field(default=None, init=False)
|
||||
|
||||
if firstphase_width != 0 or firstphase_height != 0:
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.firstphase_width != 0 or self.firstphase_height != 0:
|
||||
self.hr_upscale_to_x = self.width
|
||||
self.hr_upscale_to_y = self.height
|
||||
self.width = firstphase_width
|
||||
self.height = firstphase_height
|
||||
|
||||
self.truncate_x = 0
|
||||
self.truncate_y = 0
|
||||
self.applied_old_hires_behavior_to = None
|
||||
|
||||
self.hr_prompts = None
|
||||
self.hr_negative_prompts = None
|
||||
self.hr_extra_network_data = None
|
||||
self.width = self.firstphase_width
|
||||
self.height = self.firstphase_height
|
||||
|
||||
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
||||
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
||||
self.hr_c = None
|
||||
self.hr_uc = None
|
||||
|
||||
def calculate_target_resolution(self):
|
||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||
@ -1252,7 +1276,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
return super().get_conds()
|
||||
|
||||
|
||||
def parse_extra_network_prompts(self):
|
||||
res = super().parse_extra_network_prompts()
|
||||
|
||||
@ -1265,32 +1288,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
return res
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
init_images: list = None
|
||||
resize_mode: int = 0
|
||||
denoising_strength: float = 0.75
|
||||
image_cfg_scale: float = None
|
||||
mask: Any = None
|
||||
mask_blur_x: int = 4
|
||||
mask_blur_y: int = 4
|
||||
mask_blur: int = None
|
||||
inpainting_fill: int = 0
|
||||
inpaint_full_res: bool = True
|
||||
inpaint_full_res_padding: int = 0
|
||||
inpainting_mask_invert: int = 0
|
||||
initial_noise_multiplier: float = None
|
||||
latent_mask: Image = None
|
||||
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
image_mask: Any = field(default=None, init=False)
|
||||
|
||||
self.init_images = init_images
|
||||
self.resize_mode: int = resize_mode
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
self.init_latent = None
|
||||
self.image_mask = mask
|
||||
self.latent_mask = None
|
||||
self.mask_for_overlay = None
|
||||
self.mask_blur_x = mask_blur_x
|
||||
self.mask_blur_y = mask_blur_y
|
||||
if mask_blur is not None:
|
||||
self.mask_blur = mask_blur
|
||||
self.inpainting_fill = inpainting_fill
|
||||
self.inpaint_full_res = inpaint_full_res
|
||||
self.inpaint_full_res_padding = inpaint_full_res_padding
|
||||
self.inpainting_mask_invert = inpainting_mask_invert
|
||||
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
|
||||
nmask: torch.Tensor = field(default=None, init=False)
|
||||
image_conditioning: torch.Tensor = field(default=None, init=False)
|
||||
init_img_hash: str = field(default=None, init=False)
|
||||
mask_for_overlay: Image = field(default=None, init=False)
|
||||
init_latent: torch.Tensor = field(default=None, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
self.image_mask = self.mask
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.image_conditioning = None
|
||||
self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
|
||||
|
||||
@property
|
||||
def mask_blur(self):
|
||||
@ -1300,15 +1328,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
|
||||
@mask_blur.setter
|
||||
def mask_blur(self, value):
|
||||
self.mask_blur_x = value
|
||||
self.mask_blur_y = value
|
||||
|
||||
@mask_blur.deleter
|
||||
def mask_blur(self):
|
||||
del self.mask_blur_x
|
||||
del self.mask_blur_y
|
||||
if isinstance(value, int):
|
||||
self.mask_blur_x = value
|
||||
self.mask_blur_y = value
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
crop_region = None
|
||||
|
||||
|
@ -305,5 +305,8 @@ class Sampler:
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
Loading…
Reference in New Issue
Block a user