task manager added

based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py

 * classified
 * this way, gc.collect() will work as intended.
This commit is contained in:
Won-Kyu Park 2024-09-28 23:19:08 +09:00
parent 98cb284eb1
commit 1d3dae1471
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
4 changed files with 107 additions and 7 deletions

View File

@ -3,7 +3,7 @@ from functools import wraps
import html
import time
from modules import shared, progress, errors, devices, fifo_lock, profiling
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager
queue_lock = fifo_lock.FIFOLock()
@ -34,7 +34,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
progress.start_task(id_task)
try:
res = func(*args, **kwargs)
res = manager.task.run_and_wait_result(func, *args, **kwargs)
progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)

View File

@ -463,11 +463,17 @@ def configure_for_tests():
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
import webui
from modules import manager
if '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui()
manager.task.main_loop()
return
def dump_sysinfo():
from modules import sysinfo

83
modules/manager.py Normal file
View File

@ -0,0 +1,83 @@
#
# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py
#
# Original author comment:
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
# Other gradio calls (like those from extensions) are not influenced.
# By using one single thread to process all major calls, model moving is significantly faster.
#
# 2024/09/28 classified,
import random
import string
import threading
import time
from collections import OrderedDict
class Task:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class TaskManager:
last_exception = None
pending_tasks = []
finished_tasks = OrderedDict()
lock = None
running = False
def __init__(self):
self.lock = threading.Lock()
def work(self, task):
try:
task.result = task.func(*task.args, **task.kwargs)
except Exception as e:
task.exception = e
self.last_exception = e
def stop(self):
self.running = False
def main_loop(self):
self.running = True
while self.running:
time.sleep(0.01)
if len(self.pending_tasks) > 0:
with self.lock:
task = self.pending_tasks.pop(0)
self.work(task)
self.finished_tasks[task.task_id] = task
def push_task(self, func, *args, **kwargs):
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
task_id = args[0]
else:
task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None)
self.pending_tasks.append(task)
return task.task_id
def run_and_wait_result(self, func, *args, **kwargs):
current_id = self.push_task(func, *args, **kwargs)
while True:
time.sleep(0.01)
if current_id in self.finished_tasks:
finished = self.finished_tasks.pop(current_id)
if finished.exception is not None:
raise finished.exception
return finished.result
task = TaskManager()

View File

@ -6,6 +6,8 @@ import time
from modules import timer
from modules import initialize_util
from modules import initialize
from modules import manager
from threading import Thread
startup_timer = timer.startup_timer
startup_timer.record("launcher")
@ -14,6 +16,8 @@ initialize.imports()
initialize.check_versions()
initialize.initialize()
def create_api(app):
from modules.api.api import Api
@ -23,12 +27,10 @@ def create_api(app):
return api
def api_only():
def _api_only():
from fastapi import FastAPI
from modules.shared_cmd_options import cmd_opts
initialize.initialize()
app = FastAPI()
initialize_util.setup_middleware(app)
api = create_api(app)
@ -83,11 +85,10 @@ For more information see: https://github.com/AUTOMATIC1111/stable-diffusion-webu
{"!"*25} Warning {"!"*25}''')
def webui():
def _webui():
from modules.shared_cmd_options import cmd_opts
launch_api = cmd_opts.api
initialize.initialize()
from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
@ -177,6 +178,7 @@ def webui():
print("Stopping server...")
# If we catch a keyboard interrupt, we want to stop the server and exit.
shared.demo.close()
manager.task.stop()
break
# disable auto launch webui in browser for subsequent UI Reload
@ -193,6 +195,13 @@ def webui():
initialize.initialize_rest(reload_script_modules=True)
def api_only():
Thread(target=_api_only, daemon=True).start()
def webui():
Thread(target=_webui, daemon=True).start()
if __name__ == "__main__":
from modules.shared_cmd_options import cmd_opts
@ -200,3 +209,5 @@ if __name__ == "__main__":
api_only()
else:
webui()
manager.task.main_loop()