From 4ded88c46eeadc74b8de5d1942081c1550f355be Mon Sep 17 00:00:00 2001 From: ChristopheZhao <398453241@qq.com> Date: Sat, 15 Mar 2025 08:08:07 +0000 Subject: [PATCH 1/2] feat: Add multilingual prompt optimizer with LangGraph support - Add txt2img_prompt_optimizer.py script for automatic prompt translation and optimization - Support non-English prompts with automatic translation to English - Implement prompt optimization using LangGraph workflow - Add python-dotenv and langgraph dependencies - Update requirements.txt and requirements_versions.txt with new dependencies --- .gitignore | 1 + requirements.txt | 5 +- requirements_versions.txt | 3 + scripts/txt2img_prompt_optimizer.py | 560 ++++++++++++++++++++++++++++ 4 files changed, 568 insertions(+), 1 deletion(-) create mode 100644 scripts/txt2img_prompt_optimizer.py diff --git a/.gitignore b/.gitignore index e81ad31f5..31683f118 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ notification.mp3 /cache trace.json /sysinfo-????-??-??-??-??.json +.env diff --git a/requirements.txt b/requirements.txt index 0d6bac600..37545db21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,4 +31,7 @@ torch torchdiffeq torchsde transformers==4.30.2 -pillow-avif-plugin==1.4.3 \ No newline at end of file +pillow-avif-plugin==1.4.3 + +python-dotenv +langgraph \ No newline at end of file diff --git a/requirements_versions.txt b/requirements_versions.txt index 0306ce94f..9fae9d81e 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -33,3 +33,6 @@ torchsde==0.2.6 transformers==4.30.2 httpx==0.24.1 pillow-avif-plugin==1.4.3 + +python-dotenv==1.0.1 +langgraph==0.2.32 diff --git a/scripts/txt2img_prompt_optimizer.py b/scripts/txt2img_prompt_optimizer.py new file mode 100644 index 000000000..775f52f14 --- /dev/null +++ b/scripts/txt2img_prompt_optimizer.py @@ -0,0 +1,560 @@ +import modules.scripts as scripts +from modules.processing import StableDiffusionProcessingTxt2Img +import os +import re +import json +from dotenv import load_dotenv +import requests +from typing import Dict, List, Literal, TypedDict, Union, Optional, Any, Callable, Annotated +import functools + +# Try to import LangGraph related libraries +try: + from langgraph.graph import StateGraph, END, START + from langgraph.checkpoint.memory import MemorySaver + LANGGRAPH_AVAILABLE = True +except ImportError: + LANGGRAPH_AVAILABLE = False + print("Warning: LangGraph library not installed, using simplified implementation") + print("Can be installed via 'pip install langgraph'") + +# Try to import Pydantic +try: + from pydantic import BaseModel, Field + PYDANTIC_AVAILABLE = True +except ImportError: + PYDANTIC_AVAILABLE = False + print("Warning: Pydantic library not installed, using simplified implementation") + print("Can be installed via 'pip install pydantic'") + +# Load environment variables +load_dotenv() + +# Get DeepSeek API key +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") + +# Define state type +class PromptState(TypedDict): + original_prompt: str + language: str # Changed: from is_chinese to language, can be "english", "chinese", "other", etc. + translated_prompt: Optional[str] + optimized_prompt: Optional[str] + error: Optional[str] + + +class PromptTemplate(BaseModel): + """Prompt template Pydantic model""" + name: str = Field(..., description="Template name") + content: str = Field(..., description="Template content") + + def __str__(self) -> str: + return self.content.strip() + +class PromptTemplates(BaseModel): + """Collection of prompt templates""" + txt2img_optimizer: PromptTemplate = Field( + default=PromptTemplate( + name="Stable Diffusion Prompt Optimizer", + content=""" + You are an expert prompt engineer for Stable Diffusion image generation with deep knowledge of how SD models interpret text. + + Your task is to transform standard prompts into highly optimized versions that produce exceptional quality images. Follow these guidelines: + + 1. Maintain the original subject and core concept + 2. Enhance with precise descriptive adjectives and specific details + 3. Add appropriate artistic style references (artists, movements, platforms) + 4. Incorporate quality-boosting terms (masterpiece, best quality, highly detailed) + 5. Apply technical enhancements through brackets for emphasis: + - Use (term) for 1.1x emphasis + - Use ((term)) for 1.2x emphasis + - Use [term] for 0.9x emphasis + - Use [[term]] for 0.8x emphasis + - Use :1.x for specific weighting + + 6. Structure prompts effectively: + - Main subject first with strongest emphasis + - Scene details and environment + - Style, quality, and technical terms last + + Return ONLY the optimized prompt without explanations or commentary. Preserve all special formatting like (), [], {}, :1.2, etc. from the original prompt. + """ + ), + description="Stable Diffusion prompt optimization template" + ) + + language_detector: PromptTemplate = Field( + default=PromptTemplate( + name="Language Detector", + content=""" + You are a language detection expert. Your task is to identify if the given text is in English or not. + + Analyze the provided text and determine if it's in English. Return ONLY 'yes' if the text is primarily in English, or 'no' if it's primarily in another language. + + If the text is primarily in English or contains mostly English words with a few non-English terms, return 'yes'. + If the text is primarily in another language, return 'no'. + + Return ONLY 'yes' or 'no' without any explanations or additional text. + """ + ), + description="Language detection template" + ) + + universal_translator: PromptTemplate = Field( + default=PromptTemplate( + name="Universal Translator", + content=""" + You are a professional translator specializing in translating text to English for image generation. + + Your task is to accurately translate prompts from any language to English while preserving the original meaning and intent. Follow these guidelines: + + 1. Maintain the core subject and concept of the original prompt + 2. Preserve any special formatting like (), [], {}, :1.2, etc. + 3. Translate cultural-specific terms appropriately for an international audience + 4. Keep artistic style references intact + 5. Ensure the translation is natural and fluent in English + + Return ONLY the translated English prompt without explanations or commentary. + """ + ), + description="Universal translation template" + ) + + def get(self, template_name: str) -> PromptTemplate: + """Get template by name""" + if hasattr(self, template_name): + return getattr(self, template_name) + raise ValueError(f"Template not found: {template_name}") + +# Create template instance +TEMPLATES = PromptTemplates() + + +# Agent functions +def router_agent(state: PromptState) -> Dict[str, Any]: + """Determine the language of the prompt""" + prompt = state["original_prompt"] + + if not prompt: + return {"language": "unknown"} + + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {DEEPSEEK_API_KEY}" + } + + # Use predefined language detection template + detector_template = TEMPLATES.get("language_detector") + + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "system", "content": detector_template.content}, + {"role": "user", "content": f"Is this text in English? {prompt}"} + ], + "temperature": 0.1, + "max_tokens": 10 + } + + response = requests.post( + "https://api.deepseek.com/v1/chat/completions", + headers=headers, + json=payload + ) + + if response.status_code == 200: + result = response.json() + is_english = result["choices"][0]["message"]["content"].strip().lower() == "yes" + language = "english" if is_english else "other" + print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}'") + return {"language": language} + else: + print(f"RouterAgent: Language detection failed - {response.status_code} - {response.text}") + # Fallback to simple detection + non_ascii_chars = 0 + for char in prompt: + if ord(char) > 127: + non_ascii_chars += 1 + + language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other" + print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}' (simple detection)") + return {"language": language} + except Exception as e: + print(f"RouterAgent: Language detection failed - {str(e)}") + # Fallback to simple detection + non_ascii_chars = 0 + for char in prompt: + if ord(char) > 127: + non_ascii_chars += 1 + + language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other" + print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}' (simple detection)") + return {"language": language} + + +def translator_agent(state: PromptState) -> Dict[str, Any]: + """Translate non-English prompts to English""" + prompt = state["original_prompt"] + language = state["language"] + + if language == "english": + print(f"TranslatorAgent: Prompt is already in English, no translation needed") + return {"translated_prompt": prompt} + + if not DEEPSEEK_API_KEY: + print(f"TranslatorAgent: Warning - DEEPSEEK_API_KEY not set, using simplified translation") + return {"error": "DEEPSEEK_API_KEY not set", "translated_prompt": prompt} + + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {DEEPSEEK_API_KEY}" + } + + # Use predefined universal translation template + translator_template = TEMPLATES.get("universal_translator") + + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "system", "content": translator_template.content}, + {"role": "user", "content": f"Translate this prompt from {language} to English: {prompt}"} + ], + "temperature": 0.1, + "max_tokens": 1000 + } + + response = requests.post( + "https://api.deepseek.com/v1/chat/completions", + headers=headers, + json=payload + ) + + if response.status_code == 200: + result = response.json() + translated_text = result["choices"][0]["message"]["content"].strip() + print(f"TranslatorAgent: Translation result - '{translated_text}'") + return {"translated_prompt": translated_text} + else: + print(f"TranslatorAgent: Translation failed - {response.status_code} - {response.text}") + return {"error": f"Translation API error: {response.status_code}", "translated_prompt": prompt} + except Exception as e: + print(f"TranslatorAgent: Translation failed - {str(e)}") + return {"error": f"Translation error: {str(e)}", "translated_prompt": prompt} + +def optimizer_agent(state: PromptState) -> Dict[str, Any]: + """Optimize English prompts""" + # Determine the prompt to optimize + prompt_to_optimize = state.get("translated_prompt") or state["original_prompt"] + + if not DEEPSEEK_API_KEY: + print("OptimizerAgent: Warning - DEEPSEEK_API_KEY not set, using local optimization") + optimized = local_optimize(prompt_to_optimize) + return {"optimized_prompt": optimized} + + try: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {DEEPSEEK_API_KEY}" + } + + # Use predefined optimization template + optimizer_template = TEMPLATES.get("txt2img_optimizer") + + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "system", "content": optimizer_template.content}, + {"role": "user", "content": f"Optimize this prompt: {prompt_to_optimize}"} + ], + "temperature": 0.3, + "max_tokens": 1000 + } + + response = requests.post( + "https://api.deepseek.com/v1/chat/completions", + headers=headers, + json=payload + ) + + if response.status_code == 200: + result = response.json() + enhanced_text = result["choices"][0]["message"]["content"].strip() + print(f"OptimizerAgent: Optimization result - '{enhanced_text}'") + return {"optimized_prompt": enhanced_text} + else: + print(f"OptimizerAgent: Optimization failed - {response.status_code} - {response.text}") + optimized = local_optimize(prompt_to_optimize) + return {"error": f"Optimization API error: {response.status_code}", "optimized_prompt": optimized} + except Exception as e: + print(f"OptimizerAgent: Optimization failed - {str(e)}") + optimized = local_optimize(prompt_to_optimize) + return {"error": f"Optimization error: {str(e)}", "optimized_prompt": optimized} + +def local_optimize(prompt: str) -> str: + """Local prompt optimization method (used when API is unavailable)""" + # Example optimization: add quality-boosting keywords + quality_terms = ["high quality", "detailed", "sharp focus"] + style_terms = ["masterpiece", "best quality"] + + # Check if prompt already contains these terms + optimized = prompt + + # Add quality terms + for term in quality_terms: + if term.lower() not in optimized.lower(): + if optimized.strip().endswith((',', '。', ',', '.')): + optimized = f"{optimized} {term}" + else: + optimized = f"{optimized}, {term}" + + # Add style terms (at the beginning) + for term in reversed(style_terms): + if term.lower() not in optimized.lower(): + optimized = f"{term}, {optimized}" + + print(f"OptimizerAgent: Local optimization result - '{optimized}'") + return optimized + +# Define routing logic +def should_translate(state: PromptState) -> Literal["translator", "optimizer"]: + """Determine if translation is needed""" + if state.get("language", "") != "english": + return "translator" + else: + return "optimizer" + +# Create LangGraph workflow +def create_prompt_optimization_graph(): + """Create prompt optimization workflow graph""" + # If LangGraph is not available, return None + if not LANGGRAPH_AVAILABLE: + return None + + # Create state graph + workflow = StateGraph(PromptState) + + # Add nodes + workflow.add_node("router", router_agent) + workflow.add_node("translator", translator_agent) + workflow.add_node("optimizer", optimizer_agent) + + # Add edges + # From start to router + workflow.set_entry_point("router") + + # From router to translator or optimizer (based on language) + workflow.add_conditional_edges( + "router", + should_translate, + { + "translator": "translator", + "optimizer": "optimizer" + } + ) + + # From translator to optimizer + workflow.add_edge("translator", "optimizer") + + # From optimizer to end + workflow.add_edge("optimizer", END) + + # Compile workflow + return workflow.compile() + +# Simplified workflow (used when LangGraph is not available) +def simple_prompt_optimization_workflow(prompt: str) -> str: + """Simplified prompt optimization workflow""" + print(f"\n--- Simplified workflow started ---") + print(f"Original prompt: '{prompt}'") + + # Initialize state + state = PromptState( + original_prompt=prompt, + language="unknown", + translated_prompt=None, + optimized_prompt=None, + error=None + ) + + # Step 1: Router - determine language + router_result = router_agent(state) + state["language"] = router_result["language"] + + # Step 2: Translator - translate if not English + if state["language"] != "english": + translator_result = translator_agent(state) + state["translated_prompt"] = translator_result.get("translated_prompt") + if "error" in translator_result: + state["error"] = translator_result["error"] + + # Step 3: Optimizer - optimize prompt + optimizer_result = optimizer_agent(state) + state["optimized_prompt"] = optimizer_result.get("optimized_prompt") + if "error" in optimizer_result and not state["error"]: + state["error"] = optimizer_result["error"] + + print(f"Final optimized prompt: '{state['optimized_prompt']}'") + print(f"--- Simplified workflow finished ---\n") + + return state["optimized_prompt"] or prompt + +class PromptOptimizer(scripts.Script): + # Class-level flag to track if initialization message has been shown + _init_message_shown = False + + def __init__(self): + super().__init__() + # Show initialization message only once + if not PromptOptimizer._init_message_shown: + print("\n\n=== Txt2Img Prompt Optimizer (Multilingual) script loaded ===\n\n") + PromptOptimizer._init_message_shown = True + + # Try to create LangGraph workflow + self.graph = create_prompt_optimization_graph() + + # If LangGraph is not available, use simplified workflow + if self.graph is None and not PromptOptimizer._init_message_shown: + print("Using simplified prompt optimization workflow") + + # Track processed prompts to avoid duplicates + self.processed_prompts = set() + + def title(self): + return "Txt2Img Prompt Optimizer (Multilingual)" + + # Return AlwaysVisible to show script in UI + def show(self, is_img2img): + return scripts.AlwaysVisible + + # No UI elements needed + def ui(self, is_img2img): + return [] + + # Optimize prompt before processing + def process(self, p): + # Only optimize Txt2Img processing objects + if not isinstance(p, StableDiffusionProcessingTxt2Img): + return p + + # Record original prompt + original_prompt = p.prompt + print(f"\n=== Original prompt ===\n{original_prompt}\n") + + # Optimize main prompt (if not already processed) + if p.prompt not in self.processed_prompts: + optimized_prompt = self.optimize_prompt(p.prompt) + p.prompt = optimized_prompt + # Ensure all_prompts also uses optimized prompt + if hasattr(p, 'all_prompts') and p.all_prompts: + p.all_prompts = [optimized_prompt] * len(p.all_prompts) + # Ensure main_prompt also uses optimized prompt + if hasattr(p, 'main_prompt'): + p.main_prompt = optimized_prompt + self.processed_prompts.add(optimized_prompt) + + # Record optimization information (optional, for verification) + if not hasattr(p, 'extra_generation_params'): + p.extra_generation_params = {} + p.extra_generation_params['Prompt optimized'] = True + + # Record final prompt sent to model + print(f"\n=== Final prompt sent to model ===\n{p.prompt}\n") + + # Add post-processing hook to ensure prompt remains optimized + original_setup_prompts = p.setup_prompts + + def patched_setup_prompts(): + # Call original method + original_setup_prompts() + # Ensure prompt remains optimized + if p.prompt in self.processed_prompts: + p.all_prompts = [p.prompt] * len(p.all_prompts) + p.main_prompt = p.prompt + + # Replace method + p.setup_prompts = patched_setup_prompts + + return p + + def postprocess(self, p, processed): + """Post-process after image generation""" + # Nothing to do here + return processed + + def optimize_prompt(self, prompt): + """Optimize prompt to improve generation quality""" + if not prompt: + return prompt + + # Use LangGraph workflow or simplified workflow + if self.graph is not None: + # Use LangGraph workflow + try: + print(f"\n--- LangGraph started ---") + print(f"Original prompt: '{prompt}'") + + # Create initial state + initial_state = PromptState( + original_prompt=prompt, + language="unknown", + translated_prompt=None, + optimized_prompt=None, + error=None + ) + + # Execute workflow + final_state = self.graph.invoke(initial_state) + + optimized = final_state.get("optimized_prompt") or prompt + print(f"Final optimized prompt: '{optimized}'") + print(f"--- LangGraph finished ---\n") + return optimized + except Exception as e: + print(f"LangGraph failed: {str(e)}") + # Fallback to simplified workflow + return simple_prompt_optimization_workflow(prompt) + else: + # Use simplified workflow + return simple_prompt_optimization_workflow(prompt) + + +# Check if we're in a Stable Diffusion Webui environment +try: + from modules.processing import StableDiffusionProcessingTxt2Img +except ImportError: + # Create a mock class for testing outside of webui + class StableDiffusionProcessingTxt2Img: + """Mock class for testing outside of webui""" + def __init__(self): + self.prompt = "" + self.all_prompts = [] + self.main_prompt = "" + self.extra_generation_params = {} + + def setup_prompts(self): + """Mock setup_prompts method""" + pass + + +# For standalone testing +if __name__ == "__main__": + # Test the prompt optimization workflow + test_prompts = [ + "a cat", + "beautiful landscape", + "portrait of a woman", + "科幻城市", # Chinese: "sci-fi city" + "美丽的山水画", # Chinese: "beautiful landscape painting" + ] + + print("Testing prompt optimization workflow...") + + # Initialize optimizer + optimizer = PromptOptimizer() + + # Test each prompt + for prompt in test_prompts: + print(f"\nTesting prompt: '{prompt}'") + optimized = optimizer.optimize_prompt(prompt) + print(f"Optimized: '{optimized}'") \ No newline at end of file From 654ee7ec00484f498b56fefe7ecdcbced8430c72 Mon Sep 17 00:00:00 2001 From: ChristopheZhao <398453241@qq.com> Date: Sat, 15 Mar 2025 10:20:05 +0000 Subject: [PATCH 2/2] Translate and optimize code to comply with coding styles --- scripts/txt2img_prompt_optimizer.py | 274 ++++++++++++++-------------- 1 file changed, 133 insertions(+), 141 deletions(-) diff --git a/scripts/txt2img_prompt_optimizer.py b/scripts/txt2img_prompt_optimizer.py index 775f52f14..562ebf135 100644 --- a/scripts/txt2img_prompt_optimizer.py +++ b/scripts/txt2img_prompt_optimizer.py @@ -1,17 +1,25 @@ -import modules.scripts as scripts +""" +Txt2Img Prompt Optimizer (Multilingual) + +This script optimizes text prompts for Stable Diffusion image generation. +It can detect non-English prompts, translate them to English, and then optimize them +for better image generation results. + +The script uses a LangGraph workflow to manage the optimization process, with nodes for +language detection, translation, and optimization. If LangGraph is not available, +it falls back to a simplified workflow. +""" + +from modules import scripts from modules.processing import StableDiffusionProcessingTxt2Img import os -import re -import json from dotenv import load_dotenv import requests -from typing import Dict, List, Literal, TypedDict, Union, Optional, Any, Callable, Annotated -import functools +from typing import Dict, Literal, TypedDict, Optional, Any # Try to import LangGraph related libraries try: - from langgraph.graph import StateGraph, END, START - from langgraph.checkpoint.memory import MemorySaver + from langgraph.graph import StateGraph, END LANGGRAPH_AVAILABLE = True except ImportError: LANGGRAPH_AVAILABLE = False @@ -36,33 +44,32 @@ DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") # Define state type class PromptState(TypedDict): original_prompt: str - language: str # Changed: from is_chinese to language, can be "english", "chinese", "other", etc. + language: str translated_prompt: Optional[str] optimized_prompt: Optional[str] error: Optional[str] - class PromptTemplate(BaseModel): - """Prompt template Pydantic model""" + """Prompt template for specific tasks""" name: str = Field(..., description="Template name") content: str = Field(..., description="Template content") - + def __str__(self) -> str: return self.content.strip() - + class PromptTemplates(BaseModel): """Collection of prompt templates""" txt2img_optimizer: PromptTemplate = Field( default=PromptTemplate( name="Stable Diffusion Prompt Optimizer", - content=""" - You are an expert prompt engineer for Stable Diffusion image generation with deep knowledge of how SD models interpret text. + content="""\ + You are an expert prompt engineer for Stable Diffusion image generation with deep knowledge of how SD models interpret text. Your task is to transform standard prompts into highly optimized versions that produce exceptional quality images. Follow these guidelines: 1. Maintain the original subject and core concept 2. Enhance with precise descriptive adjectives and specific details - 3. Add appropriate artistic style references (artists, movements, platforms) + 3. Add appropriate artistic style references (artists, movements, platforms) 4. Incorporate quality-boosting terms (masterpiece, best quality, highly detailed) 5. Apply technical enhancements through brackets for emphasis: - Use (term) for 1.1x emphasis @@ -81,11 +88,11 @@ class PromptTemplates(BaseModel): ), description="Stable Diffusion prompt optimization template" ) - + language_detector: PromptTemplate = Field( default=PromptTemplate( name="Language Detector", - content=""" + content="""\ You are a language detection expert. Your task is to identify if the given text is in English or not. Analyze the provided text and determine if it's in English. Return ONLY 'yes' if the text is primarily in English, or 'no' if it's primarily in another language. @@ -95,15 +102,15 @@ class PromptTemplates(BaseModel): Return ONLY 'yes' or 'no' without any explanations or additional text. """ - ), + ), description="Language detection template" ) - + universal_translator: PromptTemplate = Field( default=PromptTemplate( name="Universal Translator", - content=""" - You are a professional translator specializing in translating text to English for image generation. + content="""\ + You are a professional translator specializing in translating text to English for image generation. Your task is to accurately translate prompts from any language to English while preserving the original meaning and intent. Follow these guidelines: @@ -118,7 +125,7 @@ class PromptTemplates(BaseModel): ), description="Universal translation template" ) - + def get(self, template_name: str) -> PromptTemplate: """Get template by name""" if hasattr(self, template_name): @@ -129,23 +136,38 @@ class PromptTemplates(BaseModel): TEMPLATES = PromptTemplates() +# Helper function for simple language detection +def simple_language_detection(prompt: str) -> str: + """Simple language detection based on ASCII character ratio""" + if not prompt: + return "unknown" + + non_ascii_chars = 0 + for char in prompt: + if ord(char) > 127: + non_ascii_chars += 1 + + language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other" + print(f"Simple language detection: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}'") + return language + # Agent functions def router_agent(state: PromptState) -> Dict[str, Any]: """Determine the language of the prompt""" prompt = state["original_prompt"] - + if not prompt: return {"language": "unknown"} - + try: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {DEEPSEEK_API_KEY}" } - + # Use predefined language detection template detector_template = TEMPLATES.get("language_detector") - + payload = { "model": "deepseek-chat", "messages": [ @@ -155,13 +177,13 @@ def router_agent(state: PromptState) -> Dict[str, Any]: "temperature": 0.1, "max_tokens": 10 } - + response = requests.post( "https://api.deepseek.com/v1/chat/completions", headers=headers, json=payload ) - + if response.status_code == 200: result = response.json() is_english = result["choices"][0]["message"]["content"].strip().lower() == "yes" @@ -171,49 +193,36 @@ def router_agent(state: PromptState) -> Dict[str, Any]: else: print(f"RouterAgent: Language detection failed - {response.status_code} - {response.text}") # Fallback to simple detection - non_ascii_chars = 0 - for char in prompt: - if ord(char) > 127: - non_ascii_chars += 1 - - language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other" - print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}' (simple detection)") + language = simple_language_detection(prompt) return {"language": language} except Exception as e: print(f"RouterAgent: Language detection failed - {str(e)}") # Fallback to simple detection - non_ascii_chars = 0 - for char in prompt: - if ord(char) > 127: - non_ascii_chars += 1 - - language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other" - print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}' (simple detection)") + language = simple_language_detection(prompt) return {"language": language} - def translator_agent(state: PromptState) -> Dict[str, Any]: """Translate non-English prompts to English""" prompt = state["original_prompt"] language = state["language"] - + if language == "english": - print(f"TranslatorAgent: Prompt is already in English, no translation needed") + print("TranslatorAgent: Prompt is already in English, no translation needed") return {"translated_prompt": prompt} - + if not DEEPSEEK_API_KEY: - print(f"TranslatorAgent: Warning - DEEPSEEK_API_KEY not set, using simplified translation") + print("TranslatorAgent: Warning - DEEPSEEK_API_KEY not set, using simplified translation") return {"error": "DEEPSEEK_API_KEY not set", "translated_prompt": prompt} - + try: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {DEEPSEEK_API_KEY}" } - + # Use predefined universal translation template translator_template = TEMPLATES.get("universal_translator") - + payload = { "model": "deepseek-chat", "messages": [ @@ -223,13 +232,13 @@ def translator_agent(state: PromptState) -> Dict[str, Any]: "temperature": 0.1, "max_tokens": 1000 } - + response = requests.post( "https://api.deepseek.com/v1/chat/completions", headers=headers, json=payload ) - + if response.status_code == 200: result = response.json() translated_text = result["choices"][0]["message"]["content"].strip() @@ -246,21 +255,21 @@ def optimizer_agent(state: PromptState) -> Dict[str, Any]: """Optimize English prompts""" # Determine the prompt to optimize prompt_to_optimize = state.get("translated_prompt") or state["original_prompt"] - + if not DEEPSEEK_API_KEY: print("OptimizerAgent: Warning - DEEPSEEK_API_KEY not set, using local optimization") optimized = local_optimize(prompt_to_optimize) return {"optimized_prompt": optimized} - + try: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {DEEPSEEK_API_KEY}" } - + # Use predefined optimization template optimizer_template = TEMPLATES.get("txt2img_optimizer") - + payload = { "model": "deepseek-chat", "messages": [ @@ -270,13 +279,13 @@ def optimizer_agent(state: PromptState) -> Dict[str, Any]: "temperature": 0.3, "max_tokens": 1000 } - + response = requests.post( "https://api.deepseek.com/v1/chat/completions", headers=headers, json=payload ) - + if response.status_code == 200: result = response.json() enhanced_text = result["choices"][0]["message"]["content"].strip() @@ -296,10 +305,10 @@ def local_optimize(prompt: str) -> str: # Example optimization: add quality-boosting keywords quality_terms = ["high quality", "detailed", "sharp focus"] style_terms = ["masterpiece", "best quality"] - + # Check if prompt already contains these terms optimized = prompt - + # Add quality terms for term in quality_terms: if term.lower() not in optimized.lower(): @@ -307,12 +316,12 @@ def local_optimize(prompt: str) -> str: optimized = f"{optimized} {term}" else: optimized = f"{optimized}, {term}" - + # Add style terms (at the beginning) for term in reversed(style_terms): if term.lower() not in optimized.lower(): optimized = f"{term}, {optimized}" - + print(f"OptimizerAgent: Local optimization result - '{optimized}'") return optimized @@ -330,21 +339,21 @@ def create_prompt_optimization_graph(): # If LangGraph is not available, return None if not LANGGRAPH_AVAILABLE: return None - + # Create state graph - workflow = StateGraph(PromptState) - + graph = StateGraph(PromptState) + # Add nodes - workflow.add_node("router", router_agent) - workflow.add_node("translator", translator_agent) - workflow.add_node("optimizer", optimizer_agent) - + graph.add_node("router", router_agent) + graph.add_node("translator", translator_agent) + graph.add_node("optimizer", optimizer_agent) + # Add edges # From start to router - workflow.set_entry_point("router") - + graph.set_entry_point("router") + # From router to translator or optimizer (based on language) - workflow.add_conditional_edges( + graph.add_conditional_edges( "router", should_translate, { @@ -352,22 +361,22 @@ def create_prompt_optimization_graph(): "optimizer": "optimizer" } ) - + # From translator to optimizer - workflow.add_edge("translator", "optimizer") - + graph.add_edge("translator", "optimizer") + # From optimizer to end - workflow.add_edge("optimizer", END) - + graph.add_edge("optimizer", END) + # Compile workflow - return workflow.compile() + return graph.compile() # Simplified workflow (used when LangGraph is not available) def simple_prompt_optimization_workflow(prompt: str) -> str: """Simplified prompt optimization workflow""" - print(f"\n--- Simplified workflow started ---") + print("\n--- Simplified workflow started ---") print(f"Original prompt: '{prompt}'") - + # Initialize state state = PromptState( original_prompt=prompt, @@ -376,71 +385,71 @@ def simple_prompt_optimization_workflow(prompt: str) -> str: optimized_prompt=None, error=None ) - + # Step 1: Router - determine language router_result = router_agent(state) state["language"] = router_result["language"] - + # Step 2: Translator - translate if not English if state["language"] != "english": translator_result = translator_agent(state) state["translated_prompt"] = translator_result.get("translated_prompt") if "error" in translator_result: state["error"] = translator_result["error"] - + # Step 3: Optimizer - optimize prompt optimizer_result = optimizer_agent(state) state["optimized_prompt"] = optimizer_result.get("optimized_prompt") if "error" in optimizer_result and not state["error"]: state["error"] = optimizer_result["error"] - + print(f"Final optimized prompt: '{state['optimized_prompt']}'") - print(f"--- Simplified workflow finished ---\n") - + print("--- Simplified workflow finished ---\n") + return state["optimized_prompt"] or prompt class PromptOptimizer(scripts.Script): # Class-level flag to track if initialization message has been shown _init_message_shown = False - + def __init__(self): super().__init__() # Show initialization message only once if not PromptOptimizer._init_message_shown: print("\n\n=== Txt2Img Prompt Optimizer (Multilingual) script loaded ===\n\n") PromptOptimizer._init_message_shown = True - + # Try to create LangGraph workflow self.graph = create_prompt_optimization_graph() - + # If LangGraph is not available, use simplified workflow if self.graph is None and not PromptOptimizer._init_message_shown: print("Using simplified prompt optimization workflow") - + # Track processed prompts to avoid duplicates self.processed_prompts = set() - + def title(self): return "Txt2Img Prompt Optimizer (Multilingual)" - + # Return AlwaysVisible to show script in UI def show(self, is_img2img): return scripts.AlwaysVisible - + # No UI elements needed def ui(self, is_img2img): return [] - + # Optimize prompt before processing def process(self, p): # Only optimize Txt2Img processing objects if not isinstance(p, StableDiffusionProcessingTxt2Img): return p - + # Record original prompt original_prompt = p.prompt print(f"\n=== Original prompt ===\n{original_prompt}\n") - + # Optimize main prompt (if not already processed) if p.prompt not in self.processed_prompts: optimized_prompt = self.optimize_prompt(p.prompt) @@ -452,18 +461,18 @@ class PromptOptimizer(scripts.Script): if hasattr(p, 'main_prompt'): p.main_prompt = optimized_prompt self.processed_prompts.add(optimized_prompt) - + # Record optimization information (optional, for verification) if not hasattr(p, 'extra_generation_params'): p.extra_generation_params = {} p.extra_generation_params['Prompt optimized'] = True - + # Record final prompt sent to model print(f"\n=== Final prompt sent to model ===\n{p.prompt}\n") - + # Add post-processing hook to ensure prompt remains optimized original_setup_prompts = p.setup_prompts - + def patched_setup_prompts(): # Call original method original_setup_prompts() @@ -471,29 +480,34 @@ class PromptOptimizer(scripts.Script): if p.prompt in self.processed_prompts: p.all_prompts = [p.prompt] * len(p.all_prompts) p.main_prompt = p.prompt - + # Replace method p.setup_prompts = patched_setup_prompts - + return p - + def postprocess(self, p, processed): """Post-process after image generation""" + # Add original prompt to extra generation params + if hasattr(self, 'extra_generation_params') and hasattr(self, 'main_prompt'): + processed.infotexts[0] = processed.infotexts[0].replace( + "Prompt: ", f"Prompt: {self.extra_generation_params.get('Original prompt', '')}\nOptimized: " + ) # Nothing to do here return processed - - def optimize_prompt(self, prompt): - """Optimize prompt to improve generation quality""" + + def optimize_prompt(self, prompt: str) -> str: + """Optimize a prompt using the workflow""" if not prompt: return prompt - + # Use LangGraph workflow or simplified workflow if self.graph is not None: # Use LangGraph workflow try: - print(f"\n--- LangGraph started ---") + print("\n--- LangGraph started ---") print(f"Original prompt: '{prompt}'") - + # Create initial state initial_state = PromptState( original_prompt=prompt, @@ -502,59 +516,37 @@ class PromptOptimizer(scripts.Script): optimized_prompt=None, error=None ) - + # Execute workflow final_state = self.graph.invoke(initial_state) - + optimized = final_state.get("optimized_prompt") or prompt print(f"Final optimized prompt: '{optimized}'") - print(f"--- LangGraph finished ---\n") + print("--- LangGraph finished ---\n") return optimized except Exception as e: - print(f"LangGraph failed: {str(e)}") - # Fallback to simplified workflow + print(f"LangGraph workflow error: {str(e)}") + print("Falling back to simplified workflow") return simple_prompt_optimization_workflow(prompt) else: # Use simplified workflow return simple_prompt_optimization_workflow(prompt) - -# Check if we're in a Stable Diffusion Webui environment -try: - from modules.processing import StableDiffusionProcessingTxt2Img -except ImportError: - # Create a mock class for testing outside of webui - class StableDiffusionProcessingTxt2Img: - """Mock class for testing outside of webui""" - def __init__(self): - self.prompt = "" - self.all_prompts = [] - self.main_prompt = "" - self.extra_generation_params = {} - - def setup_prompts(self): - """Mock setup_prompts method""" - pass - - # For standalone testing if __name__ == "__main__": # Test the prompt optimization workflow test_prompts = [ - "a cat", - "beautiful landscape", - "portrait of a woman", - "科幻城市", # Chinese: "sci-fi city" + "a beautiful landscape with mountains", # English "美丽的山水画", # Chinese: "beautiful landscape painting" ] - + print("Testing prompt optimization workflow...") - + # Initialize optimizer optimizer = PromptOptimizer() - + # Test each prompt for prompt in test_prompts: print(f"\nTesting prompt: '{prompt}'") optimized = optimizer.optimize_prompt(prompt) - print(f"Optimized: '{optimized}'") \ No newline at end of file + print(f"Optimized: '{optimized}'")