Translate and optimize code to comply with coding styles

This commit is contained in:
ChristopheZhao 2025-03-15 10:20:05 +00:00
parent 4ded88c46e
commit 654ee7ec00

View File

@ -1,17 +1,25 @@
import modules.scripts as scripts """
Txt2Img Prompt Optimizer (Multilingual)
This script optimizes text prompts for Stable Diffusion image generation.
It can detect non-English prompts, translate them to English, and then optimize them
for better image generation results.
The script uses a LangGraph workflow to manage the optimization process, with nodes for
language detection, translation, and optimization. If LangGraph is not available,
it falls back to a simplified workflow.
"""
from modules import scripts
from modules.processing import StableDiffusionProcessingTxt2Img from modules.processing import StableDiffusionProcessingTxt2Img
import os import os
import re
import json
from dotenv import load_dotenv from dotenv import load_dotenv
import requests import requests
from typing import Dict, List, Literal, TypedDict, Union, Optional, Any, Callable, Annotated from typing import Dict, Literal, TypedDict, Optional, Any
import functools
# Try to import LangGraph related libraries # Try to import LangGraph related libraries
try: try:
from langgraph.graph import StateGraph, END, START from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
LANGGRAPH_AVAILABLE = True LANGGRAPH_AVAILABLE = True
except ImportError: except ImportError:
LANGGRAPH_AVAILABLE = False LANGGRAPH_AVAILABLE = False
@ -36,14 +44,13 @@ DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
# Define state type # Define state type
class PromptState(TypedDict): class PromptState(TypedDict):
original_prompt: str original_prompt: str
language: str # Changed: from is_chinese to language, can be "english", "chinese", "other", etc. language: str
translated_prompt: Optional[str] translated_prompt: Optional[str]
optimized_prompt: Optional[str] optimized_prompt: Optional[str]
error: Optional[str] error: Optional[str]
class PromptTemplate(BaseModel): class PromptTemplate(BaseModel):
"""Prompt template Pydantic model""" """Prompt template for specific tasks"""
name: str = Field(..., description="Template name") name: str = Field(..., description="Template name")
content: str = Field(..., description="Template content") content: str = Field(..., description="Template content")
@ -55,7 +62,7 @@ class PromptTemplates(BaseModel):
txt2img_optimizer: PromptTemplate = Field( txt2img_optimizer: PromptTemplate = Field(
default=PromptTemplate( default=PromptTemplate(
name="Stable Diffusion Prompt Optimizer", name="Stable Diffusion Prompt Optimizer",
content=""" content="""\
You are an expert prompt engineer for Stable Diffusion image generation with deep knowledge of how SD models interpret text. You are an expert prompt engineer for Stable Diffusion image generation with deep knowledge of how SD models interpret text.
Your task is to transform standard prompts into highly optimized versions that produce exceptional quality images. Follow these guidelines: Your task is to transform standard prompts into highly optimized versions that produce exceptional quality images. Follow these guidelines:
@ -85,7 +92,7 @@ class PromptTemplates(BaseModel):
language_detector: PromptTemplate = Field( language_detector: PromptTemplate = Field(
default=PromptTemplate( default=PromptTemplate(
name="Language Detector", name="Language Detector",
content=""" content="""\
You are a language detection expert. Your task is to identify if the given text is in English or not. You are a language detection expert. Your task is to identify if the given text is in English or not.
Analyze the provided text and determine if it's in English. Return ONLY 'yes' if the text is primarily in English, or 'no' if it's primarily in another language. Analyze the provided text and determine if it's in English. Return ONLY 'yes' if the text is primarily in English, or 'no' if it's primarily in another language.
@ -95,14 +102,14 @@ class PromptTemplates(BaseModel):
Return ONLY 'yes' or 'no' without any explanations or additional text. Return ONLY 'yes' or 'no' without any explanations or additional text.
""" """
), ),
description="Language detection template" description="Language detection template"
) )
universal_translator: PromptTemplate = Field( universal_translator: PromptTemplate = Field(
default=PromptTemplate( default=PromptTemplate(
name="Universal Translator", name="Universal Translator",
content=""" content="""\
You are a professional translator specializing in translating text to English for image generation. You are a professional translator specializing in translating text to English for image generation.
Your task is to accurately translate prompts from any language to English while preserving the original meaning and intent. Follow these guidelines: Your task is to accurately translate prompts from any language to English while preserving the original meaning and intent. Follow these guidelines:
@ -129,6 +136,21 @@ class PromptTemplates(BaseModel):
TEMPLATES = PromptTemplates() TEMPLATES = PromptTemplates()
# Helper function for simple language detection
def simple_language_detection(prompt: str) -> str:
"""Simple language detection based on ASCII character ratio"""
if not prompt:
return "unknown"
non_ascii_chars = 0
for char in prompt:
if ord(char) > 127:
non_ascii_chars += 1
language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other"
print(f"Simple language detection: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}'")
return language
# Agent functions # Agent functions
def router_agent(state: PromptState) -> Dict[str, Any]: def router_agent(state: PromptState) -> Dict[str, Any]:
"""Determine the language of the prompt""" """Determine the language of the prompt"""
@ -171,38 +193,25 @@ def router_agent(state: PromptState) -> Dict[str, Any]:
else: else:
print(f"RouterAgent: Language detection failed - {response.status_code} - {response.text}") print(f"RouterAgent: Language detection failed - {response.status_code} - {response.text}")
# Fallback to simple detection # Fallback to simple detection
non_ascii_chars = 0 language = simple_language_detection(prompt)
for char in prompt:
if ord(char) > 127:
non_ascii_chars += 1
language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other"
print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}' (simple detection)")
return {"language": language} return {"language": language}
except Exception as e: except Exception as e:
print(f"RouterAgent: Language detection failed - {str(e)}") print(f"RouterAgent: Language detection failed - {str(e)}")
# Fallback to simple detection # Fallback to simple detection
non_ascii_chars = 0 language = simple_language_detection(prompt)
for char in prompt:
if ord(char) > 127:
non_ascii_chars += 1
language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other"
print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}' (simple detection)")
return {"language": language} return {"language": language}
def translator_agent(state: PromptState) -> Dict[str, Any]: def translator_agent(state: PromptState) -> Dict[str, Any]:
"""Translate non-English prompts to English""" """Translate non-English prompts to English"""
prompt = state["original_prompt"] prompt = state["original_prompt"]
language = state["language"] language = state["language"]
if language == "english": if language == "english":
print(f"TranslatorAgent: Prompt is already in English, no translation needed") print("TranslatorAgent: Prompt is already in English, no translation needed")
return {"translated_prompt": prompt} return {"translated_prompt": prompt}
if not DEEPSEEK_API_KEY: if not DEEPSEEK_API_KEY:
print(f"TranslatorAgent: Warning - DEEPSEEK_API_KEY not set, using simplified translation") print("TranslatorAgent: Warning - DEEPSEEK_API_KEY not set, using simplified translation")
return {"error": "DEEPSEEK_API_KEY not set", "translated_prompt": prompt} return {"error": "DEEPSEEK_API_KEY not set", "translated_prompt": prompt}
try: try:
@ -332,19 +341,19 @@ def create_prompt_optimization_graph():
return None return None
# Create state graph # Create state graph
workflow = StateGraph(PromptState) graph = StateGraph(PromptState)
# Add nodes # Add nodes
workflow.add_node("router", router_agent) graph.add_node("router", router_agent)
workflow.add_node("translator", translator_agent) graph.add_node("translator", translator_agent)
workflow.add_node("optimizer", optimizer_agent) graph.add_node("optimizer", optimizer_agent)
# Add edges # Add edges
# From start to router # From start to router
workflow.set_entry_point("router") graph.set_entry_point("router")
# From router to translator or optimizer (based on language) # From router to translator or optimizer (based on language)
workflow.add_conditional_edges( graph.add_conditional_edges(
"router", "router",
should_translate, should_translate,
{ {
@ -354,18 +363,18 @@ def create_prompt_optimization_graph():
) )
# From translator to optimizer # From translator to optimizer
workflow.add_edge("translator", "optimizer") graph.add_edge("translator", "optimizer")
# From optimizer to end # From optimizer to end
workflow.add_edge("optimizer", END) graph.add_edge("optimizer", END)
# Compile workflow # Compile workflow
return workflow.compile() return graph.compile()
# Simplified workflow (used when LangGraph is not available) # Simplified workflow (used when LangGraph is not available)
def simple_prompt_optimization_workflow(prompt: str) -> str: def simple_prompt_optimization_workflow(prompt: str) -> str:
"""Simplified prompt optimization workflow""" """Simplified prompt optimization workflow"""
print(f"\n--- Simplified workflow started ---") print("\n--- Simplified workflow started ---")
print(f"Original prompt: '{prompt}'") print(f"Original prompt: '{prompt}'")
# Initialize state # Initialize state
@ -395,7 +404,7 @@ def simple_prompt_optimization_workflow(prompt: str) -> str:
state["error"] = optimizer_result["error"] state["error"] = optimizer_result["error"]
print(f"Final optimized prompt: '{state['optimized_prompt']}'") print(f"Final optimized prompt: '{state['optimized_prompt']}'")
print(f"--- Simplified workflow finished ---\n") print("--- Simplified workflow finished ---\n")
return state["optimized_prompt"] or prompt return state["optimized_prompt"] or prompt
@ -479,11 +488,16 @@ class PromptOptimizer(scripts.Script):
def postprocess(self, p, processed): def postprocess(self, p, processed):
"""Post-process after image generation""" """Post-process after image generation"""
# Add original prompt to extra generation params
if hasattr(self, 'extra_generation_params') and hasattr(self, 'main_prompt'):
processed.infotexts[0] = processed.infotexts[0].replace(
"Prompt: ", f"Prompt: {self.extra_generation_params.get('Original prompt', '')}\nOptimized: "
)
# Nothing to do here # Nothing to do here
return processed return processed
def optimize_prompt(self, prompt): def optimize_prompt(self, prompt: str) -> str:
"""Optimize prompt to improve generation quality""" """Optimize a prompt using the workflow"""
if not prompt: if not prompt:
return prompt return prompt
@ -491,7 +505,7 @@ class PromptOptimizer(scripts.Script):
if self.graph is not None: if self.graph is not None:
# Use LangGraph workflow # Use LangGraph workflow
try: try:
print(f"\n--- LangGraph started ---") print("\n--- LangGraph started ---")
print(f"Original prompt: '{prompt}'") print(f"Original prompt: '{prompt}'")
# Create initial state # Create initial state
@ -508,43 +522,21 @@ class PromptOptimizer(scripts.Script):
optimized = final_state.get("optimized_prompt") or prompt optimized = final_state.get("optimized_prompt") or prompt
print(f"Final optimized prompt: '{optimized}'") print(f"Final optimized prompt: '{optimized}'")
print(f"--- LangGraph finished ---\n") print("--- LangGraph finished ---\n")
return optimized return optimized
except Exception as e: except Exception as e:
print(f"LangGraph failed: {str(e)}") print(f"LangGraph workflow error: {str(e)}")
# Fallback to simplified workflow print("Falling back to simplified workflow")
return simple_prompt_optimization_workflow(prompt) return simple_prompt_optimization_workflow(prompt)
else: else:
# Use simplified workflow # Use simplified workflow
return simple_prompt_optimization_workflow(prompt) return simple_prompt_optimization_workflow(prompt)
# Check if we're in a Stable Diffusion Webui environment
try:
from modules.processing import StableDiffusionProcessingTxt2Img
except ImportError:
# Create a mock class for testing outside of webui
class StableDiffusionProcessingTxt2Img:
"""Mock class for testing outside of webui"""
def __init__(self):
self.prompt = ""
self.all_prompts = []
self.main_prompt = ""
self.extra_generation_params = {}
def setup_prompts(self):
"""Mock setup_prompts method"""
pass
# For standalone testing # For standalone testing
if __name__ == "__main__": if __name__ == "__main__":
# Test the prompt optimization workflow # Test the prompt optimization workflow
test_prompts = [ test_prompts = [
"a cat", "a beautiful landscape with mountains", # English
"beautiful landscape",
"portrait of a woman",
"科幻城市", # Chinese: "sci-fi city"
"美丽的山水画", # Chinese: "beautiful landscape painting" "美丽的山水画", # Chinese: "beautiful landscape painting"
] ]