diff --git a/modules/api/api.py b/modules/api/api.py index bb87d795a..b3d85e46e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -5,10 +5,9 @@ import uvicorn from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, HTTPException import modules.shared as shared -from modules import devices from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers +from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid from modules.extras import run_extras, run_pnginfo @@ -179,6 +178,16 @@ class Api: progress = min(progress, 1) + # copy from check_progress_call of ui.py + + if shared.parallel_processing_allowed: + if shared.state.sampling_step - shared.state.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.state.current_latent is not None: + if shared.opts.show_progress_grid: + shared.state.current_image = samples_to_image_grid(shared.state.current_latent) + else: + shared.state.current_image = sample_to_image(shared.state.current_latent) + shared.state.current_image_sampling_step = shared.state.sampling_step + current_image = None if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) diff --git a/modules/shared.py b/modules/shared.py index 1ccb269ae..04aaa6485 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,6 +4,7 @@ import json import os import sys from collections import OrderedDict +import time import gradio as gr import tqdm @@ -135,6 +136,7 @@ class State: current_image = None current_image_sampling_step = 0 textinfo = None + time_start = None need_restart = False def skip(self): @@ -172,6 +174,7 @@ class State: self.skipped = False self.interrupted = False self.textinfo = None + self.time_start = time.time() devices.torch_gc()