mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-20 05:10:15 +08:00
rework #3722 to not introduce duplicate code
This commit is contained in:
parent
060ee5d3a7
commit
149784202c
@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
|
|||||||
from modules.sd_samplers import all_samplers
|
from modules.sd_samplers import all_samplers
|
||||||
from modules.extras import run_extras, run_pnginfo
|
from modules.extras import run_extras, run_pnginfo
|
||||||
|
|
||||||
# copy from wrap_gradio_gpu_call of webui.py
|
|
||||||
# because queue lock will be acquired in api handlers
|
|
||||||
# and time start needs to be set
|
|
||||||
# the function has been modified into two parts
|
|
||||||
|
|
||||||
def before_gpu_call():
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
shared.state.sampling_step = 0
|
|
||||||
shared.state.job_count = -1
|
|
||||||
shared.state.job_no = 0
|
|
||||||
shared.state.job_timestamp = shared.state.get_job_timestamp()
|
|
||||||
shared.state.current_latent = None
|
|
||||||
shared.state.current_image = None
|
|
||||||
shared.state.current_image_sampling_step = 0
|
|
||||||
shared.state.skipped = False
|
|
||||||
shared.state.interrupted = False
|
|
||||||
shared.state.textinfo = None
|
|
||||||
shared.state.time_start = time.time()
|
|
||||||
|
|
||||||
def after_gpu_call():
|
|
||||||
shared.state.job = ""
|
|
||||||
shared.state.job_count = 0
|
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
@ -41,8 +16,10 @@ def upscaler_to_index(name: str):
|
|||||||
except:
|
except:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||||
|
|
||||||
|
|
||||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||||
|
|
||||||
|
|
||||||
def setUpscalers(req: dict):
|
def setUpscalers(req: dict):
|
||||||
reqDict = vars(req)
|
reqDict = vars(req)
|
||||||
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
||||||
@ -51,6 +28,7 @@ def setUpscalers(req: dict):
|
|||||||
reqDict.pop('upscaler_2')
|
reqDict.pop('upscaler_2')
|
||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
|
|
||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app, queue_lock):
|
def __init__(self, app, queue_lock):
|
||||||
self.router = APIRouter()
|
self.router = APIRouter()
|
||||||
@ -78,10 +56,13 @@ class Api:
|
|||||||
)
|
)
|
||||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||||
# Override object param
|
# Override object param
|
||||||
before_gpu_call()
|
|
||||||
|
shared.state.begin()
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
after_gpu_call()
|
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||||
|
|
||||||
@ -119,11 +100,13 @@ class Api:
|
|||||||
imgs = [img] * p.batch_size
|
imgs = [img] * p.batch_size
|
||||||
|
|
||||||
p.init_images = imgs
|
p.init_images = imgs
|
||||||
# Override object param
|
|
||||||
before_gpu_call()
|
shared.state.begin()
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
after_gpu_call()
|
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||||
|
|
||||||
|
@ -144,9 +144,6 @@ class State:
|
|||||||
self.sampling_step = 0
|
self.sampling_step = 0
|
||||||
self.current_image_sampling_step = 0
|
self.current_image_sampling_step = 0
|
||||||
|
|
||||||
def get_job_timestamp(self):
|
|
||||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
|
||||||
|
|
||||||
def dict(self):
|
def dict(self):
|
||||||
obj = {
|
obj = {
|
||||||
"skipped": self.skipped,
|
"skipped": self.skipped,
|
||||||
@ -160,6 +157,25 @@ class State:
|
|||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
self.sampling_step = 0
|
||||||
|
self.job_count = -1
|
||||||
|
self.job_no = 0
|
||||||
|
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
self.current_latent = None
|
||||||
|
self.current_image = None
|
||||||
|
self.current_image_sampling_step = 0
|
||||||
|
self.skipped = False
|
||||||
|
self.interrupted = False
|
||||||
|
self.textinfo = None
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
self.job = ""
|
||||||
|
self.job_count = 0
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
state = State()
|
state = State()
|
||||||
|
|
||||||
|
17
webui.py
17
webui.py
@ -46,26 +46,13 @@ def wrap_queued_call(func):
|
|||||||
|
|
||||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
shared.state.sampling_step = 0
|
shared.state.begin()
|
||||||
shared.state.job_count = -1
|
|
||||||
shared.state.job_no = 0
|
|
||||||
shared.state.job_timestamp = shared.state.get_job_timestamp()
|
|
||||||
shared.state.current_latent = None
|
|
||||||
shared.state.current_image = None
|
|
||||||
shared.state.current_image_sampling_step = 0
|
|
||||||
shared.state.skipped = False
|
|
||||||
shared.state.interrupted = False
|
|
||||||
shared.state.textinfo = None
|
|
||||||
|
|
||||||
with queue_lock:
|
with queue_lock:
|
||||||
res = func(*args, **kwargs)
|
res = func(*args, **kwargs)
|
||||||
|
|
||||||
shared.state.job = ""
|
shared.state.end()
|
||||||
shared.state.job_count = 0
|
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user