diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index c060cccb2..b475b9f81 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,6 +1,6 @@ import inspect from collections import namedtuple -import numpy as np +from contextlib import nullcontext import torch from PIL import Image from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models @@ -59,15 +59,11 @@ def samples_to_images_tensor(sample, approximation=None, model=None): return x_sample - -def single_sample_to_image(sample, approximation=None): +def single_sample_to_image(sample, approximation=None, non_blocking=False): x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 - x_sample = torch.clamp(x_sample, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - - return Image.fromarray(x_sample) + x_sample = 255. * x_sample.permute(1, 2, 0) + return x_sample.to(device='cpu', dtype=torch.uint8, non_blocking=non_blocking) def decode_first_stage(model, x): @@ -76,12 +72,27 @@ def decode_first_stage(model, x): return samples_to_images_tensor(x, approx_index, model) +if torch.cuda.is_available(): + lp_stream = torch.cuda.Stream() + live_preview_stream_context = torch.cuda.stream(lp_stream) +else: + lp_stream = None + live_preview_stream_context = nullcontext() + def sample_to_image(samples, index=0, approximation=None): - return single_sample_to_image(samples[index], approximation) + with live_preview_stream_context: + sample = single_sample_to_image(samples[index], approximation, non_blocking=lp_stream is not None) + if lp_stream is not None: + lp_stream.synchronize() + return Image.fromarray(sample.numpy()) def samples_to_image_grid(samples, approximation=None): - return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) + with live_preview_stream_context: + sample_tensors = [single_sample_to_image(sample, approximation, non_blocking=lp_stream is not None) for sample in samples] + if lp_stream is not None: + lp_stream.synchronize() + return images.image_grid([Image.fromarray(sample.numpy()) for sample in sample_tensors]) def images_tensor_to_samples(image, approximation=None, model=None):