added tool for profiling code

This commit is contained in:
AUTOMATIC1111 2024-06-09 21:18:36 +03:00
parent aafbb5b403
commit 57e6d05a43
5 changed files with 78 additions and 5 deletions

View File

@ -1,8 +1,9 @@
import os.path
from functools import wraps from functools import wraps
import html import html
import time import time
from modules import shared, progress, errors, devices, fifo_lock from modules import shared, progress, errors, devices, fifo_lock, profiling
queue_lock = fifo_lock.FIFOLock() queue_lock = fifo_lock.FIFOLock()
@ -111,8 +112,13 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
else: else:
vram_html = '' vram_html = ''
if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename):
profiling_html = f"<p class='profile'> [ <a href='{profiling.webpath()}' download>Profile</a> ] </p>"
else:
profiling_html = ''
# last item is always HTML # last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>" res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}{profiling_html}</div>"
return tuple(res) return tuple(res)

View File

@ -16,7 +16,7 @@ from skimage import exposure
from typing import Any from typing import Any
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
from modules.rng import slerp # noqa: F401 from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@ -843,7 +843,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
# backwards compatibility, fix sampler and scheduler if invalid # backwards compatibility, fix sampler and scheduler if invalid
sd_samplers.fix_p_invalid_sampler_and_scheduler(p) sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
res = process_images_inner(p) with profiling.Profiler():
res = process_images_inner(p)
finally: finally:
sd_models.apply_token_merging(p.sd_model, 0) sd_models.apply_token_merging(p.sd_model, 0)

46
modules/profiling.py Normal file
View File

@ -0,0 +1,46 @@
import torch
from modules import shared, ui_gradio_extensions
class Profiler:
def __init__(self):
if not shared.opts.profiling_enable:
self.profiler = None
return
activities = []
if "CPU" in shared.opts.profiling_activities:
activities.append(torch.profiler.ProfilerActivity.CPU)
if "CUDA" in shared.opts.profiling_activities:
activities.append(torch.profiler.ProfilerActivity.CUDA)
if not activities:
self.profiler = None
return
self.profiler = torch.profiler.profile(
activities=activities,
record_shapes=shared.opts.profiling_record_shapes,
profile_memory=shared.opts.profiling_profile_memory,
with_stack=shared.opts.profiling_with_stack
)
def __enter__(self):
if self.profiler:
self.profiler.__enter__()
return self
def __exit__(self, exc_type, exc, exc_tb):
if self.profiler:
shared.state.textinfo = "Finishing profile..."
self.profiler.__exit__(exc_type, exc, exc_tb)
self.profiler.export_chrome_trace(shared.opts.profiling_filename)
def webpath():
return ui_gradio_extensions.webpath(shared.opts.profiling_filename)

View File

@ -129,6 +129,22 @@ options_templates.update(options_section(('system', "System", "system"), {
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
})) }))
options_templates.update(options_section(('profiler', "Profiler", "system"), {
"profiling_explanation": OptionHTML("""
Those settings allow you to enable torch profiler when generating pictures.
Profiling allows you to see which code uses how much of computer's resources during generation.
Each generation writes its own profile to one file, overwriting previous.
The file can be viewed in <a href="chrome:tracing">Chrome</a>, or on a <a href="https://ui.perfetto.dev/">Perfetto</a> web site.
Warning: writing profile can take a lot of time, up to 30 seconds, and the file itelf can be around 500MB in size.
"""),
"profiling_enable": OptionInfo(False, "Enable profiling"),
"profiling_activities": OptionInfo(["CPU"], "Activities", gr.CheckboxGroup, {"choices": ["CPU", "CUDA"]}),
"profiling_record_shapes": OptionInfo(True, "Record shapes"),
"profiling_profile_memory": OptionInfo(True, "Profile memory"),
"profiling_with_stack": OptionInfo(True, "Include python stack"),
"profiling_filename": OptionInfo("trace.json", "Profile filename"),
}))
options_templates.update(options_section(('API', "API", "system"), { options_templates.update(options_section(('API', "API", "system"), {
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True), "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True), "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),

View File

@ -279,7 +279,7 @@ input[type="checkbox"].input-accordion-checkbox{
display: inline-block; display: inline-block;
} }
.html-log .performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr { .html-log .performance p.time, .performance p.vram, .performance p.profile, .performance p.time abbr, .performance p.vram abbr {
margin-bottom: 0; margin-bottom: 0;
color: var(--block-title-text-color); color: var(--block-title-text-color);
} }
@ -291,6 +291,10 @@ input[type="checkbox"].input-accordion-checkbox{
margin-left: auto; margin-left: auto;
} }
.html-log .performance p.profile {
margin-left: 0.5em;
}
.html-log .performance .measurement{ .html-log .performance .measurement{
color: var(--body-text-color); color: var(--body-text-color);
font-weight: bold; font-weight: bold;