Spaces:
Runtime error
Runtime error
File size: 5,618 Bytes
c6eb9ce e7247e4 fb2f0a7 c6eb9ce fb2f0a7 c6eb9ce e7247e4 c6eb9ce 4584c11 c6eb9ce 5fb9af6 c6eb9ce 5fb9af6 c6eb9ce 5fb9af6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
from PIL import Image
from dotenv import load_dotenv
from image_evaluators import LlamaEvaluator
from prompt_refiners import LlamaPromptRefiner
from weave_prompt import PromptOptimizer
from similarity_metrics import LPIPSImageSimilarityMetric
from image_generators import FalImageGenerator
# Load environment variables from .env file
load_dotenv()
st.set_page_config(
page_title="WeavePrompt",
page_icon="🎨",
layout="wide"
)
def main():
st.title("🎨 WeavePrompt: Iterative Prompt Optimization")
st.markdown("""
Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it.
""")
# Initialize session state
if 'optimizer' not in st.session_state:
st.session_state.optimizer = PromptOptimizer(
image_generator=FalImageGenerator(),
evaluator=LlamaEvaluator(),
refiner=LlamaPromptRefiner(),
similarity_metric=LPIPSImageSimilarityMetric(),
max_iterations=10,
similarity_threshold=0.95
)
if 'optimization_started' not in st.session_state:
st.session_state.optimization_started = False
if 'current_results' not in st.session_state:
st.session_state.current_results = None
# File uploader
uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
# Display target image
target_image = Image.open(uploaded_file)
col1, col2 = st.columns(2)
with col1:
st.subheader("Target Image")
st.image(target_image, width='stretch')
# Start button
if not st.session_state.optimization_started:
if st.button("Start Optimization"):
st.session_state.optimization_started = True
# Initialize optimization
is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
st.session_state.current_results = (is_completed, prompt, generated_image)
# Display optimization progress
if st.session_state.optimization_started:
with col2:
st.subheader("Generated Image")
is_completed, prompt, generated_image = st.session_state.current_results
st.image(generated_image, width='stretch')
# Display prompt and controls
st.text_area("Current Prompt", prompt, height=100)
# Progress metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Iteration", len(st.session_state.optimizer.history))
with col2:
if len(st.session_state.optimizer.history) > 0:
similarity = st.session_state.optimizer.history[-1]['similarity']
st.metric("Similarity", f"{similarity:.2%}")
with col3:
st.metric("Status", "Completed" if is_completed else "In Progress")
# Next step button
if not is_completed:
if st.button("Next Step"):
is_completed, prompt, generated_image = st.session_state.optimizer.step()
st.session_state.current_results = (is_completed, prompt, generated_image)
st.rerun()
else:
st.success("Optimization completed! Click 'Reset' to try another image.")
# Reset button
if st.button("Reset"):
st.session_state.optimization_started = False
st.session_state.current_results = None
st.rerun()
# Display history
if len(st.session_state.optimizer.history) > 0:
st.subheader("Optimization History")
for idx, hist_entry in enumerate(st.session_state.optimizer.history):
st.markdown(f"### Step {idx + 1}")
col1, col2 = st.columns([2, 3])
with col1:
st.image(hist_entry['image'], width='stretch')
with col2:
st.text(f"Similarity: {hist_entry['similarity']:.2%}")
st.text("Prompt:")
st.text(hist_entry['prompt'])
# Toggle analysis view per history entry
expand_key = f"expand_analysis_{idx}"
if 'analysis_expanded' not in st.session_state:
st.session_state['analysis_expanded'] = {}
if expand_key not in st.session_state['analysis_expanded']:
st.session_state['analysis_expanded'][expand_key] = False
if st.session_state['analysis_expanded'][expand_key]:
if st.button("Hide Analysis", key=f"hide_{expand_key}"):
st.session_state['analysis_expanded'][expand_key] = False
st.rerun()
st.text("Analysis:")
for key, value in hist_entry['analysis'].items():
st.text(f"{key}: {value}")
else:
if st.button("Expand Analysis", key=expand_key):
st.session_state['analysis_expanded'][expand_key] = True
st.rerun()
if __name__ == "__main__":
main()
|