mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-05-06 20:09:06 +08:00
Translate and optimize code to comply with coding styles
This commit is contained in:
parent
4ded88c46e
commit
654ee7ec00
@ -1,17 +1,25 @@
|
|||||||
import modules.scripts as scripts
|
"""
|
||||||
|
Txt2Img Prompt Optimizer (Multilingual)
|
||||||
|
|
||||||
|
This script optimizes text prompts for Stable Diffusion image generation.
|
||||||
|
It can detect non-English prompts, translate them to English, and then optimize them
|
||||||
|
for better image generation results.
|
||||||
|
|
||||||
|
The script uses a LangGraph workflow to manage the optimization process, with nodes for
|
||||||
|
language detection, translation, and optimization. If LangGraph is not available,
|
||||||
|
it falls back to a simplified workflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from modules import scripts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import json
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import requests
|
import requests
|
||||||
from typing import Dict, List, Literal, TypedDict, Union, Optional, Any, Callable, Annotated
|
from typing import Dict, Literal, TypedDict, Optional, Any
|
||||||
import functools
|
|
||||||
|
|
||||||
# Try to import LangGraph related libraries
|
# Try to import LangGraph related libraries
|
||||||
try:
|
try:
|
||||||
from langgraph.graph import StateGraph, END, START
|
from langgraph.graph import StateGraph, END
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
|
||||||
LANGGRAPH_AVAILABLE = True
|
LANGGRAPH_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LANGGRAPH_AVAILABLE = False
|
LANGGRAPH_AVAILABLE = False
|
||||||
@ -36,14 +44,13 @@ DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
|
|||||||
# Define state type
|
# Define state type
|
||||||
class PromptState(TypedDict):
|
class PromptState(TypedDict):
|
||||||
original_prompt: str
|
original_prompt: str
|
||||||
language: str # Changed: from is_chinese to language, can be "english", "chinese", "other", etc.
|
language: str
|
||||||
translated_prompt: Optional[str]
|
translated_prompt: Optional[str]
|
||||||
optimized_prompt: Optional[str]
|
optimized_prompt: Optional[str]
|
||||||
error: Optional[str]
|
error: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate(BaseModel):
|
class PromptTemplate(BaseModel):
|
||||||
"""Prompt template Pydantic model"""
|
"""Prompt template for specific tasks"""
|
||||||
name: str = Field(..., description="Template name")
|
name: str = Field(..., description="Template name")
|
||||||
content: str = Field(..., description="Template content")
|
content: str = Field(..., description="Template content")
|
||||||
|
|
||||||
@ -55,7 +62,7 @@ class PromptTemplates(BaseModel):
|
|||||||
txt2img_optimizer: PromptTemplate = Field(
|
txt2img_optimizer: PromptTemplate = Field(
|
||||||
default=PromptTemplate(
|
default=PromptTemplate(
|
||||||
name="Stable Diffusion Prompt Optimizer",
|
name="Stable Diffusion Prompt Optimizer",
|
||||||
content="""
|
content="""\
|
||||||
You are an expert prompt engineer for Stable Diffusion image generation with deep knowledge of how SD models interpret text.
|
You are an expert prompt engineer for Stable Diffusion image generation with deep knowledge of how SD models interpret text.
|
||||||
|
|
||||||
Your task is to transform standard prompts into highly optimized versions that produce exceptional quality images. Follow these guidelines:
|
Your task is to transform standard prompts into highly optimized versions that produce exceptional quality images. Follow these guidelines:
|
||||||
@ -85,7 +92,7 @@ class PromptTemplates(BaseModel):
|
|||||||
language_detector: PromptTemplate = Field(
|
language_detector: PromptTemplate = Field(
|
||||||
default=PromptTemplate(
|
default=PromptTemplate(
|
||||||
name="Language Detector",
|
name="Language Detector",
|
||||||
content="""
|
content="""\
|
||||||
You are a language detection expert. Your task is to identify if the given text is in English or not.
|
You are a language detection expert. Your task is to identify if the given text is in English or not.
|
||||||
|
|
||||||
Analyze the provided text and determine if it's in English. Return ONLY 'yes' if the text is primarily in English, or 'no' if it's primarily in another language.
|
Analyze the provided text and determine if it's in English. Return ONLY 'yes' if the text is primarily in English, or 'no' if it's primarily in another language.
|
||||||
@ -102,7 +109,7 @@ class PromptTemplates(BaseModel):
|
|||||||
universal_translator: PromptTemplate = Field(
|
universal_translator: PromptTemplate = Field(
|
||||||
default=PromptTemplate(
|
default=PromptTemplate(
|
||||||
name="Universal Translator",
|
name="Universal Translator",
|
||||||
content="""
|
content="""\
|
||||||
You are a professional translator specializing in translating text to English for image generation.
|
You are a professional translator specializing in translating text to English for image generation.
|
||||||
|
|
||||||
Your task is to accurately translate prompts from any language to English while preserving the original meaning and intent. Follow these guidelines:
|
Your task is to accurately translate prompts from any language to English while preserving the original meaning and intent. Follow these guidelines:
|
||||||
@ -129,6 +136,21 @@ class PromptTemplates(BaseModel):
|
|||||||
TEMPLATES = PromptTemplates()
|
TEMPLATES = PromptTemplates()
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function for simple language detection
|
||||||
|
def simple_language_detection(prompt: str) -> str:
|
||||||
|
"""Simple language detection based on ASCII character ratio"""
|
||||||
|
if not prompt:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
non_ascii_chars = 0
|
||||||
|
for char in prompt:
|
||||||
|
if ord(char) > 127:
|
||||||
|
non_ascii_chars += 1
|
||||||
|
|
||||||
|
language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other"
|
||||||
|
print(f"Simple language detection: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}'")
|
||||||
|
return language
|
||||||
|
|
||||||
# Agent functions
|
# Agent functions
|
||||||
def router_agent(state: PromptState) -> Dict[str, Any]:
|
def router_agent(state: PromptState) -> Dict[str, Any]:
|
||||||
"""Determine the language of the prompt"""
|
"""Determine the language of the prompt"""
|
||||||
@ -171,38 +193,25 @@ def router_agent(state: PromptState) -> Dict[str, Any]:
|
|||||||
else:
|
else:
|
||||||
print(f"RouterAgent: Language detection failed - {response.status_code} - {response.text}")
|
print(f"RouterAgent: Language detection failed - {response.status_code} - {response.text}")
|
||||||
# Fallback to simple detection
|
# Fallback to simple detection
|
||||||
non_ascii_chars = 0
|
language = simple_language_detection(prompt)
|
||||||
for char in prompt:
|
|
||||||
if ord(char) > 127:
|
|
||||||
non_ascii_chars += 1
|
|
||||||
|
|
||||||
language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other"
|
|
||||||
print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}' (simple detection)")
|
|
||||||
return {"language": language}
|
return {"language": language}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"RouterAgent: Language detection failed - {str(e)}")
|
print(f"RouterAgent: Language detection failed - {str(e)}")
|
||||||
# Fallback to simple detection
|
# Fallback to simple detection
|
||||||
non_ascii_chars = 0
|
language = simple_language_detection(prompt)
|
||||||
for char in prompt:
|
|
||||||
if ord(char) > 127:
|
|
||||||
non_ascii_chars += 1
|
|
||||||
|
|
||||||
language = "english" if (non_ascii_chars / len(prompt) < 0.3) else "other"
|
|
||||||
print(f"RouterAgent: Prompt '{prompt}' detected as '{'English' if language == 'english' else 'Non-English'}' (simple detection)")
|
|
||||||
return {"language": language}
|
return {"language": language}
|
||||||
|
|
||||||
|
|
||||||
def translator_agent(state: PromptState) -> Dict[str, Any]:
|
def translator_agent(state: PromptState) -> Dict[str, Any]:
|
||||||
"""Translate non-English prompts to English"""
|
"""Translate non-English prompts to English"""
|
||||||
prompt = state["original_prompt"]
|
prompt = state["original_prompt"]
|
||||||
language = state["language"]
|
language = state["language"]
|
||||||
|
|
||||||
if language == "english":
|
if language == "english":
|
||||||
print(f"TranslatorAgent: Prompt is already in English, no translation needed")
|
print("TranslatorAgent: Prompt is already in English, no translation needed")
|
||||||
return {"translated_prompt": prompt}
|
return {"translated_prompt": prompt}
|
||||||
|
|
||||||
if not DEEPSEEK_API_KEY:
|
if not DEEPSEEK_API_KEY:
|
||||||
print(f"TranslatorAgent: Warning - DEEPSEEK_API_KEY not set, using simplified translation")
|
print("TranslatorAgent: Warning - DEEPSEEK_API_KEY not set, using simplified translation")
|
||||||
return {"error": "DEEPSEEK_API_KEY not set", "translated_prompt": prompt}
|
return {"error": "DEEPSEEK_API_KEY not set", "translated_prompt": prompt}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -332,19 +341,19 @@ def create_prompt_optimization_graph():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Create state graph
|
# Create state graph
|
||||||
workflow = StateGraph(PromptState)
|
graph = StateGraph(PromptState)
|
||||||
|
|
||||||
# Add nodes
|
# Add nodes
|
||||||
workflow.add_node("router", router_agent)
|
graph.add_node("router", router_agent)
|
||||||
workflow.add_node("translator", translator_agent)
|
graph.add_node("translator", translator_agent)
|
||||||
workflow.add_node("optimizer", optimizer_agent)
|
graph.add_node("optimizer", optimizer_agent)
|
||||||
|
|
||||||
# Add edges
|
# Add edges
|
||||||
# From start to router
|
# From start to router
|
||||||
workflow.set_entry_point("router")
|
graph.set_entry_point("router")
|
||||||
|
|
||||||
# From router to translator or optimizer (based on language)
|
# From router to translator or optimizer (based on language)
|
||||||
workflow.add_conditional_edges(
|
graph.add_conditional_edges(
|
||||||
"router",
|
"router",
|
||||||
should_translate,
|
should_translate,
|
||||||
{
|
{
|
||||||
@ -354,18 +363,18 @@ def create_prompt_optimization_graph():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# From translator to optimizer
|
# From translator to optimizer
|
||||||
workflow.add_edge("translator", "optimizer")
|
graph.add_edge("translator", "optimizer")
|
||||||
|
|
||||||
# From optimizer to end
|
# From optimizer to end
|
||||||
workflow.add_edge("optimizer", END)
|
graph.add_edge("optimizer", END)
|
||||||
|
|
||||||
# Compile workflow
|
# Compile workflow
|
||||||
return workflow.compile()
|
return graph.compile()
|
||||||
|
|
||||||
# Simplified workflow (used when LangGraph is not available)
|
# Simplified workflow (used when LangGraph is not available)
|
||||||
def simple_prompt_optimization_workflow(prompt: str) -> str:
|
def simple_prompt_optimization_workflow(prompt: str) -> str:
|
||||||
"""Simplified prompt optimization workflow"""
|
"""Simplified prompt optimization workflow"""
|
||||||
print(f"\n--- Simplified workflow started ---")
|
print("\n--- Simplified workflow started ---")
|
||||||
print(f"Original prompt: '{prompt}'")
|
print(f"Original prompt: '{prompt}'")
|
||||||
|
|
||||||
# Initialize state
|
# Initialize state
|
||||||
@ -395,7 +404,7 @@ def simple_prompt_optimization_workflow(prompt: str) -> str:
|
|||||||
state["error"] = optimizer_result["error"]
|
state["error"] = optimizer_result["error"]
|
||||||
|
|
||||||
print(f"Final optimized prompt: '{state['optimized_prompt']}'")
|
print(f"Final optimized prompt: '{state['optimized_prompt']}'")
|
||||||
print(f"--- Simplified workflow finished ---\n")
|
print("--- Simplified workflow finished ---\n")
|
||||||
|
|
||||||
return state["optimized_prompt"] or prompt
|
return state["optimized_prompt"] or prompt
|
||||||
|
|
||||||
@ -479,11 +488,16 @@ class PromptOptimizer(scripts.Script):
|
|||||||
|
|
||||||
def postprocess(self, p, processed):
|
def postprocess(self, p, processed):
|
||||||
"""Post-process after image generation"""
|
"""Post-process after image generation"""
|
||||||
|
# Add original prompt to extra generation params
|
||||||
|
if hasattr(self, 'extra_generation_params') and hasattr(self, 'main_prompt'):
|
||||||
|
processed.infotexts[0] = processed.infotexts[0].replace(
|
||||||
|
"Prompt: ", f"Prompt: {self.extra_generation_params.get('Original prompt', '')}\nOptimized: "
|
||||||
|
)
|
||||||
# Nothing to do here
|
# Nothing to do here
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
def optimize_prompt(self, prompt):
|
def optimize_prompt(self, prompt: str) -> str:
|
||||||
"""Optimize prompt to improve generation quality"""
|
"""Optimize a prompt using the workflow"""
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
@ -491,7 +505,7 @@ class PromptOptimizer(scripts.Script):
|
|||||||
if self.graph is not None:
|
if self.graph is not None:
|
||||||
# Use LangGraph workflow
|
# Use LangGraph workflow
|
||||||
try:
|
try:
|
||||||
print(f"\n--- LangGraph started ---")
|
print("\n--- LangGraph started ---")
|
||||||
print(f"Original prompt: '{prompt}'")
|
print(f"Original prompt: '{prompt}'")
|
||||||
|
|
||||||
# Create initial state
|
# Create initial state
|
||||||
@ -508,43 +522,21 @@ class PromptOptimizer(scripts.Script):
|
|||||||
|
|
||||||
optimized = final_state.get("optimized_prompt") or prompt
|
optimized = final_state.get("optimized_prompt") or prompt
|
||||||
print(f"Final optimized prompt: '{optimized}'")
|
print(f"Final optimized prompt: '{optimized}'")
|
||||||
print(f"--- LangGraph finished ---\n")
|
print("--- LangGraph finished ---\n")
|
||||||
return optimized
|
return optimized
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"LangGraph failed: {str(e)}")
|
print(f"LangGraph workflow error: {str(e)}")
|
||||||
# Fallback to simplified workflow
|
print("Falling back to simplified workflow")
|
||||||
return simple_prompt_optimization_workflow(prompt)
|
return simple_prompt_optimization_workflow(prompt)
|
||||||
else:
|
else:
|
||||||
# Use simplified workflow
|
# Use simplified workflow
|
||||||
return simple_prompt_optimization_workflow(prompt)
|
return simple_prompt_optimization_workflow(prompt)
|
||||||
|
|
||||||
|
|
||||||
# Check if we're in a Stable Diffusion Webui environment
|
|
||||||
try:
|
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
|
||||||
except ImportError:
|
|
||||||
# Create a mock class for testing outside of webui
|
|
||||||
class StableDiffusionProcessingTxt2Img:
|
|
||||||
"""Mock class for testing outside of webui"""
|
|
||||||
def __init__(self):
|
|
||||||
self.prompt = ""
|
|
||||||
self.all_prompts = []
|
|
||||||
self.main_prompt = ""
|
|
||||||
self.extra_generation_params = {}
|
|
||||||
|
|
||||||
def setup_prompts(self):
|
|
||||||
"""Mock setup_prompts method"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# For standalone testing
|
# For standalone testing
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Test the prompt optimization workflow
|
# Test the prompt optimization workflow
|
||||||
test_prompts = [
|
test_prompts = [
|
||||||
"a cat",
|
"a beautiful landscape with mountains", # English
|
||||||
"beautiful landscape",
|
|
||||||
"portrait of a woman",
|
|
||||||
"科幻城市", # Chinese: "sci-fi city"
|
|
||||||
"美丽的山水画", # Chinese: "beautiful landscape painting"
|
"美丽的山水画", # Chinese: "beautiful landscape painting"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user