2022-09-15 18:10:16 +08:00
import re
from collections import namedtuple
import torch
2022-10-03 19:25:36 +08:00
from lark import Lark , Transformer , Visitor
import functools
2022-09-15 18:10:16 +08:00
import modules . shared as shared
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
# [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
def get_learned_conditioning_prompt_schedules ( prompts , steps ) :
2022-10-03 19:25:36 +08:00
grammar = r """
start : prompt
prompt : ( emphasized | scheduled | weighted | plain ) *
! emphasized : " ( " prompt " ) "
| " ( " prompt " : " prompt " ) "
| " [ " prompt " ] "
scheduled : " [ " ( prompt " : " ) ? prompt " : " NUMBER " ] "
! weighted : " { " weighted_item ( " | " weighted_item ) * " } "
! weighted_item : prompt ( " : " prompt ) ?
plain : / ( [ ^ \\\[ \] ( ) { } : | ] | \\. ) + /
% import common . SIGNED_NUMBER - > NUMBER
"""
parser = Lark ( grammar , parser = ' lalr ' )
def collect_steps ( steps , tree ) :
l = [ steps ]
class CollectSteps ( Visitor ) :
def scheduled ( self , tree ) :
tree . children [ - 1 ] = float ( tree . children [ - 1 ] )
if tree . children [ - 1 ] < 1 :
tree . children [ - 1 ] * = steps
tree . children [ - 1 ] = min ( steps , int ( tree . children [ - 1 ] ) )
l . append ( tree . children [ - 1 ] )
CollectSteps ( ) . visit ( tree )
return sorted ( set ( l ) )
def at_step ( step , tree ) :
class AtStep ( Transformer ) :
def scheduled ( self , args ) :
if len ( args ) == 2 :
before , after , when = ( ) , * args
else :
before , after , when = args
yield before if step < = when else after
def start ( self , args ) :
def flatten ( x ) :
if type ( x ) == str :
yield x
else :
for gen in x :
yield from flatten ( gen )
return ' ' . join ( flatten ( args [ 0 ] ) )
def plain ( self , args ) :
yield args [ 0 ] . value
def __default__ ( self , data , children , meta ) :
for child in children :
yield from child
return AtStep ( ) . transform ( tree )
@functools.cache
def get_schedule ( prompt ) :
tree = parser . parse ( prompt )
return [ [ t , at_step ( t , tree ) ] for t in collect_steps ( steps , tree ) ]
return [ get_schedule ( prompt ) for prompt in prompts ]
2022-09-15 18:10:16 +08:00
ScheduledPromptConditioning = namedtuple ( " ScheduledPromptConditioning " , [ " end_at_step " , " cond " ] )
ScheduledPromptBatch = namedtuple ( " ScheduledPromptBatch " , [ " shape " , " schedules " ] )
def get_learned_conditioning ( prompts , steps ) :
res = [ ]
prompt_schedules = get_learned_conditioning_prompt_schedules ( prompts , steps )
cache = { }
for prompt , prompt_schedule in zip ( prompts , prompt_schedules ) :
cached = cache . get ( prompt , None )
if cached is not None :
res . append ( cached )
2022-09-15 23:05:42 +08:00
continue
2022-09-15 18:10:16 +08:00
texts = [ x [ 1 ] for x in prompt_schedule ]
conds = shared . sd_model . get_learned_conditioning ( texts )
cond_schedule = [ ]
for i , ( end_at_step , text ) in enumerate ( prompt_schedule ) :
cond_schedule . append ( ScheduledPromptConditioning ( end_at_step , conds [ i ] ) )
cache [ prompt ] = cond_schedule
res . append ( cond_schedule )
return ScheduledPromptBatch ( ( len ( prompts ) , ) + res [ 0 ] [ 0 ] . cond . shape , res )
def reconstruct_cond_batch ( c : ScheduledPromptBatch , current_step ) :
2022-09-19 23:39:21 +08:00
res = torch . zeros ( c . shape , device = shared . device , dtype = next ( shared . sd_model . parameters ( ) ) . dtype )
2022-09-15 18:10:16 +08:00
for i , cond_schedule in enumerate ( c . schedules ) :
target_index = 0
for curret_index , ( end_at , cond ) in enumerate ( cond_schedule ) :
if current_step < = end_at :
target_index = curret_index
break
res [ i ] = cond_schedule [ target_index ] . cond
2022-09-19 23:18:33 +08:00
return res
2022-09-15 18:10:16 +08:00
2022-09-29 16:31:48 +08:00
re_attention = re . compile ( r """
\\\( |
\\\) |
\\\[ |
\\] |
\\\\|
\\|
\( |
\[ |
: ( [ + - ] ? [ . \d ] + ) \) |
\) |
] |
[ ^ \\( ) \[ \] : ] + |
:
""" , re.X)
def parse_prompt_attention ( text ) :
"""
Parses a string with attention tokens and returns a list of pairs : text and its assoicated weight .
Accepted tokens are :
( abc ) - increases attention to abc by a multiplier of 1.1
( abc : 3.12 ) - increases attention to abc by a multiplier of 3.12
[ abc ] - decreases attention to abc by a multiplier of 1.1
\( - literal character ' ( '
\[ - literal character ' [ '
\) - literal character ' ) '
\] - literal character ' ] '
\\ - literal character ' \'
anything else - just text
Example :
' a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))). '
produces :
[
[ ' a ' , 1.0 ] ,
[ ' house ' , 1.5730000000000004 ] ,
[ ' ' , 1.1 ] ,
[ ' on ' , 1.0 ] ,
[ ' a ' , 1.1 ] ,
[ ' hill ' , 0.55 ] ,
[ ' , sun, ' , 1.1 ] ,
[ ' sky ' , 1.4641000000000006 ] ,
[ ' . ' , 1.1 ]
]
"""
2022-09-15 18:10:16 +08:00
2022-09-29 16:31:48 +08:00
res = [ ]
round_brackets = [ ]
square_brackets = [ ]
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range ( start_position , multiplier ) :
for p in range ( start_position , len ( res ) ) :
res [ p ] [ 1 ] * = multiplier
for m in re_attention . finditer ( text ) :
text = m . group ( 0 )
weight = m . group ( 1 )
if text . startswith ( ' \\ ' ) :
res . append ( [ text [ 1 : ] , 1.0 ] )
elif text == ' ( ' :
round_brackets . append ( len ( res ) )
elif text == ' [ ' :
square_brackets . append ( len ( res ) )
elif weight is not None and len ( round_brackets ) > 0 :
multiply_range ( round_brackets . pop ( ) , float ( weight ) )
elif text == ' ) ' and len ( round_brackets ) > 0 :
multiply_range ( round_brackets . pop ( ) , round_bracket_multiplier )
elif text == ' ] ' and len ( square_brackets ) > 0 :
multiply_range ( square_brackets . pop ( ) , square_bracket_multiplier )
else :
res . append ( [ text , 1.0 ] )
for pos in round_brackets :
multiply_range ( pos , round_bracket_multiplier )
for pos in square_brackets :
multiply_range ( pos , square_bracket_multiplier )
2022-09-29 16:39:55 +08:00
if len ( res ) == 0 :
res = [ [ " " , 1.0 ] ]
2022-09-29 16:31:48 +08:00
return res