add allow specify the task id and get the location of task in the queue of pending task

This commit is contained in:
gayshub 2023-12-15 16:57:17 +08:00
parent 4afaaf8a02
commit 1242ba08e1
4 changed files with 41 additions and 4 deletions

View File

@ -33,7 +33,7 @@ from typing import Dict, List, Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing from contextlib import closing
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
def script_name_to_index(name, scripts): def script_name_to_index(name, scripts):
try: try:
@ -337,6 +337,10 @@ class Api:
return script_args return script_args
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
task_id = create_task_id("text2img")
if txt2imgreq.force_task_id != None:
task_id = txt2imgreq.force_task_id
script_runner = scripts.scripts_txt2img script_runner = scripts.scripts_txt2img
if not script_runner.scripts: if not script_runner.scripts:
script_runner.initialize_scripts(False) script_runner.initialize_scripts(False)
@ -363,6 +367,8 @@ class Api:
send_images = args.pop('send_images', True) send_images = args.pop('send_images', True)
args.pop('save_images', None) args.pop('save_images', None)
add_task_to_queue(task_id)
with self.queue_lock: with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.is_api = True p.is_api = True
@ -372,12 +378,14 @@ class Api:
try: try:
shared.state.begin(job="scripts_txt2img") shared.state.begin(job="scripts_txt2img")
start_task(task_id)
if selectable_scripts is not None: if selectable_scripts is not None:
p.script_args = script_args p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else: else:
p.script_args = tuple(script_args) # Need to pass args as tuple here p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p) processed = process_images(p)
finish_task(task_id)
finally: finally:
shared.state.end() shared.state.end()
shared.total_tqdm.clear() shared.total_tqdm.clear()
@ -387,6 +395,10 @@ class Api:
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
task_id = create_task_id("img2img")
if img2imgreq.force_task_id != None:
task_id = img2imgreq.force_task_id
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
@ -423,6 +435,8 @@ class Api:
send_images = args.pop('send_images', True) send_images = args.pop('send_images', True)
args.pop('save_images', None) args.pop('save_images', None)
add_task_to_queue(task_id)
with self.queue_lock: with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images] p.init_images = [decode_base64_to_image(x) for x in init_images]
@ -433,12 +447,14 @@ class Api:
try: try:
shared.state.begin(job="scripts_img2img") shared.state.begin(job="scripts_img2img")
start_task(task_id)
if selectable_scripts is not None: if selectable_scripts is not None:
p.script_args = script_args p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else: else:
p.script_args = tuple(script_args) # Need to pass args as tuple here p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p) processed = process_images(p)
finish_task(task_id)
finally: finally:
shared.state.end() shared.state.end()
shared.total_tqdm.clear() shared.total_tqdm.clear()
@ -514,7 +530,7 @@ class Api:
if shared.state.current_image and not req.skip_current_image: if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image) current_image = encode_pil_to_base64(shared.state.current_image)
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
def interrogateapi(self, interrogatereq: models.InterrogateRequest): def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image image_b64 = interrogatereq.image

View File

@ -109,6 +109,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True}, {"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False}, {"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
] ]
).generate_model() ).generate_model()
@ -126,6 +127,7 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True}, {"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False}, {"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
] ]
).generate_model() ).generate_model()

View File

@ -1023,6 +1023,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_sampler_name: str = None hr_sampler_name: str = None
hr_prompt: str = '' hr_prompt: str = ''
hr_negative_prompt: str = '' hr_negative_prompt: str = ''
force_task_id: str = None
cached_hr_uc = [None, None] cached_hr_uc = [None, None]
cached_hr_c = [None, None] cached_hr_c = [None, None]
@ -1358,6 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
inpainting_mask_invert: int = 0 inpainting_mask_invert: int = 0
initial_noise_multiplier: float = None initial_noise_multiplier: float = None
latent_mask: Image = None latent_mask: Image = None
force_task_id: string = None
image_mask: Any = field(default=None, init=False) image_mask: Any = field(default=None, init=False)

View File

@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
from modules.shared import opts from modules.shared import opts
import modules.shared as shared import modules.shared as shared
from collections import OrderedDict
import string
import random
from typing import List
current_task = None current_task = None
pending_tasks = {} pending_tasks = OrderedDict()
finished_tasks = [] finished_tasks = []
recorded_results = [] recorded_results = []
recorded_results_limit = 2 recorded_results_limit = 2
@ -34,6 +37,11 @@ def finish_task(id_task):
if len(finished_tasks) > 16: if len(finished_tasks) > 16:
finished_tasks.pop(0) finished_tasks.pop(0)
def create_task_id(task_type):
N = 7
res = ''.join(random.choices(string.ascii_uppercase +
string.digits, k=N))
return f"task({task_type}-{res})"
def record_results(id_task, res): def record_results(id_task, res):
recorded_results.append((id_task, res)) recorded_results.append((id_task, res))
@ -44,6 +52,9 @@ def record_results(id_task, res):
def add_task_to_queue(id_job): def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time() pending_tasks[id_job] = time.time()
class PendingTasksResponse(BaseModel):
size: int = Field(title="Pending task size")
tasks: List[str] = Field(title="Pending task ids")
class ProgressRequest(BaseModel): class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
@ -63,8 +74,14 @@ class ProgressResponse(BaseModel):
def setup_progress_api(app): def setup_progress_api(app):
app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"])
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
def get_pending_tasks():
pending_tasks_ids = [x for x in pending_tasks]
pending_len = len(pending_tasks_ids)
return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
def progressapi(req: ProgressRequest): def progressapi(req: ProgressRequest):
active = req.id_task == current_task active = req.id_task == current_task