This commit is contained in:
AUTOMATIC1111 2023-12-02 23:42:06 +03:00
parent ac02216e54
commit 051375258c
15 changed files with 244 additions and 135 deletions

View File

@ -108,14 +108,6 @@ function get_img2img_tab_index() {
function create_submit_args(args) {
var res = Array.from(args);
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
// I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
// If gradio at some point stops sending outputs, this may break something
if (Array.isArray(res[res.length - 3])) {
res[res.length - 3] = null;
}
return res;
}
@ -183,7 +175,6 @@ function submit_extras() {
res[0] = id;
console.log(res);
return res;
}

View File

@ -207,7 +207,7 @@ class Api:
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
#api_middleware(self.app) # XXX this will have to be fixed
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)

View File

@ -1,6 +1,6 @@
import inspect
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, Field, create_model, ConfigDict
from typing import Any, Optional, Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
@ -92,9 +92,7 @@ class PydanticModelGenerator:
fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
DynamicModel = create_model(self._model_name, __config__=ConfigDict(populate_by_name=True, frozen=True), **fields)
return DynamicModel
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
@ -232,6 +230,9 @@ class SamplerItem(BaseModel):
options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
class Config:
protected_namespaces = ()
name: str = Field(title="Name")
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
@ -242,6 +243,9 @@ class LatentUpscalerModeItem(BaseModel):
name: str = Field(title="Name")
class SDModelItem(BaseModel):
class Config:
protected_namespaces = ()
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
hash: Optional[str] = Field(title="Short hash")
@ -250,6 +254,9 @@ class SDModelItem(BaseModel):
config: Optional[str] = Field(title="Config file")
class SDVaeItem(BaseModel):
class Config:
protected_namespaces = ()
model_name: str = Field(title="Model Name")
filename: str = Field(title="Filename")

View File

@ -109,7 +109,7 @@ def check_versions():
expected_torch_version = "2.0.0"
expected_xformers_version = "0.0.20"
expected_gradio_version = "3.41.2"
expected_gradio_version = "4.7.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
print_error_explanation(f"""

View File

@ -0,0 +1,166 @@
import inspect
import warnings
from functools import wraps
import gradio as gr
import gradio.component_meta
from modules import scripts, ui_tempdir, patches
class GradioDeprecationWarning(DeprecationWarning):
pass
def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(getattr(comp, 'elem_classes', None) or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
def IOComponent_init(self, *args, **kwargs):
self.webui_tooltip = kwargs.pop('tooltip', None)
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_IOComponent_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res
def Block_get_config(self):
config = original_Block_get_config(self)
webui_tooltip = getattr(self, 'webui_tooltip', None)
if webui_tooltip:
config["webui_tooltip"] = webui_tooltip
config.pop('example_inputs', None)
return config
def BlockContext_init(self, *args, **kwargs):
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res
def Blocks_get_config_file(self, *args, **kwargs):
config = original_Blocks_get_config_file(self, *args, **kwargs)
for comp_config in config["components"]:
if "example_inputs" in comp_config:
comp_config["example_inputs"] = {"serialized": []}
return config
original_IOComponent_init = patches.patch(__name__, obj=gr.components.Component, field="__init__", replacement=IOComponent_init)
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
ui_tempdir.install_ui_tempdir_override()
def gradio_component_meta_create_or_modify_pyi(component_class, class_name, events):
if hasattr(component_class, 'webui_do_not_create_gradio_pyi_thank_you'):
return
gradio_component_meta_create_or_modify_pyi_original(component_class, class_name, events)
# this prevents creation of .pyi files in webui dir
gradio_component_meta_create_or_modify_pyi_original = patches.patch(__file__, gradio.component_meta, 'create_or_modify_pyi', gradio_component_meta_create_or_modify_pyi)
# this function is broken and does not seem to do anything useful
gradio.component_meta.updateable = lambda x: x
def repair(grclass):
if not getattr(grclass, 'EVENTS', None):
return
@wraps(grclass.__init__)
def __repaired_init__(self, *args, tooltip=None, source=None, original=grclass.__init__, **kwargs):
if source:
kwargs["sources"] = [source]
allowed_kwargs = inspect.signature(original).parameters
fixed_kwargs = {}
for k, v in kwargs.items():
if k in allowed_kwargs:
fixed_kwargs[k] = v
else:
warnings.warn(f"unexpected argument for {grclass.__name__}: {k}", GradioDeprecationWarning, stacklevel=2)
original(self, *args, **fixed_kwargs)
self.webui_tooltip = tooltip
for event in self.EVENTS:
replaced_event = getattr(self, str(event))
def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
if _js:
xkwargs['js'] = _js
return replaced_event(*xargs, **xkwargs)
setattr(self, str(event), fun)
grclass.__init__ = __repaired_init__
grclass.update = gr.update
for component in set(gr.components.__all__ + gr.layouts.__all__):
repair(getattr(gr, component, None))
class Dependency(gr.events.Dependency):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def then(*xargs, _js=None, **xkwargs):
if _js:
xkwargs['js'] = _js
return original_then(*xargs, **xkwargs)
original_then = self.then
self.then = then
gr.events.Dependency = Dependency
gr.Box = gr.Group

View File

@ -1,83 +0,0 @@
import gradio as gr
from modules import scripts, ui_tempdir, patches
def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
def IOComponent_init(self, *args, **kwargs):
self.webui_tooltip = kwargs.pop('tooltip', None)
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_IOComponent_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res
def Block_get_config(self):
config = original_Block_get_config(self)
webui_tooltip = getattr(self, 'webui_tooltip', None)
if webui_tooltip:
config["webui_tooltip"] = webui_tooltip
config.pop('example_inputs', None)
return config
def BlockContext_init(self, *args, **kwargs):
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res
def Blocks_get_config_file(self, *args, **kwargs):
config = original_Blocks_get_config_file(self, *args, **kwargs)
for comp_config in config["components"]:
if "example_inputs" in comp_config:
comp_config["example_inputs"] = {"serialized": []}
return config
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
ui_tempdir.install_ui_tempdir_override()

View File

@ -34,7 +34,7 @@ def imports():
shared_init.initialize()
startup_timer.record("initialize shared")
from modules import processing, gradio_extensons, ui # noqa: F401
from modules import processing, gradio_extensions, ui # noqa: F401
startup_timer.record("other imports")

View File

@ -4,6 +4,8 @@ import signal
import sys
import re
import starlette
from modules.timer import startup_timer
@ -183,8 +185,7 @@ def configure_opts_onchange():
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.user_middleware.insert(0, starlette.middleware.Middleware(GZipMiddleware, minimum_size=1000))
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
@ -202,5 +203,6 @@ def configure_cors_middleware(app):
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
if cmd_opts.cors_allow_origins_regex:
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
app.add_middleware(CORSMiddleware, **cors_options)
app.user_middleware.insert(0, starlette.middleware.Middleware(CORSMiddleware, **cors_options))

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import base64
import io
import time
@ -55,11 +56,11 @@ class ProgressResponse(BaseModel):
active: bool = Field(title="Whether the task is being worked on right now")
queued: bool = Field(title="Whether the task is in queue")
completed: bool = Field(title="Whether the task has already finished")
progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
eta: float = Field(default=None, title="ETA in secs")
live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
progress: float | None = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
eta: float | None = Field(default=None, title="ETA in secs")
live_preview: str | None = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
id_live_preview: int | None = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.")
def setup_progress_api(app):

View File

@ -12,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
from modules import gradio_extensions
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
@ -33,7 +33,7 @@ from modules.generation_parameters_copypaste import image_from_url_text
create_setting_component = ui_settings.create_setting_component
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gradio_extensions.GradioDeprecationWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@ -265,7 +265,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
dummy_component = gr.Label(visible=False)
dummy_component = gr.Textbox(visible=False)
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
extra_tabs.__enter__()
@ -311,7 +311,7 @@ def create_ui():
with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
with enable_hr.extra():
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution")
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)

View File

@ -1,7 +1,12 @@
from functools import wraps
import gradio as gr
from modules import gradio_extensions
class FormComponent:
webui_do_not_create_gradio_pyi_thank_you = True
def get_expected_parent(self):
return gr.components.Form
@ -9,12 +14,13 @@ class FormComponent:
gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
class ToolButton(FormComponent, gr.Button):
class ToolButton(gr.Button, FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, *args, **kwargs):
classes = kwargs.pop("elem_classes", [])
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
@wraps(gr.Button.__init__)
def __init__(self, value="", *args, elem_classes=None, **kwargs):
elem_classes = elem_classes or []
super().__init__(value=value, *args, elem_classes=["tool", *elem_classes], **kwargs)
def get_block_name(self):
return "button"
@ -22,7 +28,9 @@ class ToolButton(FormComponent, gr.Button):
class ResizeHandleRow(gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
webui_do_not_create_gradio_pyi_thank_you = True
@wraps(gr.Row.__init__)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -32,79 +40,92 @@ class ResizeHandleRow(gr.Row):
return "row"
class FormRow(FormComponent, gr.Row):
class FormRow(gr.Row, FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
class FormColumn(FormComponent, gr.Column):
class FormColumn(gr.Column, FormComponent):
"""Same as gr.Column but fits inside gradio forms"""
def get_block_name(self):
return "column"
class FormGroup(FormComponent, gr.Group):
class FormGroup(gr.Group, FormComponent):
"""Same as gr.Group but fits inside gradio forms"""
def get_block_name(self):
return "group"
class FormHTML(FormComponent, gr.HTML):
class FormHTML(gr.HTML, FormComponent):
"""Same as gr.HTML but fits inside gradio forms"""
def get_block_name(self):
return "html"
class FormColorPicker(FormComponent, gr.ColorPicker):
class FormColorPicker(gr.ColorPicker, FormComponent):
"""Same as gr.ColorPicker but fits inside gradio forms"""
def get_block_name(self):
return "colorpicker"
class DropdownMulti(FormComponent, gr.Dropdown):
class DropdownMulti(gr.Dropdown, FormComponent):
"""Same as gr.Dropdown but always multiselect"""
@wraps(gr.Dropdown.__init__)
def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs)
kwargs['multiselect'] = True
super().__init__(**kwargs)
def get_block_name(self):
return "dropdown"
class DropdownEditable(FormComponent, gr.Dropdown):
class DropdownEditable(gr.Dropdown, FormComponent):
"""Same as gr.Dropdown but allows editing value"""
@wraps(gr.Dropdown.__init__)
def __init__(self, **kwargs):
super().__init__(allow_custom_value=True, **kwargs)
kwargs['allow_custom_value'] = True
super().__init__(**kwargs)
def get_block_name(self):
return "dropdown"
class InputAccordion(gr.Checkbox):
class InputAccordionImpl(gr.Checkbox):
"""A gr.Accordion that can be used as an input - returns True if open, False if closed.
Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
"""
webui_do_not_create_gradio_pyi_thank_you = True
global_index = 0
def __init__(self, value, **kwargs):
@wraps(gr.Checkbox.__init__)
def __init__(self, value=None, setup=False, **kwargs):
if not setup:
super().__init__(value=value, **kwargs)
return
self.accordion_id = kwargs.get('elem_id')
if self.accordion_id is None:
self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
InputAccordion.global_index += 1
self.accordion_id = f"input-accordion-{InputAccordionImpl.global_index}"
InputAccordionImpl.global_index += 1
kwargs_checkbox = {
**kwargs,
"elem_id": f"{self.accordion_id}-checkbox",
"visible": False,
}
super().__init__(value, **kwargs_checkbox)
super().__init__(value=value, **kwargs_checkbox)
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
@ -115,6 +136,7 @@ class InputAccordion(gr.Checkbox):
"elem_classes": ['input-accordion'],
"open": value,
}
self.accordion = gr.Accordion(**kwargs_accordion)
def extra(self):
@ -143,3 +165,6 @@ class InputAccordion(gr.Checkbox):
def get_block_name(self):
return "checkbox"
def InputAccordion(value=None, **kwargs):
return InputAccordionImpl(value=value, setup=True, **kwargs)

View File

@ -382,7 +382,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
page_elem.change(fn=None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}')
editor = page.create_user_metadata_editor(ui, tabname)
editor.create_ui()

View File

@ -31,7 +31,7 @@ def check_tmp_file(gradio, filename):
return False
def save_pil_to_file(self, pil_image, dir=None, format="png"):
def save_pil_to_file(pil_image, dir=None, format="png", cache_dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
@ -61,7 +61,7 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
def install_ui_tempdir_override():
"""override save to file function so that it also writes PNG info"""
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
gradio.processing_utils.save_pil_to_cache = save_pil_to_file
def on_tmpdir_changed():

View File

@ -8,7 +8,7 @@ clean-fid
einops
fastapi>=0.90.1
gfpgan
gradio==3.41.2
gradio==4.7.1
inflection
jsonmerge
kornia

View File

@ -5,9 +5,9 @@ basicsr==1.4.2
blendmodes==2022
clean-fid==0.1.35
einops==0.4.1
fastapi==0.94.0
fastapi==0.104.1
gfpgan==1.3.8
gradio==3.41.2
gradio==4.7.1
httpcore==0.15
inflection==0.5.1
jsonmerge==1.8.0