frankai98's picture
Update app.py
27983d9 verified
raw
history blame
5.96 kB
import nest_asyncio
nest_asyncio.apply()
import streamlit as st
from transformers import pipeline
import torch
from gtts import gTTS
import io
import time
from streamlit.components.v1 import html
import asyncio
if not asyncio.get_event_loop().is_running():
asyncio.set_event_loop(asyncio.new_event_loop())
# Initialize session state
if 'processed_data' not in st.session_state:
st.session_state.processed_data = {
'scenario': None,
'story': None,
'audio': None
}
# Modified JavaScript timer component
def timer():
return """
<script>
// Remove existing timer if present
var existingTimer = window.parent.document.getElementById('parent-timer');
if (existingTimer) {
existingTimer.remove();
clearInterval(window.parent.timerInterval);
}
// Create new timer element in parent document
var timerDiv = window.parent.document.createElement('div');
timerDiv.id = 'parent-timer';
timerDiv.style.fontSize = '16px';
timerDiv.style.color = '#666';
timerDiv.style.marginBottom = '10px';
timerDiv.innerHTML = '⏱️ Elapsed: 00:00';
window.parent.document.body.appendChild(timerDiv);
var start = Date.now();
var timerInterval = setInterval(function() {
var elapsed = Date.now() - start;
var minutes = Math.floor(elapsed / 60000);
var seconds = Math.floor((elapsed % 60000) / 1000);
timerDiv.innerHTML = '⏱️ Elapsed: ' +
(minutes < 10 ? '0' : '') + minutes + ':' +
(seconds < 10 ? '0' : '') + seconds;
}, 1000);
// Store the interval ID in the parent window
window.parent.timerInterval = timerInterval;
</script>
"""
# Page setup
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
st.header("Turn Your Image to a Short Audio Story for Children")
# Model loading
@st.cache_resource
def load_models():
return {
"img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"),
"story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct")
}
models = load_models()
# Processing functions
def img2text(url):
return models["img_model"](url)[0]["generated_text"]
def text2story(text):
prompt = f"Generate a brief 100-word story about: {text}"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
response = models["story_model"](
messages,
max_new_tokens=100,
do_sample=True,
temperature=0.7
)[0]["generated_text"]
return response[2]["content"]
def text2audio(story_text):
audio_io = io.BytesIO()
tts = gTTS(text=story_text, lang='en', slow=False)
tts.write_to_fp(audio_io)
audio_io.seek(0)
return {'audio': audio_io, 'sampling_rate': 16000}
# UI components
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...")
if uploaded_file is not None:
# Initialize progress containers
status_text = st.empty()
progress_bar = st.progress(0)
# Start JavaScript timer
html(timer(), height=50)
try:
# Save uploaded file
bytes_data = uploaded_file.getvalue()
with open(uploaded_file.name, "wb") as file:
file.write(bytes_data)
if st.session_state.get('current_file') != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
# Display image
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
# Stage 1: Image to Text
status_text.markdown("**🖼️ Analyzing image...**")
progress_bar.progress(0)
st.session_state.processed_data['scenario'] = img2text(uploaded_file.name)
progress_bar.progress(33)
# Stage 2: Text to Story
status_text.markdown("**📖 Generating story...**")
progress_bar.progress(33)
st.session_state.processed_data['story'] = text2story(
st.session_state.processed_data['scenario']
)
progress_bar.progress(66)
# Stage 3: Story to Audio
status_text.markdown("**🔊 Synthesizing audio...**")
progress_bar.progress(66)
st.session_state.processed_data['audio'] = text2audio(
st.session_state.processed_data['story']
)
progress_bar.progress(100)
# Final status and stop timer
status_text.success("**✅ Generation complete!**")
html("""
<script>
clearInterval(window.parent.timerInterval);
var timerDiv = window.parent.document.getElementById('parent-timer');
if (timerDiv) {
timerDiv.style.color = '#00cc00';
}
</script>
""")
# Show results
st.subheader("Results")
st.write("**Caption:**", st.session_state.processed_data['scenario'])
st.write("**Story:**", st.session_state.processed_data['story'])
except Exception as e:
html("""
<script>
clearInterval(window.parent.timerInterval);
var timerDiv = window.parent.document.getElementById('parent-timer');
if (timerDiv) {
timerDiv.remove();
}
</script>
""")
status_text.error(f"**❌ Error:** {str(e)}")
progress_bar.empty()
raise e
# Audio playback
if st.button("Play Audio of the Story Generated"):
if st.session_state.processed_data.get('audio'):
audio_data = st.session_state.processed_data['audio']
st.audio(
audio_data['audio'].getvalue(),
format="audio/mp3"
)
else:
st.warning("Please generate a story first!")