add option to show/hide warnings

removed hiding warnings from LDSR
fixed/reworked few places that produced warnings
This commit is contained in:
AUTOMATIC 2023-01-18 23:04:24 +03:00
parent 889b851a52
commit 924e222004
10 changed files with 71 additions and 32 deletions

View File

@ -1,7 +1,6 @@
import os import os
import gc import gc
import time import time
import warnings
import numpy as np import numpy as np
import torch import torch
@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack from modules import shared, sd_hijack
warnings.filterwarnings("ignore", category=UserWarning)
cached_ldsr_model: torch.nn.Module = None cached_ldsr_model: torch.nn.Module = None

View File

@ -11,7 +11,7 @@ ignore_ids_for_localization={
train_embedding: 'OPTION', train_embedding: 'OPTION',
train_hypernetwork: 'OPTION', train_hypernetwork: 'OPTION',
txt2img_styles: 'OPTION', txt2img_styles: 'OPTION',
img2img_styles 'OPTION', img2img_styles: 'OPTION',
setting_random_artist_categories: 'SPAN', setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN', setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN', setting_realesrgan_enabled_models: 'SPAN',

View File

@ -12,7 +12,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
pbar = tqdm.tqdm(total=steps - initial_step) pbar = tqdm.tqdm(total=steps - initial_step)
try: try:
sd_hijack_checkpoint.add()
for i in range((steps-initial_step) * gradient_step): for i in range((steps-initial_step) * gradient_step):
if scheduler.finished: if scheduler.finished:
break break
@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close() pbar.close()
hypernetwork.eval() hypernetwork.eval()
#report_statistics(loss_dict) #report_statistics(loss_dict)
sd_hijack_checkpoint.remove()
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name hypernetwork.optimizer_name = optimizer_name

View File

@ -69,12 +69,6 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def fix_checkpoint():
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
comments = [] comments = []
@ -107,8 +101,6 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
fix_checkpoint()
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
res = [el] res = [el]

View File

@ -1,10 +1,46 @@
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import ldm.modules.attention
import ldm.modules.diffusionmodules.openaimodel
def BasicTransformerBlock_forward(self, x, context=None): def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context) return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x): def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x) return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb): def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb) return checkpoint(self._forward, x, emb)
stored = []
def add():
if len(stored) != 0:
return
stored.extend([
ldm.modules.attention.BasicTransformerBlock.forward,
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
])
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
def remove():
if len(stored) == 0:
return
ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
stored.clear()

View File

@ -369,6 +369,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
"show_warnings": OptionInfo(False, "Show warnings in console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),

View File

@ -15,7 +15,7 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
pbar = tqdm.tqdm(total=steps - initial_step) pbar = tqdm.tqdm(total=steps - initial_step)
try: try:
sd_hijack_checkpoint.add()
for i in range((steps-initial_step) * gradient_step): for i in range((steps-initial_step) * gradient_step):
if scheduler.finished: if scheduler.finished:
break break
@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close() pbar.close()
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed shared.parallel_processing_allowed = old_parallel_processing_allowed
sd_hijack_checkpoint.remove()
return embedding, filename return embedding, filename
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None

View File

@ -11,6 +11,7 @@ import tempfile
import time import time
import traceback import traceback
from functools import partial, reduce from functools import partial, reduce
import warnings
import gradio as gr import gradio as gr
import gradio.routes import gradio.routes
@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init() mimetypes.init()
mimetypes.add_type('application/javascript', '.js') mimetypes.add_type('application/javascript', '.js')
@ -417,17 +420,16 @@ def apply_setting(key, value):
return value return value
def update_generation_info(args): def update_generation_info(generation_info, html_info, img_index):
generation_info, html_info, img_index = args
try: try:
generation_info = json.loads(generation_info) generation_info = json.loads(generation_info)
if img_index < 0 or img_index >= len(generation_info["infotexts"]): if img_index < 0 or img_index >= len(generation_info["infotexts"]):
return html_info return html_info, gr.update()
return plaintext_to_html(generation_info["infotexts"][img_index]) return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
except Exception: except Exception:
pass pass
# if the json parse or anything else fails, just return the old html_info # if the json parse or anything else fails, just return the old html_info
return html_info return html_info, gr.update()
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
@ -508,10 +510,9 @@ Requested path was: {f}
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click( generation_info_button.click(
fn=update_generation_info, fn=update_generation_info,
_js="(x, y) => [x, y, selected_gallery_index()]", _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info], inputs=[generation_info, html_info, html_info],
outputs=[html_info], outputs=[html_info, html_info],
preprocess=False
) )
save.click( save.click(
@ -526,7 +527,8 @@ Requested path was: {f}
outputs=[ outputs=[
download_files, download_files,
html_log, html_log,
] ],
show_progress=False,
) )
save_zip.click( save_zip.click(
@ -588,7 +590,7 @@ def create_ui():
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"): with gr.Column(variant='compact', elem_id="txt2img_settings"):
@ -768,7 +770,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow().style(equal_height=False): with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"): with gr.Column(variant='compact', elem_id="img2img_settings"):
@ -1768,7 +1770,10 @@ def create_ui():
if saved_value is None: if saved_value is None:
ui_settings[key] = getattr(obj, field) ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value): elif condition and not condition(saved_value):
print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') pass
# this warning is generally not useful;
# print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else: else:
setattr(obj, field, saved_value) setattr(obj, field, saved_value)
if init_field is not None: if init_field is not None:

View File

@ -116,7 +116,7 @@ class Script(scripts.Script):
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file")) file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])

View File

@ -299,9 +299,8 @@ input[type="range"]{
} }
/* more gradio's garbage cleanup */ /* more gradio's garbage cleanup */
.min-h-\[4rem\] { .min-h-\[4rem\] { min-height: unset !important; }
min-height: unset !important; .min-h-\[6rem\] { min-height: unset !important; }
}
.progressDiv{ .progressDiv{
position: absolute; position: absolute;