File size: 5,618 Bytes
c6eb9ce
 
e7247e4
fb2f0a7
 
c6eb9ce
fb2f0a7
 
c6eb9ce
e7247e4
 
 
c6eb9ce
4584c11
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
5fb9af6
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fb9af6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6eb9ce
 
5fb9af6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
from PIL import Image
from dotenv import load_dotenv
from image_evaluators import LlamaEvaluator
from prompt_refiners import LlamaPromptRefiner
from weave_prompt import PromptOptimizer
from similarity_metrics import LPIPSImageSimilarityMetric
from image_generators import FalImageGenerator

# Load environment variables from .env file
load_dotenv()

st.set_page_config(
    page_title="WeavePrompt",
    page_icon="🎨",
    layout="wide"
)

def main():
    st.title("🎨 WeavePrompt: Iterative Prompt Optimization")
    st.markdown("""
    Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it.
    """)
    
    # Initialize session state
    if 'optimizer' not in st.session_state:
        st.session_state.optimizer = PromptOptimizer(
            image_generator=FalImageGenerator(),
            evaluator=LlamaEvaluator(),
            refiner=LlamaPromptRefiner(),
            similarity_metric=LPIPSImageSimilarityMetric(),
            max_iterations=10,
            similarity_threshold=0.95
        )
    
    if 'optimization_started' not in st.session_state:
        st.session_state.optimization_started = False
    
    if 'current_results' not in st.session_state:
        st.session_state.current_results = None

    # File uploader
    uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
    
    if uploaded_file is not None:
        # Display target image
        target_image = Image.open(uploaded_file)
        
        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Target Image")
            st.image(target_image, width='stretch')
        
        # Start button
        if not st.session_state.optimization_started:
            if st.button("Start Optimization"):
                st.session_state.optimization_started = True
                # Initialize optimization
                is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
                st.session_state.current_results = (is_completed, prompt, generated_image)
        
        # Display optimization progress
        if st.session_state.optimization_started:
            with col2:
                st.subheader("Generated Image")
                is_completed, prompt, generated_image = st.session_state.current_results
                st.image(generated_image, width='stretch')
            
            # Display prompt and controls
            st.text_area("Current Prompt", prompt, height=100)
            
            # Progress metrics
            col1, col2, col3 = st.columns(3)
            with col1:
                st.metric("Iteration", len(st.session_state.optimizer.history))
            with col2:
                if len(st.session_state.optimizer.history) > 0:
                    similarity = st.session_state.optimizer.history[-1]['similarity']
                    st.metric("Similarity", f"{similarity:.2%}")
            with col3:
                st.metric("Status", "Completed" if is_completed else "In Progress")
            
            # Next step button
            if not is_completed:
                if st.button("Next Step"):
                    is_completed, prompt, generated_image = st.session_state.optimizer.step()
                    st.session_state.current_results = (is_completed, prompt, generated_image)
                    st.rerun()
            else:
                st.success("Optimization completed! Click 'Reset' to try another image.")
            
            # Reset button
            if st.button("Reset"):
                st.session_state.optimization_started = False
                st.session_state.current_results = None
                st.rerun()
            
            # Display history
            if len(st.session_state.optimizer.history) > 0:
                st.subheader("Optimization History")
                for idx, hist_entry in enumerate(st.session_state.optimizer.history):
                    st.markdown(f"### Step {idx + 1}")
                    col1, col2 = st.columns([2, 3])
                    with col1:
                        st.image(hist_entry['image'], width='stretch')
                    with col2:
                        st.text(f"Similarity: {hist_entry['similarity']:.2%}")
                        st.text("Prompt:")
                        st.text(hist_entry['prompt'])
                        # Toggle analysis view per history entry
                        expand_key = f"expand_analysis_{idx}"
                        if 'analysis_expanded' not in st.session_state:
                            st.session_state['analysis_expanded'] = {}
                        if expand_key not in st.session_state['analysis_expanded']:
                            st.session_state['analysis_expanded'][expand_key] = False

                        if st.session_state['analysis_expanded'][expand_key]:
                            if st.button("Hide Analysis", key=f"hide_{expand_key}"):
                                st.session_state['analysis_expanded'][expand_key] = False
                                st.rerun()
                            st.text("Analysis:")
                            for key, value in hist_entry['analysis'].items():
                                st.text(f"{key}: {value}")
                        else:
                            if st.button("Expand Analysis", key=expand_key):
                                st.session_state['analysis_expanded'][expand_key] = True
                                st.rerun()

if __name__ == "__main__":
    main()