mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-04 03:29:00 +08:00
Merge 27e35f13faa79c90dddae45eea324baffcc13ace into 374bb6cc384d2a19422c0b07d69de0a41d1f3f4d
This commit is contained in:
commit
c692122513
@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
@ -59,15 +59,11 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
|
||||
return x_sample
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
def single_sample_to_image(sample, approximation=None, non_blocking=False):
|
||||
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
return Image.fromarray(x_sample)
|
||||
x_sample = 255. * x_sample.permute(1, 2, 0)
|
||||
return x_sample.to(device='cpu', dtype=torch.uint8, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
@ -76,12 +72,27 @@ def decode_first_stage(model, x):
|
||||
return samples_to_images_tensor(x, approx_index, model)
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
lp_stream = torch.cuda.Stream()
|
||||
live_preview_stream_context = torch.cuda.stream(lp_stream)
|
||||
else:
|
||||
lp_stream = None
|
||||
live_preview_stream_context = nullcontext()
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
with live_preview_stream_context:
|
||||
sample = single_sample_to_image(samples[index], approximation, non_blocking=lp_stream is not None)
|
||||
if lp_stream is not None:
|
||||
lp_stream.synchronize()
|
||||
return Image.fromarray(sample.numpy())
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
with live_preview_stream_context:
|
||||
sample_tensors = [single_sample_to_image(sample, approximation, non_blocking=lp_stream is not None) for sample in samples]
|
||||
if lp_stream is not None:
|
||||
lp_stream.synchronize()
|
||||
return images.image_grid([Image.fromarray(sample.numpy()) for sample in sample_tensors])
|
||||
|
||||
|
||||
def images_tensor_to_samples(image, approximation=None, model=None):
|
||||
|
Loading…
x
Reference in New Issue
Block a user