Spaces:
Paused
Paused
| import os | |
| import shutil | |
| import subprocess | |
| import streamlit as st | |
| # βββ 1. Mode detection & data directory βββββββββββββββββββββββββββββββββββββββ | |
| # LOCAL_TRAIN=1 β use "./data"; otherwise Spaces uses "/tmp/data" | |
| LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true") | |
| DATA_DIR = os.path.join(os.getcwd(), "data") if LOCAL else "/tmp/data" | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| # βββ 2. Page layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config(page_title="HiDream LoRA Trainer", layout="wide") | |
| st.title("π¨ HiDream LoRA Trainer (Streamlit)") | |
| # Sidebar for configuration | |
| with st.sidebar: | |
| st.header("π Configuration") | |
| base_model = st.selectbox( | |
| "Base Model", | |
| ["HiDream-ai/HiDream-I1-Dev", | |
| "runwayml/stable-diffusion-v1-5", | |
| "stabilityai/stable-diffusion-2-1"] | |
| ) | |
| trigger_word = st.text_input("Trigger Word", value="default-style") | |
| num_steps = st.slider("Training Steps", min_value=10, max_value=500, value=100, step=10) | |
| lora_r = st.slider("LoRA Rank (r)", min_value=4, max_value=128, value=16, step=4) | |
| lora_alpha = st.slider("LoRA Alpha", min_value=4, max_value=128, value=16, step=4) | |
| st.markdown("---") | |
| st.header("π Upload Dataset") | |
| uploaded_files = st.file_uploader( | |
| "Select your images & text files", | |
| type=["jpg","jpeg","png","txt"], | |
| accept_multiple_files=True | |
| ) | |
| if st.button("Upload Dataset"): | |
| # Clear old files | |
| for f in os.listdir(DATA_DIR): | |
| os.remove(os.path.join(DATA_DIR, f)) | |
| # Write new files | |
| for up in uploaded_files: | |
| dest = os.path.join(DATA_DIR, up.name) | |
| with open(dest, "wb") as f: | |
| f.write(up.getbuffer()) | |
| st.success(f"β Uploaded {len(uploaded_files)} files to `{DATA_DIR}`") | |
| st.markdown("---") | |
| # Trigger training | |
| if st.button("π Start Training"): | |
| st.session_state.training = True | |
| # βββ 3. Training log area βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| log_area = st.empty() | |
| # βββ 4. Invoke training when triggered ββββββββββββββββββββββββββββββββββββββββ | |
| if st.session_state.get("training", False): | |
| st.info("Training started⦠Logs below:") | |
| log_lines = [] | |
| # Prepare environment for train.py | |
| env = os.environ.copy() | |
| env.update({ | |
| "BASE_MODEL": base_model, | |
| "TRIGGER_WORD": trigger_word, | |
| "NUM_STEPS": str(num_steps), | |
| "LORA_R": str(lora_r), | |
| "LORA_ALPHA": str(lora_alpha), | |
| "LOCAL_TRAIN": os.environ.get("LOCAL_TRAIN","") | |
| }) | |
| # Launch train.py as subprocess and stream logs | |
| proc = subprocess.Popen( | |
| ["python3", "train.py"], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| env=env | |
| ) | |
| for line in proc.stdout: | |
| log_lines.append(line) | |
| # Update the text area with all lines so far | |
| log_area.text_area("Training Log", value="".join(log_lines), height=400) | |
| proc.wait() | |
| if proc.returncode == 0: | |
| st.success("β Training complete!") | |
| else: | |
| st.error(f"β Training failed (exit code {proc.returncode})") | |
| # Reset trigger | |
| st.session_state.training = False | |