frankai98's picture
Update app.py
cb3fae5 verified
raw
history blame
5.79 kB
import nest_asyncio
nest_asyncio.apply()
import streamlit as st
from transformers import pipeline
from gtts import gTTS
import io
from streamlit.components.v1 import html
import asyncio
if not asyncio.get_event_loop().is_running():
asyncio.set_event_loop(asyncio.new_event_loop())
# Initialize session state
if 'processed_data' not in st.session_state:
st.session_state.processed_data = {
'scenario': None,
'story': None,
'audio': None
}
if 'timer_started' not in st.session_state:
st.session_state.timer_started = False
if 'timer_frozen' not in st.session_state:
st.session_state.timer_frozen = False
# Timer component that uses localStorage to check for a freeze flag
def timer():
return """
<div id="timer" style="font-size:16px;color:#666;margin-bottom:10px;">⏱️ Elapsed: 00:00</div>
<script>
(function() {
var start = Date.now();
var timerElement = document.getElementById('timer');
// Clear any previous freeze flag
localStorage.removeItem("freezeTimer");
var interval = setInterval(function() {
// Check if freeze flag is set in localStorage.
if(localStorage.getItem("freezeTimer") === "true"){
clearInterval(interval);
timerElement.style.color = '#00cc00'; // Change color to indicate frozen.
return;
}
var elapsed = Date.now() - start;
var minutes = Math.floor(elapsed / 60000);
var seconds = Math.floor((elapsed % 60000) / 1000);
timerElement.innerHTML = '⏱️ Elapsed: ' +
(minutes < 10 ? '0' : '') + minutes + ':' +
(seconds < 10 ? '0' : '') + seconds;
}, 1000);
})();
</script>
"""
# Page setup
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
st.header("Turn Your Image to a Short Audio Story for Children")
# Model loading
@st.cache_resource
def load_models():
return {
"img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"),
"story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct")
}
models = load_models()
# Processing functions
def img2text(url):
return models["img_model"](url)[0]["generated_text"]
def text2story(text):
prompt = f"Generate a 100-word story about: {text}"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
response = models["story_model"](
messages,
max_new_tokens=100,
do_sample=True,
temperature=0.7
)[0]["generated_text"]
return response[2]["content"]
def text2audio(story_text):
audio_io = io.BytesIO()
tts = gTTS(text=story_text, lang='en', slow=False)
tts.write_to_fp(audio_io)
audio_io.seek(0)
return {'audio': audio_io, 'sampling_rate': 16000}
# UI components
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...")
if uploaded_file is not None:
# Inject the timer only once, right after image upload.
if not st.session_state.timer_started and not st.session_state.timer_frozen:
st.session_state.timer_started = True
html(timer(), height=50)
# Initialize progress containers
status_text = st.empty()
progress_bar = st.progress(0)
try:
# Save uploaded file
bytes_data = uploaded_file.getvalue()
with open(uploaded_file.name, "wb") as file:
file.write(bytes_data)
if st.session_state.get('current_file') != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
# Display image
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
# Stage 1: Image to Text
status_text.markdown("**🖼️ Generating caption...**")
progress_bar.progress(0)
st.session_state.processed_data['scenario'] = img2text(uploaded_file.name)
progress_bar.progress(33)
# Stage 2: Text to Story
status_text.markdown("**📖 Generating story...**")
progress_bar.progress(33)
st.session_state.processed_data['story'] = text2story(
st.session_state.processed_data['scenario']
)
progress_bar.progress(66)
# Stage 3: Story to Audio
status_text.markdown("**🔊 Synthesizing audio...**")
progress_bar.progress(66)
st.session_state.processed_data['audio'] = text2audio(
st.session_state.processed_data['story']
)
progress_bar.progress(100)
status_text.success("**✅ Generation complete!**")
# Immediately freeze the timer by setting the freeze flag.
html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
st.session_state.timer_frozen = True
except Exception as e:
html("<script>document.getElementById('timer').remove();</script>")
status_text.error(f"**❌ Error:** {str(e)}")
progress_bar.empty()
raise e
st.write("**Caption:**", st.session_state.processed_data['scenario'])
st.write("**Story:**", st.session_state.processed_data['story'])
# Optionally, you can still have the Play Audio button.
if st.button("Play Audio of the Story Generated"):
if st.session_state.processed_data.get('audio'):
# Since the timer is already frozen by now, just play the audio.
audio_data = st.session_state.processed_data['audio']
st.audio(audio_data['audio'].getvalue(), format="audio/mp3")
else:
st.warning("Please generate a story first!")