2022-09-15 18:10:16 +08:00
import re
from collections import namedtuple
2022-10-06 05:11:30 +08:00
from typing import List
2022-10-04 23:49:51 +08:00
import lark
2022-09-15 18:10:16 +08:00
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
# [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
2022-10-04 23:49:51 +08:00
schedule_parser = lark . Lark ( r """
! start : ( prompt | / [ ] [ ( ) : ] / + ) *
2022-10-06 00:10:39 +08:00
prompt : ( emphasized | scheduled | alternate | plain | WHITESPACE ) *
2022-10-04 23:49:51 +08:00
! emphasized : " ( " prompt " ) "
| " ( " prompt " : " prompt " ) "
| " [ " prompt " ] "
scheduled : " [ " [ prompt " : " ] prompt " : " [ WHITESPACE ] NUMBER " ] "
2022-10-06 00:10:39 +08:00
alternate : " [ " prompt ( " | " prompt ) + " ] "
2022-10-04 23:49:51 +08:00
WHITESPACE : / \s + /
2022-10-06 00:10:39 +08:00
plain : / ( [ ^ \\\[ \] ( ) : | ] | \\. ) + /
2022-10-04 23:49:51 +08:00
% import common . SIGNED_NUMBER - > NUMBER
""" )
2022-09-15 18:10:16 +08:00
def get_learned_conditioning_prompt_schedules ( prompts , steps ) :
2022-10-03 19:25:36 +08:00
"""
2022-10-04 23:49:51 +08:00
>> > g = lambda p : get_learned_conditioning_prompt_schedules ( [ p ] , 10 ) [ 0 ]
>> > g ( " test " )
[ [ 10 , ' test ' ] ]
>> > g ( " a [b:3] " )
[ [ 3 , ' a ' ] , [ 10 , ' a b ' ] ]
>> > g ( " a [b: 3] " )
[ [ 3 , ' a ' ] , [ 10 , ' a b ' ] ]
>> > g ( " a [[[b]]:2] " )
[ [ 2 , ' a ' ] , [ 10 , ' a [[b]] ' ] ]
>> > g ( " [(a:2):3] " )
[ [ 3 , ' ' ] , [ 10 , ' (a:2) ' ] ]
>> > g ( " a [b : c : 1] d " )
[ [ 1 , ' a b d ' ] , [ 10 , ' a c d ' ] ]
>> > g ( " a[b:[c:d:2]:1]e " )
[ [ 1 , ' abe ' ] , [ 2 , ' ace ' ] , [ 10 , ' ade ' ] ]
>> > g ( " a [unbalanced " )
[ [ 10 , ' a [unbalanced ' ] ]
>> > g ( " a [b:.5] c " )
[ [ 5 , ' a c ' ] , [ 10 , ' a b c ' ] ]
>> > g ( " a [ { b|d { :.5] c " ) # not handling this right now
[ [ 5 , ' a c ' ] , [ 10 , ' a { b|d { c ' ] ]
>> > g ( " ((a][:b:c [d:3] " )
[ [ 3 , ' ((a][:b:c ' ] , [ 10 , ' ((a][:b:c d ' ] ]
2023-01-11 10:47:03 +08:00
>> > g ( " [a|(b:1.1)] " )
[ [ 1 , ' a ' ] , [ 2 , ' (b:1.1) ' ] , [ 3 , ' a ' ] , [ 4 , ' (b:1.1) ' ] , [ 5 , ' a ' ] , [ 6 , ' (b:1.1) ' ] , [ 7 , ' a ' ] , [ 8 , ' (b:1.1) ' ] , [ 9 , ' a ' ] , [ 10 , ' (b:1.1) ' ] ]
2022-10-04 23:49:51 +08:00
"""
2022-10-04 23:02:01 +08:00
2022-10-03 19:25:36 +08:00
def collect_steps ( steps , tree ) :
2023-05-10 13:25:25 +08:00
res = [ steps ]
2022-10-04 23:49:51 +08:00
class CollectSteps ( lark . Visitor ) :
2022-10-03 19:25:36 +08:00
def scheduled ( self , tree ) :
tree . children [ - 1 ] = float ( tree . children [ - 1 ] )
if tree . children [ - 1 ] < 1 :
tree . children [ - 1 ] * = steps
tree . children [ - 1 ] = min ( steps , int ( tree . children [ - 1 ] ) )
2023-05-10 13:25:25 +08:00
res . append ( tree . children [ - 1 ] )
2022-10-06 00:10:39 +08:00
def alternate ( self , tree ) :
2023-05-10 13:25:25 +08:00
res . extend ( range ( 1 , steps + 1 ) )
2022-10-03 19:25:36 +08:00
CollectSteps ( ) . visit ( tree )
2023-05-10 13:25:25 +08:00
return sorted ( set ( res ) )
2022-10-04 23:02:01 +08:00
2022-10-03 19:25:36 +08:00
def at_step ( step , tree ) :
2022-10-04 23:49:51 +08:00
class AtStep ( lark . Transformer ) :
2022-10-03 19:25:36 +08:00
def scheduled ( self , args ) :
2022-10-04 23:49:51 +08:00
before , after , _ , when = args
yield before or ( ) if step < = when else after
2022-10-06 00:10:39 +08:00
def alternate ( self , args ) :
yield next ( args [ ( step - 1 ) % len ( args ) ] )
2022-10-03 19:25:36 +08:00
def start ( self , args ) :
def flatten ( x ) :
if type ( x ) == str :
yield x
else :
for gen in x :
yield from flatten ( gen )
2022-10-04 23:49:51 +08:00
return ' ' . join ( flatten ( args ) )
2022-10-03 19:25:36 +08:00
def plain ( self , args ) :
yield args [ 0 ] . value
def __default__ ( self , data , children , meta ) :
for child in children :
2023-01-11 21:59:47 +08:00
yield child
2022-10-03 19:25:36 +08:00
return AtStep ( ) . transform ( tree )
2022-10-04 23:49:51 +08:00
2022-10-03 19:25:36 +08:00
def get_schedule ( prompt ) :
2022-10-04 23:49:51 +08:00
try :
tree = schedule_parser . parse ( prompt )
2023-05-10 12:52:45 +08:00
except lark . exceptions . LarkError :
2022-10-04 23:49:51 +08:00
if 0 :
import traceback
traceback . print_exc ( )
return [ [ steps , prompt ] ]
2022-10-03 19:25:36 +08:00
return [ [ t , at_step ( t , tree ) ] for t in collect_steps ( steps , tree ) ]
2022-10-04 23:02:01 +08:00
promptdict = { prompt : get_schedule ( prompt ) for prompt in set ( prompts ) }
return [ promptdict [ prompt ] for prompt in prompts ]
2022-09-15 18:10:16 +08:00
ScheduledPromptConditioning = namedtuple ( " ScheduledPromptConditioning " , [ " end_at_step " , " cond " ] )
2022-10-04 23:49:51 +08:00
def get_learned_conditioning ( model , prompts , steps ) :
2022-10-06 04:16:27 +08:00
""" converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one .
Input :
( model , [ ' a red crown ' , ' a [blue:green:5] jeweled crown ' ] , 20 )
Output :
[
[
ScheduledPromptConditioning ( end_at_step = 20 , cond = tensor ( [ [ - 0.3886 , 0.0229 , - 0.0523 , . . . , - 0.4901 , - 0.3066 , 0.0674 ] , . . . , [ 0.3317 , - 0.5102 , - 0.4066 , . . . , 0.4119 , - 0.7647 , - 1.0160 ] ] , device = ' cuda:0 ' ) )
] ,
[
ScheduledPromptConditioning ( end_at_step = 5 , cond = tensor ( [ [ - 0.3886 , 0.0229 , - 0.0522 , . . . , - 0.4901 , - 0.3067 , 0.0673 ] , . . . , [ - 0.0192 , 0.3867 , - 0.4644 , . . . , 0.1135 , - 0.3696 , - 0.4625 ] ] , device = ' cuda:0 ' ) ) ,
ScheduledPromptConditioning ( end_at_step = 20 , cond = tensor ( [ [ - 0.3886 , 0.0229 , - 0.0522 , . . . , - 0.4901 , - 0.3067 , 0.0673 ] , . . . , [ - 0.7352 , - 0.4356 , - 0.7888 , . . . , 0.6994 , - 0.4312 , - 1.2593 ] ] , device = ' cuda:0 ' ) )
]
]
"""
2022-09-15 18:10:16 +08:00
res = [ ]
prompt_schedules = get_learned_conditioning_prompt_schedules ( prompts , steps )
cache = { }
for prompt , prompt_schedule in zip ( prompts , prompt_schedules ) :
cached = cache . get ( prompt , None )
if cached is not None :
res . append ( cached )
2022-09-15 23:05:42 +08:00
continue
2022-09-15 18:10:16 +08:00
texts = [ x [ 1 ] for x in prompt_schedule ]
2022-10-04 23:49:51 +08:00
conds = model . get_learned_conditioning ( texts )
2022-09-15 18:10:16 +08:00
cond_schedule = [ ]
for i , ( end_at_step , text ) in enumerate ( prompt_schedule ) :
cond_schedule . append ( ScheduledPromptConditioning ( end_at_step , conds [ i ] ) )
cache [ prompt ] = cond_schedule
res . append ( cond_schedule )
2022-10-06 04:16:27 +08:00
return res
re_AND = re . compile ( r " \ bAND \ b " )
2022-10-06 04:52:05 +08:00
re_weight = re . compile ( r " ^(.*?)(?: \ s*: \ s*([-+]?(?: \ d+ \ .?| \ d* \ . \ d+)))? \ s*$ " )
2022-10-06 04:16:27 +08:00
def get_multicond_prompt_list ( prompts ) :
res_indexes = [ ]
prompt_flat_list = [ ]
prompt_indexes = { }
for prompt in prompts :
subprompts = re_AND . split ( prompt )
indexes = [ ]
for subprompt in subprompts :
2022-10-06 18:21:12 +08:00
match = re_weight . search ( subprompt )
text , weight = match . groups ( ) if match is not None else ( subprompt , 1.0 )
2022-10-06 04:16:27 +08:00
weight = float ( weight ) if weight is not None else 1.0
index = prompt_indexes . get ( text , None )
if index is None :
index = len ( prompt_flat_list )
prompt_flat_list . append ( text )
prompt_indexes [ text ] = index
indexes . append ( ( index , weight ) )
res_indexes . append ( indexes )
return res_indexes , prompt_flat_list , prompt_indexes
class ComposableScheduledPromptConditioning :
def __init__ ( self , schedules , weight = 1.0 ) :
2022-10-06 05:11:30 +08:00
self . schedules : List [ ScheduledPromptConditioning ] = schedules
2022-10-06 04:16:27 +08:00
self . weight : float = weight
class MulticondLearnedConditioning :
def __init__ ( self , shape , batch ) :
self . shape : tuple = shape # the shape field is needed to send this object to DDIM/PLMS
2022-10-06 05:11:30 +08:00
self . batch : List [ List [ ComposableScheduledPromptConditioning ] ] = batch
2022-09-15 18:10:16 +08:00
2022-10-06 04:16:27 +08:00
def get_multicond_learned_conditioning ( model , prompts , steps ) - > MulticondLearnedConditioning :
""" same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt , the list is obtained by splitting the prompt using the AND separator .
https : / / energy - based - model . github . io / Compositional - Visual - Generation - with - Composable - Diffusion - Models /
"""
res_indexes , prompt_flat_list , prompt_indexes = get_multicond_prompt_list ( prompts )
learned_conditioning = get_learned_conditioning ( model , prompt_flat_list , steps )
res = [ ]
for indexes in res_indexes :
res . append ( [ ComposableScheduledPromptConditioning ( learned_conditioning [ i ] , weight ) for i , weight in indexes ] )
return MulticondLearnedConditioning ( shape = ( len ( prompts ) , ) , batch = res )
2022-10-06 05:11:30 +08:00
def reconstruct_cond_batch ( c : List [ List [ ScheduledPromptConditioning ] ] , current_step ) :
2022-10-06 04:16:27 +08:00
param = c [ 0 ] [ 0 ] . cond
res = torch . zeros ( ( len ( c ) , ) + param . shape , device = param . device , dtype = param . dtype )
for i , cond_schedule in enumerate ( c ) :
2022-09-15 18:10:16 +08:00
target_index = 0
2022-10-04 23:49:51 +08:00
for current , ( end_at , cond ) in enumerate ( cond_schedule ) :
2022-09-15 18:10:16 +08:00
if current_step < = end_at :
2022-10-04 23:49:51 +08:00
target_index = current
2022-09-15 18:10:16 +08:00
break
res [ i ] = cond_schedule [ target_index ] . cond
2022-09-19 23:18:33 +08:00
return res
2022-09-15 18:10:16 +08:00
2022-10-06 04:16:27 +08:00
def reconstruct_multicond_batch ( c : MulticondLearnedConditioning , current_step ) :
param = c . batch [ 0 ] [ 0 ] . schedules [ 0 ] . cond
tensors = [ ]
conds_list = [ ]
for batch_no , composable_prompts in enumerate ( c . batch ) :
conds_for_batch = [ ]
for cond_index , composable_prompt in enumerate ( composable_prompts ) :
target_index = 0
for current , ( end_at , cond ) in enumerate ( composable_prompt . schedules ) :
if current_step < = end_at :
target_index = current
break
conds_for_batch . append ( ( len ( tensors ) , composable_prompt . weight ) )
tensors . append ( composable_prompt . schedules [ target_index ] . cond )
conds_list . append ( conds_for_batch )
2022-10-08 20:43:25 +08:00
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
# and won't be able to torch.stack them. So this fixes that.
token_count = max ( [ x . shape [ 0 ] for x in tensors ] )
for i in range ( len ( tensors ) ) :
if tensors [ i ] . shape [ 0 ] != token_count :
last_vector = tensors [ i ] [ - 1 : ]
last_vector_repeated = last_vector . repeat ( [ token_count - tensors [ i ] . shape [ 0 ] , 1 ] )
tensors [ i ] = torch . vstack ( [ tensors [ i ] , last_vector_repeated ] )
2022-10-06 04:16:27 +08:00
return conds_list , torch . stack ( tensors ) . to ( device = param . device , dtype = param . dtype )
2022-09-29 16:31:48 +08:00
re_attention = re . compile ( r """
\\\( |
\\\) |
\\\[ |
\\] |
\\\\|
\\|
\( |
\[ |
: ( [ + - ] ? [ . \d ] + ) \) |
\) |
] |
[ ^ \\( ) \[ \] : ] + |
:
""" , re.X)
2023-01-16 03:29:53 +08:00
re_break = re . compile ( r " \ s* \ bBREAK \ b \ s* " , re . S )
2022-09-29 16:31:48 +08:00
def parse_prompt_attention ( text ) :
"""
2022-10-19 01:18:56 +08:00
Parses a string with attention tokens and returns a list of pairs : text and its associated weight .
2022-09-29 16:31:48 +08:00
Accepted tokens are :
( abc ) - increases attention to abc by a multiplier of 1.1
( abc : 3.12 ) - increases attention to abc by a multiplier of 3.12
[ abc ] - decreases attention to abc by a multiplier of 1.1
\( - literal character ' ( '
\[ - literal character ' [ '
\) - literal character ' ) '
\] - literal character ' ] '
\\ - literal character ' \'
anything else - just text
2022-10-04 23:49:51 +08:00
>> > parse_prompt_attention ( ' normal text ' )
[ [ ' normal text ' , 1.0 ] ]
>> > parse_prompt_attention ( ' an (important) word ' )
[ [ ' an ' , 1.0 ] , [ ' important ' , 1.1 ] , [ ' word ' , 1.0 ] ]
>> > parse_prompt_attention ( ' (unbalanced ' )
[ [ ' unbalanced ' , 1.1 ] ]
>> > parse_prompt_attention ( ' \ (literal \ ] ' )
[ [ ' (literal] ' , 1.0 ] ]
>> > parse_prompt_attention ( ' (unnecessary)(parens) ' )
[ [ ' unnecessaryparens ' , 1.1 ] ]
>> > parse_prompt_attention ( ' a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))). ' )
[ [ ' a ' , 1.0 ] ,
[ ' house ' , 1.5730000000000004 ] ,
[ ' ' , 1.1 ] ,
[ ' on ' , 1.0 ] ,
[ ' a ' , 1.1 ] ,
[ ' hill ' , 0.55 ] ,
[ ' , sun, ' , 1.1 ] ,
[ ' sky ' , 1.4641000000000006 ] ,
[ ' . ' , 1.1 ] ]
2022-09-29 16:31:48 +08:00
"""
2022-09-15 18:10:16 +08:00
2022-09-29 16:31:48 +08:00
res = [ ]
round_brackets = [ ]
square_brackets = [ ]
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range ( start_position , multiplier ) :
for p in range ( start_position , len ( res ) ) :
res [ p ] [ 1 ] * = multiplier
for m in re_attention . finditer ( text ) :
text = m . group ( 0 )
weight = m . group ( 1 )
if text . startswith ( ' \\ ' ) :
res . append ( [ text [ 1 : ] , 1.0 ] )
elif text == ' ( ' :
round_brackets . append ( len ( res ) )
elif text == ' [ ' :
square_brackets . append ( len ( res ) )
elif weight is not None and len ( round_brackets ) > 0 :
multiply_range ( round_brackets . pop ( ) , float ( weight ) )
elif text == ' ) ' and len ( round_brackets ) > 0 :
multiply_range ( round_brackets . pop ( ) , round_bracket_multiplier )
elif text == ' ] ' and len ( square_brackets ) > 0 :
multiply_range ( square_brackets . pop ( ) , square_bracket_multiplier )
else :
2023-01-16 03:29:53 +08:00
parts = re . split ( re_break , text )
for i , part in enumerate ( parts ) :
if i > 0 :
res . append ( [ " BREAK " , - 1 ] )
res . append ( [ part , 1.0 ] )
2022-09-29 16:31:48 +08:00
for pos in round_brackets :
multiply_range ( pos , round_bracket_multiplier )
for pos in square_brackets :
multiply_range ( pos , square_bracket_multiplier )
2022-09-29 16:39:55 +08:00
if len ( res ) == 0 :
res = [ [ " " , 1.0 ] ]
2022-10-04 23:49:51 +08:00
# merge runs of identical weights
i = 0
while i + 1 < len ( res ) :
if res [ i ] [ 1 ] == res [ i + 1 ] [ 1 ] :
res [ i ] [ 0 ] + = res [ i + 1 ] [ 0 ]
res . pop ( i + 1 )
else :
i + = 1
2022-09-29 16:31:48 +08:00
return res
2022-10-04 23:49:51 +08:00
if __name__ == " __main__ " :
import doctest
doctest . testmod ( optionflags = doctest . NORMALIZE_WHITESPACE )
else :
import torch # doctest faster