2022-09-10 04:16:02 +08:00
import csv
2023-11-27 19:39:50 +08:00
import fnmatch
2022-09-11 22:35:12 +08:00
import os
2022-09-10 04:16:02 +08:00
import os . path
2022-09-11 22:35:12 +08:00
import typing
import shutil
2022-09-10 04:16:02 +08:00
2022-09-11 22:35:12 +08:00
class PromptStyle ( typing . NamedTuple ) :
name : str
prompt : str
negative_prompt : str
2023-11-27 19:39:50 +08:00
path : str = None
2022-09-14 22:56:21 +08:00
def merge_prompts ( style_prompt : str , prompt : str ) - > str :
if " {prompt} " in style_prompt :
res = style_prompt . replace ( " {prompt} " , prompt )
else :
parts = filter ( None , ( prompt . strip ( ) , style_prompt . strip ( ) ) )
res = " , " . join ( parts )
2022-09-10 04:16:02 +08:00
2022-09-14 22:56:21 +08:00
return res
2022-09-10 04:16:02 +08:00
2022-09-14 22:56:21 +08:00
def apply_styles_to_prompt ( prompt , styles ) :
for style in styles :
prompt = merge_prompts ( style , prompt )
2022-09-10 05:51:07 +08:00
2023-12-05 03:40:12 +08:00
return prompt
2022-09-10 04:16:02 +08:00
2023-12-30 21:51:02 +08:00
def extract_style_text_from_prompt ( style_text , prompt ) :
""" This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
2023-06-04 15:56:48 +08:00
2023-12-30 21:51:02 +08:00
extract_style_text_from_prompt ( " masterpiece " , " 1girl, art by greg, masterpiece " ) outputs ( True , " 1girl, art by greg " )
extract_style_text_from_prompt ( " masterpiece, {prompt} " , " masterpiece, 1girl, art by greg " ) outputs ( True , " 1girl, art by greg " )
extract_style_text_from_prompt ( " masterpiece, {prompt} " , " exquisite, 1girl, art by greg " ) outputs ( False , " exquisite, 1girl, art by greg " )
2023-11-27 19:39:50 +08:00
"""
2023-12-30 21:51:02 +08:00
stripped_prompt = prompt . strip ( )
stripped_style_text = style_text . strip ( )
2023-06-04 15:56:48 +08:00
if " {prompt} " in stripped_style_text :
2023-12-30 21:51:02 +08:00
left , right = stripped_style_text . split ( " {prompt} " , 2 )
2023-06-04 15:56:48 +08:00
if stripped_prompt . startswith ( left ) and stripped_prompt . endswith ( right ) :
2023-12-30 21:51:02 +08:00
prompt = stripped_prompt [ len ( left ) : len ( stripped_prompt ) - len ( right ) ]
2023-06-04 15:56:48 +08:00
return True , prompt
else :
if stripped_prompt . endswith ( stripped_style_text ) :
2023-12-30 21:51:02 +08:00
prompt = stripped_prompt [ : len ( stripped_prompt ) - len ( stripped_style_text ) ]
if prompt . endswith ( ' , ' ) :
2023-06-04 15:56:48 +08:00
prompt = prompt [ : - 2 ]
2023-12-30 21:51:02 +08:00
2023-06-04 15:56:48 +08:00
return True , prompt
return False , prompt
2023-11-27 19:39:50 +08:00
def extract_original_prompts ( style : PromptStyle , prompt , negative_prompt ) :
"""
Takes a style and compares it to the prompt and negative prompt . If the style
matches , returns True plus the prompt and negative prompt with the style text
removed . Otherwise , returns False with the original prompt and negative prompt .
"""
2023-06-04 15:56:48 +08:00
if not style . prompt and not style . negative_prompt :
return False , prompt , negative_prompt
2023-12-30 21:51:02 +08:00
match_positive , extracted_positive = extract_style_text_from_prompt ( style . prompt , prompt )
2023-06-04 15:56:48 +08:00
if not match_positive :
return False , prompt , negative_prompt
2023-12-30 21:51:02 +08:00
match_negative , extracted_negative = extract_style_text_from_prompt ( style . negative_prompt , negative_prompt )
2023-06-04 15:56:48 +08:00
if not match_negative :
return False , prompt , negative_prompt
return True , extracted_positive , extracted_negative
2022-09-14 22:56:21 +08:00
class StyleDatabase :
def __init__ ( self , path : str ) :
2023-11-27 19:39:50 +08:00
self . no_style = PromptStyle ( " None " , " " , " " , None )
2023-01-14 19:56:39 +08:00
self . styles = { }
self . path = path
2022-09-10 04:16:02 +08:00
2023-11-27 19:39:50 +08:00
folder , file = os . path . split ( self . path )
2023-12-10 14:48:16 +08:00
filename , _ , ext = file . partition ( ' * ' )
self . default_path = os . path . join ( folder , filename + ext )
2023-11-27 19:39:50 +08:00
self . prompt_fields = [ field for field in PromptStyle . _fields if field != " path " ]
2023-01-14 19:56:39 +08:00
self . reload ( )
def reload ( self ) :
2023-11-27 19:39:50 +08:00
"""
Clears the style database and reloads the styles from the CSV file ( s )
matching the path used to initialize the database .
"""
2023-01-14 19:56:39 +08:00
self . styles . clear ( )
2023-11-27 19:39:50 +08:00
path , filename = os . path . split ( self . path )
if " * " in filename :
fileglob = filename . split ( " * " ) [ 0 ] + " *.csv "
filelist = [ ]
for file in os . listdir ( path ) :
if fnmatch . fnmatch ( file , fileglob ) :
filelist . append ( file )
# Add a visible divider to the style list
half_len = round ( len ( file ) / 2 )
divider = f " { ' - ' * ( 20 - half_len ) } { file . upper ( ) } "
divider = f " { divider } { ' - ' * ( 40 - len ( divider ) ) } "
self . styles [ divider ] = PromptStyle (
f " { divider } " , None , None , " do_not_save "
)
# Add styles from this CSV file
self . load_from_csv ( os . path . join ( path , file ) )
if len ( filelist ) == 0 :
print ( f " No styles found in { path } matching { fileglob } " )
return
elif not os . path . exists ( self . path ) :
print ( f " Style database not found: { self . path } " )
2022-09-14 22:56:21 +08:00
return
2023-11-27 19:39:50 +08:00
else :
self . load_from_csv ( self . path )
2022-09-14 22:56:21 +08:00
2023-11-27 19:39:50 +08:00
def load_from_csv ( self , path : str ) :
with open ( path , " r " , encoding = " utf-8-sig " , newline = " " ) as file :
2023-05-18 03:50:08 +08:00
reader = csv . DictReader ( file , skipinitialspace = True )
2022-09-14 22:56:21 +08:00
for row in reader :
2023-11-27 19:39:50 +08:00
# Ignore empty rows or rows starting with a comment
if not row or row [ " name " ] . startswith ( " # " ) :
continue
2022-09-14 22:56:21 +08:00
# Support loading old CSV format with "name, text"-columns
prompt = row [ " prompt " ] if " prompt " in row else row [ " text " ]
negative_prompt = row . get ( " negative_prompt " , " " )
2023-11-27 19:39:50 +08:00
# Add style to database
self . styles [ row [ " name " ] ] = PromptStyle (
row [ " name " ] , prompt , negative_prompt , path
)
2023-12-10 13:03:41 +08:00
def get_style_paths ( self ) - > set :
""" Returns a set of all distinct paths of files that styles are loaded from. """
2023-11-27 19:39:50 +08:00
# Update any styles without a path to the default path
for style in list ( self . styles . values ( ) ) :
if not style . path :
self . styles [ style . name ] = style . _replace ( path = self . default_path )
# Create a list of all distinct paths, including the default path
style_paths = set ( )
style_paths . add ( self . default_path )
for _ , style in self . styles . items ( ) :
if style . path :
style_paths . add ( style . path )
# Remove any paths for styles that are just list dividers
2023-12-10 13:03:41 +08:00
style_paths . discard ( " do_not_save " )
2023-11-27 19:39:50 +08:00
2023-12-10 13:03:41 +08:00
return style_paths
2022-09-14 22:56:21 +08:00
2022-09-30 11:01:32 +08:00
def get_style_prompts ( self , styles ) :
return [ self . styles . get ( x , self . no_style ) . prompt for x in styles ]
def get_negative_style_prompts ( self , styles ) :
return [ self . styles . get ( x , self . no_style ) . negative_prompt for x in styles ]
2022-09-14 22:56:21 +08:00
def apply_styles_to_prompt ( self , prompt , styles ) :
2023-11-27 19:39:50 +08:00
return apply_styles_to_prompt (
prompt , [ self . styles . get ( x , self . no_style ) . prompt for x in styles ]
)
2022-09-14 22:56:21 +08:00
def apply_negative_styles_to_prompt ( self , prompt , styles ) :
2023-11-27 19:39:50 +08:00
return apply_styles_to_prompt (
prompt , [ self . styles . get ( x , self . no_style ) . negative_prompt for x in styles ]
)
def save_styles ( self , path : str = None ) - > None :
# The path argument is deprecated, but kept for backwards compatibility
_ = path
2023-12-10 13:03:41 +08:00
style_paths = self . get_style_paths ( )
2023-11-27 19:39:50 +08:00
csv_names = [ os . path . split ( path ) [ 1 ] . lower ( ) for path in style_paths ]
for style_path in style_paths :
# Always keep a backup file around
if os . path . exists ( style_path ) :
shutil . copy ( style_path , f " { style_path } .bak " )
# Write the styles to the CSV file
with open ( style_path , " w " , encoding = " utf-8-sig " , newline = " " ) as file :
writer = csv . DictWriter ( file , fieldnames = self . prompt_fields )
writer . writeheader ( )
for style in ( s for s in self . styles . values ( ) if s . path == style_path ) :
# Skip style list dividers, e.g. "STYLES.CSV"
if style . name . lower ( ) . strip ( " # " ) in csv_names :
continue
# Write style fields, ignoring the path field
writer . writerow (
{ k : v for k , v in style . _asdict ( ) . items ( ) if k != " path " }
)
2023-06-04 15:56:48 +08:00
def extract_styles_from_prompt ( self , prompt , negative_prompt ) :
extracted = [ ]
applicable_styles = list ( self . styles . values ( ) )
while True :
found_style = None
for style in applicable_styles :
2023-11-27 19:39:50 +08:00
is_match , new_prompt , new_neg_prompt = extract_original_prompts (
style , prompt , negative_prompt
)
2023-06-04 15:56:48 +08:00
if is_match :
found_style = style
prompt = new_prompt
negative_prompt = new_neg_prompt
break
if not found_style :
break
applicable_styles . remove ( found_style )
extracted . append ( found_style . name )
return list ( reversed ( extracted ) ) , prompt , negative_prompt