stable-diffusion-webui/modules/api/api.py

130 lines
4.5 KiB
Python
Raw Normal View History

from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
2022-10-19 13:19:01 +08:00
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
2022-10-19 03:04:56 +08:00
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
from typing import List
import json
import io
import base64
from PIL import Image
2022-10-19 13:19:01 +08:00
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
2022-10-19 03:04:56 +08:00
class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class Api:
2022-10-18 14:51:53 +08:00
def __init__(self, app, queue_lock):
self.router = APIRouter()
2022-10-18 14:51:53 +08:00
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
2022-10-23 05:10:28 +08:00
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
if "," in base64_string:
base64_string = base64_string.split(",")[1]
imgdata = base64.b64decode(base64_string)
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
2022-10-19 03:04:56 +08:00
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
2022-10-18 03:10:36 +08:00
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
2022-10-19 03:04:56 +08:00
"sampler_index": sampler_index[0],
2022-10-18 04:36:14 +08:00
"do_not_save_samples": True,
"do_not_save_grid": True
2022-10-18 03:10:36 +08:00
}
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
2022-10-18 14:51:53 +08:00
with self.queue_lock:
processed = process_images(p)
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
2022-10-23 03:42:00 +08:00
mask = img2imgreq.mask
if mask:
2022-10-23 05:10:28 +08:00
mask = self.__base64_to_image(mask)
2022-10-23 03:42:00 +08:00
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
2022-10-23 05:10:28 +08:00
"do_not_save_grid": True,
"mask": mask
}
)
p = StableDiffusionProcessingImg2Img(**vars(populate))
imgs = []
for img in init_images:
2022-10-23 05:10:28 +08:00
img = self.__base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
# Override object param
with self.queue_lock:
processed = process_images(p)
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
2022-10-19 13:19:01 +08:00
def extrasapi(self):
raise NotImplementedError
2022-10-19 13:19:01 +08:00
def pnginfoapi(self):
raise NotImplementedError
def launch(self, server_name, port):
2022-10-18 14:51:53 +08:00
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)