stable-diffusion-webui/modules/gradio_extensons.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

143 lines
5.3 KiB
Python
Raw Permalink Normal View History

2023-11-27 19:30:49 +08:00
from inspect import signature
from functools import wraps
2023-08-04 12:50:17 +08:00
import gradio as gr
from modules import scripts, ui_tempdir, patches
2023-08-04 12:50:17 +08:00
def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
def IOComponent_init(self, *args, **kwargs):
self.webui_tooltip = kwargs.pop('tooltip', None)
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_IOComponent_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res
def Block_get_config(self):
config = original_Block_get_config(self)
webui_tooltip = getattr(self, 'webui_tooltip', None)
if webui_tooltip:
config["webui_tooltip"] = webui_tooltip
config.pop('example_inputs', None)
2023-08-04 12:50:17 +08:00
return config
def BlockContext_init(self, *args, **kwargs):
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
return res
def Blocks_get_config_file(self, *args, **kwargs):
config = original_Blocks_get_config_file(self, *args, **kwargs)
for comp_config in config["components"]:
if "example_inputs" in comp_config:
comp_config["example_inputs"] = {"serialized": []}
return config
2023-11-27 19:30:49 +08:00
def gradio_component_compatibility_layer(component_function):
@wraps(component_function)
def patched_function(*args, **kwargs):
original_signature = signature(component_function).parameters
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
result = component_function(*args, **valid_kwargs)
return result
return patched_function
sub_events = ['then', 'success']
def gradio_component_events_compatibility_layer(component_function):
@wraps(component_function)
def patched_function(*args, **kwargs):
kwargs['js'] = kwargs.get('js', kwargs.pop('_js', None))
original_signature = signature(component_function).parameters
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
result = component_function(*args, **valid_kwargs)
for sub_event in sub_events:
component_event_then_function = getattr(result, sub_event, None)
if component_event_then_function:
patched_component_event_then_function = gradio_component_sub_events_compatibility_layer(component_event_then_function)
setattr(result, sub_event, patched_component_event_then_function)
# original_component_event_then_function = patches.patch(f'{__name__}.', obj=result, field='then', replacement=patched_component_event_then_function)
return result
return patched_function
def gradio_component_sub_events_compatibility_layer(component_function):
@wraps(component_function)
def patched_function(*args, **kwargs):
kwargs['js'] = kwargs.get('js', kwargs.pop('_js', None))
original_signature = signature(component_function).parameters
valid_kwargs = {k: v for k, v in kwargs.items() if k in original_signature}
result = component_function(*args, **valid_kwargs)
return result
return patched_function
for component_name in set(gr.components.__all__ + gr.layouts.__all__):
try:
component = getattr(gr, component_name)
component_init = getattr(component, '__init__')
patched_component_init = gradio_component_compatibility_layer(component_init)
original_IOComponent_init = patches.patch(f'{__name__}.{component_name}', obj=component, field="__init__", replacement=patched_component_init)
component_events = set(getattr(component, 'EVENTS'))
for component_event in component_events:
component_event_function = getattr(component, component_event)
patched_component_event_function = gradio_component_events_compatibility_layer(component_event_function)
original_component_event_function = patches.patch(f'{__name__}.{component_name}.{component_event}', obj=component, field=component_event, replacement=patched_component_event_function)
except Exception as e:
print(e)
pass
gr.Box = gr.Group
original_IOComponent_init = patches.patch(__name__, obj=gr.components.base.Component, field="__init__", replacement=IOComponent_init)
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
2023-08-04 12:50:17 +08:00
ui_tempdir.install_ui_tempdir_override()
2023-11-27 19:30:49 +08:00