mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
task manager added
based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py * classified * this way, gc.collect() will work as intended.
This commit is contained in:
parent
98cb284eb1
commit
1d3dae1471
@ -3,7 +3,7 @@ from functools import wraps
|
||||
import html
|
||||
import time
|
||||
|
||||
from modules import shared, progress, errors, devices, fifo_lock, profiling
|
||||
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager
|
||||
|
||||
queue_lock = fifo_lock.FIFOLock()
|
||||
|
||||
@ -34,7 +34,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
progress.start_task(id_task)
|
||||
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
res = manager.task.run_and_wait_result(func, *args, **kwargs)
|
||||
progress.record_results(id_task, res)
|
||||
finally:
|
||||
progress.finish_task(id_task)
|
||||
|
@ -463,11 +463,17 @@ def configure_for_tests():
|
||||
def start():
|
||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
|
||||
import webui
|
||||
|
||||
from modules import manager
|
||||
|
||||
if '--nowebui' in sys.argv:
|
||||
webui.api_only()
|
||||
else:
|
||||
webui.webui()
|
||||
|
||||
manager.task.main_loop()
|
||||
return
|
||||
|
||||
|
||||
def dump_sysinfo():
|
||||
from modules import sysinfo
|
||||
|
83
modules/manager.py
Normal file
83
modules/manager.py
Normal file
@ -0,0 +1,83 @@
|
||||
#
|
||||
# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py
|
||||
#
|
||||
# Original author comment:
|
||||
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
|
||||
# Other gradio calls (like those from extensions) are not influenced.
|
||||
# By using one single thread to process all major calls, model moving is significantly faster.
|
||||
#
|
||||
# 2024/09/28 classified,
|
||||
|
||||
import random
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class Task:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
class TaskManager:
|
||||
last_exception = None
|
||||
pending_tasks = []
|
||||
finished_tasks = OrderedDict()
|
||||
lock = None
|
||||
running = False
|
||||
|
||||
def __init__(self):
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def work(self, task):
|
||||
try:
|
||||
task.result = task.func(*task.args, **task.kwargs)
|
||||
except Exception as e:
|
||||
task.exception = e
|
||||
self.last_exception = e
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
|
||||
def main_loop(self):
|
||||
self.running = True
|
||||
while self.running:
|
||||
time.sleep(0.01)
|
||||
if len(self.pending_tasks) > 0:
|
||||
with self.lock:
|
||||
task = self.pending_tasks.pop(0)
|
||||
|
||||
self.work(task)
|
||||
|
||||
self.finished_tasks[task.task_id] = task
|
||||
|
||||
|
||||
def push_task(self, func, *args, **kwargs):
|
||||
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
|
||||
task_id = args[0]
|
||||
else:
|
||||
task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
|
||||
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None)
|
||||
self.pending_tasks.append(task)
|
||||
|
||||
return task.task_id
|
||||
|
||||
|
||||
def run_and_wait_result(self, func, *args, **kwargs):
|
||||
current_id = self.push_task(func, *args, **kwargs)
|
||||
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
if current_id in self.finished_tasks:
|
||||
finished = self.finished_tasks.pop(current_id)
|
||||
if finished.exception is not None:
|
||||
raise finished.exception
|
||||
|
||||
return finished.result
|
||||
|
||||
|
||||
task = TaskManager()
|
21
webui.py
21
webui.py
@ -6,6 +6,8 @@ import time
|
||||
from modules import timer
|
||||
from modules import initialize_util
|
||||
from modules import initialize
|
||||
from modules import manager
|
||||
from threading import Thread
|
||||
|
||||
startup_timer = timer.startup_timer
|
||||
startup_timer.record("launcher")
|
||||
@ -14,6 +16,8 @@ initialize.imports()
|
||||
|
||||
initialize.check_versions()
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
|
||||
def create_api(app):
|
||||
from modules.api.api import Api
|
||||
@ -23,12 +27,10 @@ def create_api(app):
|
||||
return api
|
||||
|
||||
|
||||
def api_only():
|
||||
def _api_only():
|
||||
from fastapi import FastAPI
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
app = FastAPI()
|
||||
initialize_util.setup_middleware(app)
|
||||
api = create_api(app)
|
||||
@ -83,11 +85,10 @@ For more information see: https://github.com/AUTOMATIC1111/stable-diffusion-webu
|
||||
{"!"*25} Warning {"!"*25}''')
|
||||
|
||||
|
||||
def webui():
|
||||
def _webui():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
|
||||
from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
|
||||
|
||||
@ -177,6 +178,7 @@ def webui():
|
||||
print("Stopping server...")
|
||||
# If we catch a keyboard interrupt, we want to stop the server and exit.
|
||||
shared.demo.close()
|
||||
manager.task.stop()
|
||||
break
|
||||
|
||||
# disable auto launch webui in browser for subsequent UI Reload
|
||||
@ -193,6 +195,13 @@ def webui():
|
||||
initialize.initialize_rest(reload_script_modules=True)
|
||||
|
||||
|
||||
def api_only():
|
||||
Thread(target=_api_only, daemon=True).start()
|
||||
|
||||
|
||||
def webui():
|
||||
Thread(target=_webui, daemon=True).start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
@ -200,3 +209,5 @@ if __name__ == "__main__":
|
||||
api_only()
|
||||
else:
|
||||
webui()
|
||||
|
||||
manager.task.main_loop()
|
||||
|
Loading…
Reference in New Issue
Block a user