use dataclass for StableDiffusionProcessing

This commit is contained in:
AUTOMATIC1111 2023-08-13 08:24:16 +03:00
parent fa9370b741
commit 599f61a1e0
2 changed files with 172 additions and 143 deletions

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import json import json
import logging import logging
import math import math
import os import os
import sys import sys
import hashlib import hashlib
from dataclasses import dataclass, field
import torch import torch
import numpy as np import numpy as np
@ -11,7 +13,7 @@ from PIL import Image, ImageOps
import random import random
import cv2 import cv2
from skimage import exposure from skimage import exposure
from typing import Any, Dict, List from typing import Any
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
@ -104,106 +106,126 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
@dataclass(repr=False)
class StableDiffusionProcessing: class StableDiffusionProcessing:
""" sd_model: object = None
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing outpath_samples: str = None
""" outpath_grids: str = None
prompt: str = ""
prompt_for_display: str = None
negative_prompt: str = ""
styles: list[str] = field(default_factory=list)
seed: int = -1
subseed: int = -1
subseed_strength: float = 0
seed_resize_from_h: int = -1
seed_resize_from_w: int = -1
seed_enable_extras: bool = True
sampler_name: str = None
batch_size: int = 1
n_iter: int = 1
steps: int = 50
cfg_scale: float = 7.0
width: int = 512
height: int = 512
restore_faces: bool = None
tiling: bool = None
do_not_save_samples: bool = False
do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = 0
ddim_discretize: str = None
s_min_uncond: float = None
s_churn: float = None
s_tmax: float = None
s_tmin: float = None
s_noise: float = None
override_settings: dict[str, Any] = None
override_settings_restore_afterwards: bool = True
sampler_index: int = None
refiner_checkpoint: str = None
refiner_switch_at: float = None
token_merging_ratio = 0
token_merging_ratio_hr = 0
disable_extra_networks: bool = False
script_args: list = None
cached_uc = [None, None] cached_uc = [None, None]
cached_c = [None, None] cached_c = [None, None]
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, refiner_checkpoint: str = None, refiner_switch_at: float = None, script_args: list = None): sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
if sampler_index is not None: is_using_inpainting_conditioning: bool = field(default=False, init=False)
paste_to: tuple | None = field(default=None, init=False)
is_hr_pass: bool = field(default=False, init=False)
c: tuple = field(default=None, init=False)
uc: tuple = field(default=None, init=False)
rng: rng.ImageRNG | None = field(default=None, init=False)
step_multiplier: int = field(default=1, init=False)
color_corrections: list = field(default=None, init=False)
scripts: list = field(default=None, init=False)
all_prompts: list = field(default=None, init=False)
all_negative_prompts: list = field(default=None, init=False)
all_seeds: list = field(default=None, init=False)
all_subseeds: list = field(default=None, init=False)
iteration: int = field(default=0, init=False)
main_prompt: str = field(default=None, init=False)
main_negative_prompt: str = field(default=None, init=False)
prompts: list = field(default=None, init=False)
negative_prompts: list = field(default=None, init=False)
seeds: list = field(default=None, init=False)
subseeds: list = field(default=None, init=False)
extra_network_data: dict = field(default=None, init=False)
user: str = field(default=None, init=False)
sd_model_name: str = field(default=None, init=False)
sd_model_hash: str = field(default=None, init=False)
sd_vae_name: str = field(default=None, init=False)
sd_vae_hash: str = field(default=None, init=False)
def __post_init__(self):
if self.sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w
self.sampler_name: str = sampler_name
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
self.cfg_scale: float = cfg_scale
self.width: int = width
self.height: int = height
self.restore_faces: bool = restore_faces
self.tiling: bool = tiling
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images
self.eta = eta
self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None self.sampler_noise_scheduler_override = None
self.ddim_discretize = ddim_discretize or opts.ddim_discretize self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
self.s_min_uncond = s_min_uncond or opts.s_min_uncond self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
self.s_churn = s_churn or opts.s_churn self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
self.s_tmin = s_tmin or opts.s_tmin self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf') self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
self.s_noise = s_noise if s_noise is not None else opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.extra_generation_params = self.extra_generation_params or {}
self.override_settings_restore_afterwards = override_settings_restore_afterwards self.override_settings = self.override_settings or {}
self.refiner_checkpoint = refiner_checkpoint self.script_args = self.script_args or {}
self.refiner_switch_at = refiner_switch_at
self.is_using_inpainting_conditioning = False
self.disable_extra_networks = False
self.token_merging_ratio = 0
self.token_merging_ratio_hr = 0
self.refiner_checkpoint_info = None self.refiner_checkpoint_info = None
if not seed_enable_extras: if not self.seed_enable_extras:
self.subseed = -1 self.subseed = -1
self.subseed_strength = 0 self.subseed_strength = 0
self.seed_resize_from_h = 0 self.seed_resize_from_h = 0
self.seed_resize_from_w = 0 self.seed_resize_from_w = 0
self.scripts = None
self.script_args = script_args
self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None
self.all_subseeds = None
self.iteration = 0
self.is_hr_pass = False
self.sampler = None
self.main_prompt = None
self.main_negative_prompt = None
self.prompts = None
self.negative_prompts = None
self.extra_network_data = None
self.seeds = None
self.subseeds = None
self.step_multiplier = 1
self.cached_uc = StableDiffusionProcessing.cached_uc self.cached_uc = StableDiffusionProcessing.cached_uc
self.cached_c = StableDiffusionProcessing.cached_c self.cached_c = StableDiffusionProcessing.cached_c
self.uc = None
self.c = None
self.rng: rng.ImageRNG = None
self.user = None
self.sd_model_name = None
self.sd_model_hash = None
self.sd_vae_name = None
self.sd_vae_hash = None
@property @property
def sd_model(self): def sd_model(self):
return shared.sd_model return shared.sd_model
@sd_model.setter
def sd_model(self, value):
pass
def txt2img_image_conditioning(self, x, width=None, height=None): def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@ -932,49 +954,51 @@ def old_hires_fix_first_pass_dimensions(width, height):
return width, height return width, height
@dataclass(repr=False)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None enable_hr: bool = False
denoising_strength: float = 0.75
firstphase_width: int = 0
firstphase_height: int = 0
hr_scale: float = 2.0
hr_upscaler: str = None
hr_second_pass_steps: int = 0
hr_resize_x: int = 0
hr_resize_y: int = 0
hr_checkpoint_name: str = None
hr_sampler_name: str = None
hr_prompt: str = ''
hr_negative_prompt: str = ''
cached_hr_uc = [None, None] cached_hr_uc = [None, None]
cached_hr_c = [None, None] cached_hr_c = [None, None]
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): hr_checkpoint_info: dict = field(default=None, init=False)
super().__init__(**kwargs) hr_upscale_to_x: int = field(default=0, init=False)
self.enable_hr = enable_hr hr_upscale_to_y: int = field(default=0, init=False)
self.denoising_strength = denoising_strength truncate_x: int = field(default=0, init=False)
self.hr_scale = hr_scale truncate_y: int = field(default=0, init=False)
self.hr_upscaler = hr_upscaler applied_old_hires_behavior_to: tuple = field(default=None, init=False)
self.hr_second_pass_steps = hr_second_pass_steps latent_scale_mode: dict = field(default=None, init=False)
self.hr_resize_x = hr_resize_x hr_c: tuple | None = field(default=None, init=False)
self.hr_resize_y = hr_resize_y hr_uc: tuple | None = field(default=None, init=False)
self.hr_upscale_to_x = hr_resize_x all_hr_prompts: list = field(default=None, init=False)
self.hr_upscale_to_y = hr_resize_y all_hr_negative_prompts: list = field(default=None, init=False)
self.hr_checkpoint_name = hr_checkpoint_name hr_prompts: list = field(default=None, init=False)
self.hr_checkpoint_info = None hr_negative_prompts: list = field(default=None, init=False)
self.hr_sampler_name = hr_sampler_name hr_extra_network_data: list = field(default=None, init=False)
self.hr_prompt = hr_prompt
self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
self.latent_scale_mode = None
if firstphase_width != 0 or firstphase_height != 0: def __post_init__(self):
super().__post_init__()
if self.firstphase_width != 0 or self.firstphase_height != 0:
self.hr_upscale_to_x = self.width self.hr_upscale_to_x = self.width
self.hr_upscale_to_y = self.height self.hr_upscale_to_y = self.height
self.width = firstphase_width self.width = self.firstphase_width
self.height = firstphase_height self.height = self.firstphase_height
self.truncate_x = 0
self.truncate_y = 0
self.applied_old_hires_behavior_to = None
self.hr_prompts = None
self.hr_negative_prompts = None
self.hr_extra_network_data = None
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
self.hr_c = None
self.hr_uc = None
def calculate_target_resolution(self): def calculate_target_resolution(self):
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
@ -1252,7 +1276,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return super().get_conds() return super().get_conds()
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts() res = super().parse_extra_network_prompts()
@ -1265,32 +1288,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return res return res
@dataclass(repr=False)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None init_images: list = None
resize_mode: int = 0
denoising_strength: float = 0.75
image_cfg_scale: float = None
mask: Any = None
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
inpainting_mask_invert: int = 0
initial_noise_multiplier: float = None
latent_mask: Image = None
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): image_mask: Any = field(default=None, init=False)
super().__init__(**kwargs)
self.init_images = init_images nmask: torch.Tensor = field(default=None, init=False)
self.resize_mode: int = resize_mode image_conditioning: torch.Tensor = field(default=None, init=False)
self.denoising_strength: float = denoising_strength init_img_hash: str = field(default=None, init=False)
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None mask_for_overlay: Image = field(default=None, init=False)
self.init_latent = None init_latent: torch.Tensor = field(default=None, init=False)
self.image_mask = mask
self.latent_mask = None def __post_init__(self):
self.mask_for_overlay = None super().__post_init__()
self.mask_blur_x = mask_blur_x
self.mask_blur_y = mask_blur_y self.image_mask = self.mask
if mask_blur is not None:
self.mask_blur = mask_blur
self.inpainting_fill = inpainting_fill
self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding
self.inpainting_mask_invert = inpainting_mask_invert
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
self.mask = None self.mask = None
self.nmask = None self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
self.image_conditioning = None
@property @property
def mask_blur(self): def mask_blur(self):
@ -1300,15 +1328,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
@mask_blur.setter @mask_blur.setter
def mask_blur(self, value): def mask_blur(self, value):
self.mask_blur_x = value if isinstance(value, int):
self.mask_blur_y = value self.mask_blur_x = value
self.mask_blur_y = value
@mask_blur.deleter
def mask_blur(self):
del self.mask_blur_x
del self.mask_blur_y
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None crop_region = None

View File

@ -305,5 +305,8 @@ class Sampler:
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()