[Bugfix][API] - Fix API response for colab users

This commit is contained in:
Stephen 2022-10-24 11:16:07 -04:00 committed by AUTOMATIC1111
parent cbb857b675
commit db9ab1a46b
2 changed files with 19 additions and 8 deletions

View File

@ -7,6 +7,7 @@ import uvicorn
from fastapi import Body, APIRouter, HTTPException from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json from pydantic import BaseModel, Field, Json
from typing import List
import json import json
import io import io
import base64 import base64
@ -15,12 +16,12 @@ from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
info: Json info: Json
class ImageToImageResponse(BaseModel): class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
info: Json info: Json
@ -41,6 +42,9 @@ class Api:
# convert base64 to PIL image # convert base64 to PIL image
return Image.open(io.BytesIO(imgdata)) return Image.open(io.BytesIO(imgdata))
def __processed_info_to_json(self, processed):
return json.dumps(processed.info)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@ -65,7 +69,7 @@ class Api:
i.save(buffer, format="png") i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
@ -111,7 +115,12 @@ class Api:
i.save(buffer, format="png") i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) if (not img2imgreq.include_init_images):
# remove img2imgreq.init_images and img2imgreq.mask
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def extrasapi(self): def extrasapi(self):
raise NotImplementedError raise NotImplementedError

View File

@ -31,6 +31,7 @@ class ModelDef(BaseModel):
field_alias: str field_alias: str
field_type: Any field_type: Any
field_value: Any field_value: Any
field_exclude: bool = False
class PydanticModelGenerator: class PydanticModelGenerator:
@ -68,7 +69,7 @@ class PydanticModelGenerator:
field=underscore(k), field=underscore(k),
field_alias=k, field_alias=k,
field_type=field_type_generator(k, v), field_type=field_type_generator(k, v),
field_value=v.default field_value=v.default,
) )
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
] ]
@ -78,7 +79,8 @@ class PydanticModelGenerator:
field=underscore(fields["key"]), field=underscore(fields["key"]),
field_alias=fields["key"], field_alias=fields["key"],
field_type=fields["type"], field_type=fields["type"],
field_value=fields["default"])) field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
def generate_model(self): def generate_model(self):
""" """
@ -86,7 +88,7 @@ class PydanticModelGenerator:
from the json and overrides provided at initialization from the json and overrides provided at initialization
""" """
fields = { fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
} }
DynamicModel = create_model(self._model_name, **fields) DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True DynamicModel.__config__.allow_population_by_field_name = True
@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model() ).generate_model()