mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-24 22:09:00 +08:00
fix bugs related to variable prompt lengths
This commit is contained in:
parent
4999eb2ef9
commit
77f4237d1c
@ -89,7 +89,6 @@ class StableDiffusionModelHijack:
|
|||||||
layer.padding_mode = 'circular' if enable else 'zeros'
|
layer.padding_mode = 'circular' if enable else 'zeros'
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
max_length = opts.max_prompt_tokens - 2
|
|
||||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||||
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
|
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
@ -174,7 +173,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
if line in cache:
|
if line in cache:
|
||||||
remade_tokens, fixes, multipliers = cache[line]
|
remade_tokens, fixes, multipliers = cache[line]
|
||||||
else:
|
else:
|
||||||
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||||
|
token_count = max(current_token_count, token_count)
|
||||||
|
|
||||||
cache[line] = (remade_tokens, fixes, multipliers)
|
cache[line] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
@ -265,15 +265,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76]
|
target_token_count = get_target_prompt_token_count(token_count) + 2
|
||||||
|
|
||||||
|
position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76]
|
||||||
position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
|
position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
|
||||||
|
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
|
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
|
|
||||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
batch_multipliers = torch.asarray(batch_multipliers).to(device)
|
batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
|
||||||
|
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
|
||||||
original_mean = z.mean()
|
original_mean = z.mean()
|
||||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
new_mean = z.mean()
|
new_mean = z.mean()
|
||||||
|
@ -142,6 +142,16 @@ class VanillaStableDiffusionSampler:
|
|||||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||||
cond = tensor
|
cond = tensor
|
||||||
|
|
||||||
|
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||||
|
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||||
|
# not 100% correct but should work well enough
|
||||||
|
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||||
|
last_vector = unconditional_conditioning[:, -1:]
|
||||||
|
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||||
|
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||||
|
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||||
|
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
@ -221,18 +231,29 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
cond_in = torch.cat([tensor, uncond])
|
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if tensor.shape[1] == uncond.shape[1]:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
|
if shared.batch_cond_uncond:
|
||||||
|
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = a + batch_size
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = a + batch_size
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
|
||||||
|
|
||||||
denoised_uncond = x_out[-batch_size:]
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
|
||||||
|
|
||||||
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
denoised = torch.clone(denoised_uncond)
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
|
||||||
for i, conds in enumerate(conds_list):
|
for i, conds in enumerate(conds_list):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user