diff --git a/README.md b/README.md index 0c49d6f2a..63f8d0003 100644 --- a/README.md +++ b/README.md @@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt. A tab with settings, allowing you to use UI to edit more than half of parameters that previously were commandline. Settings are saved to config.js file. Settings that remain as commandline options are ones that are required at startup. + +### Attention +Using `()` in prompt decreases model's attention to enclosed words, and `[]` increases it. You can combine +multiple modifiers: + +![](images/attention-3.jpg) diff --git a/images/attention-3.jpg b/images/attention-3.jpg new file mode 100644 index 000000000..7c7ef0d3a Binary files /dev/null and b/images/attention-3.jpg differ diff --git a/webui.py b/webui.py index b3375e98a..a0fa23c4a 100644 --- a/webui.py +++ b/webui.py @@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir): print(traceback.format_exc(), file=sys.stderr) -class TextInversionEmbeddings: +class StableDiffuionModelHijack: ids_lookup = {} word_embeddings = {} word_embeddings_checksums = {} - fixes = [] + fixes = None used_custom_terms = [] dir_mtime = None - def load(self, dir, model): + def load_textual_inversion_embeddings(self, dir, model): mt = os.path.getmtime(dir) if self.dir_mtime is not None and mt <= self.dir_mtime: return @@ -469,6 +469,7 @@ class TextInversionEmbeddings: self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}' ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] + first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] @@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.embeddings = embeddings self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length + self.token_mults = {} + + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult def forward(self, text): self.embeddings.fixes = [] @@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] + batch_multipliers = [] for tokens in batch_tokens: tuple_tokens = tuple(tokens) if tuple_tokens in cache: - remade_tokens, fixes = cache[tuple_tokens] + remade_tokens, fixes, multipliers = cache[tuple_tokens] else: fixes = [] remade_tokens = [] + multipliers = [] + mult = 1.0 i = 0 while i < len(tokens): @@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): possible_matches = self.embeddings.ids_lookup.get(token, None) - if possible_matches is None: + mult_change = self.token_mults.get(token) + if mult_change is not None: + mult *= mult_change + elif possible_matches is None: remade_tokens.append(token) + multipliers.append(mult) else: found = False for ids, word in possible_matches: if tokens[i:i+len(ids)] == ids: fixes.append((len(remade_tokens), word)) remade_tokens.append(777) + multipliers.append(mult) i += len(ids) - 1 found = True self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word])) @@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if not found: remade_tokens.append(token) + multipliers.append(mult) i += 1 remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes) + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) self.embeddings.fixes.append(fixes) + batch_multipliers.append(multipliers) tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device) outputs = self.wrapped.transformer(input_ids=tokens) z = outputs.last_hidden_state + + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise + batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device) + original_mean = z.mean() + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + new_mean = z.mean() + z *= original_mean / new_mean + return z @@ -562,24 +601,19 @@ class EmbeddingsWithFixes(nn.Module): def forward(self, input_ids): batch_fixes = self.embeddings.fixes - self.embeddings.fixes = [] + self.embeddings.fixes = None inputs_embeds = self.wrapped(input_ids) - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, word in fixes: - tensor[offset] = self.embeddings.word_embeddings[word] + if batch_fixes is not None: + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, word in fixes: + tensor[offset] = self.embeddings.word_embeddings[word] + return inputs_embeds -def get_learned_conditioning_with_embeddings(model, prompts): - if os.path.exists(cmd_opts.embeddings_dir): - text_inversion_embeddings.load(cmd_opts.embeddings_dir, model) - - return model.get_learned_conditioning(prompts) - - def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" @@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments]) if os.path.exists(cmd_opts.embeddings_dir): - text_inversion_embeddings.load(cmd_opts.embeddings_dir, model) + model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model) output_images = [] with torch.no_grad(), autocast("cuda"), model.ema_scope(): @@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, uc = model.get_learned_conditioning(len(prompts) * [""]) c = model.get_learned_conditioning(prompts) - if len(text_inversion_embeddings.used_custom_terms) > 0: - comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms])) + if len(model_hijack.used_custom_terms) > 0: + comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms])) # we manually generate all input noises because each one should have a specific seed x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) @@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = (model if cmd_opts.no_half else model.half()).to(device) -text_inversion_embeddings = TextInversionEmbeddings() -if os.path.exists(cmd_opts.embeddings_dir): - text_inversion_embeddings.hijack(model) +model_hijack = StableDiffuionModelHijack() +model_hijack.hijack(model) demo = gr.TabbedInterface( interface_list=[x[0] for x in interfaces],