Sam Dobson commited on
Commit
c514928
Β·
0 Parent(s):

First commit

Browse files
Files changed (3) hide show
  1. README.md +73 -0
  2. app.py +151 -0
  3. requirements.txt +5 -0
README.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TinyStories Story Generator
3
+ emoji: πŸ“š
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ tags:
12
+ - text-generation
13
+ - llama
14
+ - tinystories
15
+ - storytelling
16
+ ---
17
+
18
+ # TinyStories Story Generator
19
+
20
+ An interactive Gradio Space for generating simple children's stories using a small Llama-architecture model trained on the TinyStories dataset.
21
+
22
+ ## About
23
+
24
+ This Space provides a chat-style interface to interact with a ~15M parameter language model that generates simple, coherent children's stories. The model uses vocabulary and concepts that a typical 3-4 year old would understand.
25
+
26
+ ## Features
27
+
28
+ - Interactive story generation
29
+ - Adjustable generation parameters (temperature, top-k, top-p, max length)
30
+ - Example prompts to get started
31
+ - Real-time generation
32
+ - Clean, user-friendly interface
33
+
34
+ ## Model Details
35
+
36
+ - **Architecture:** Llama 2
37
+ - **Parameters:** ~15M
38
+ - **Layers:** 6
39
+ - **Attention Heads:** 6
40
+ - **Max Context Length:** 256 tokens
41
+ - **Training Dataset:** [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories)
42
+
43
+ ## Usage
44
+
45
+ 1. Enter a story prompt (e.g., "Once upon a time, there was a...")
46
+ 2. Optionally adjust generation settings
47
+ 3. Click "Generate Story"
48
+ 4. Enjoy your AI-generated children's story!
49
+
50
+ ## Setup Instructions
51
+
52
+ To run this Space:
53
+
54
+ 1. Upload your trained model to HuggingFace Hub
55
+ 2. Update the `MODEL_REPO` variable in `app.py` with your model repository (format: `username/model-name`)
56
+ 3. Or set the `MODEL_REPO` environment variable in the Space settings
57
+
58
+ ## Local Development
59
+
60
+ ```bash
61
+ pip install -r requirements.txt
62
+ python app.py
63
+ ```
64
+
65
+ ## License
66
+
67
+ MIT License
68
+
69
+ ## Acknowledgments
70
+
71
+ - Model architecture and training code adapted from [llama2.c](https://github.com/karpathy/llama2.c) by Andrej Karpathy
72
+ - Trained on the [TinyStories dataset](https://huggingface.co/datasets/roneneldan/TinyStories) by Ronen Eldan and Yuanzhi Li
73
+ - Based on the Llama 2 architecture by Meta AI
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for TinyStories Llama model chat.
3
+ """
4
+ import gradio as gr
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import torch
7
+ import warnings
8
+ import os
9
+
10
+ warnings.filterwarnings('ignore', category=UserWarning)
11
+
12
+ MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/tinystories-llama-15m")
13
+
14
+ print(f"Loading model and tokenizer from {MODEL_REPO}...")
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_REPO)
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model = model.to(device)
20
+ model.eval()
21
+
22
+ print(f"Model loaded on {device}")
23
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
24
+
25
+
26
+ def generate_story(
27
+ prompt,
28
+ max_length=200,
29
+ temperature=0.8,
30
+ top_k=50,
31
+ top_p=0.9,
32
+ do_sample=True
33
+ ):
34
+ """Generate a story continuation from the prompt."""
35
+ if not prompt.strip():
36
+ return "Please provide a story prompt!"
37
+
38
+ # Tokenize input
39
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
40
+
41
+ # Generate
42
+ with torch.no_grad():
43
+ outputs = model.generate(
44
+ **inputs,
45
+ max_length=max_length,
46
+ temperature=temperature,
47
+ top_k=top_k,
48
+ top_p=top_p,
49
+ do_sample=do_sample,
50
+ pad_token_id=tokenizer.eos_token_id,
51
+ eos_token_id=tokenizer.eos_token_id,
52
+ )
53
+
54
+ # Decode and return
55
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ return generated_text
57
+
58
+
59
+ with gr.Blocks(title="TinyStories Story Generator") as demo:
60
+ gr.Markdown(
61
+ """
62
+ # TinyStories Llama Model Chat
63
+
64
+ This is a small Llama-architecture model trained on the TinyStories dataset.
65
+ It generates simple, coherent children's stories using vocabulary that a typical 3-4 year old would understand.
66
+
67
+ **Try starting your story with:**
68
+ - "Once upon a time, there was a..."
69
+ - "One day, a little boy named..."
70
+ - "In a small town, there lived a..."
71
+ """
72
+ )
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ prompt_input = gr.Textbox(
77
+ label="Story Prompt",
78
+ placeholder="Once upon a time, there was a",
79
+ lines=3
80
+ )
81
+
82
+ with gr.Accordion("Generation Settings", open=False):
83
+ max_length_slider = gr.Slider(
84
+ minimum=50,
85
+ maximum=256,
86
+ value=200,
87
+ step=10,
88
+ label="Max Length (tokens)"
89
+ )
90
+ temperature_slider = gr.Slider(
91
+ minimum=0.1,
92
+ maximum=2.0,
93
+ value=0.8,
94
+ step=0.1,
95
+ label="Temperature (higher = more creative)"
96
+ )
97
+ top_k_slider = gr.Slider(
98
+ minimum=1,
99
+ maximum=100,
100
+ value=50,
101
+ step=1,
102
+ label="Top-k"
103
+ )
104
+ top_p_slider = gr.Slider(
105
+ minimum=0.1,
106
+ maximum=1.0,
107
+ value=0.9,
108
+ step=0.05,
109
+ label="Top-p (nucleus sampling)"
110
+ )
111
+ do_sample_checkbox = gr.Checkbox(
112
+ label="Use Sampling",
113
+ value=True
114
+ )
115
+
116
+ generate_btn = gr.Button("Generate Story", variant="primary")
117
+
118
+ with gr.Column():
119
+ output_text = gr.Textbox(
120
+ label="Generated Story",
121
+ lines=15,
122
+ show_copy_button=True
123
+ )
124
+
125
+ gr.Examples(
126
+ examples=[
127
+ ["Once upon a time, there was a little girl named Lily."],
128
+ ["One day, a little boy found a magic"],
129
+ ["The little dog was very happy because"],
130
+ ["In a small garden, there lived a"],
131
+ ["Timmy wanted to play with his friend, but"],
132
+ ],
133
+ inputs=prompt_input,
134
+ label="Example Prompts"
135
+ )
136
+
137
+ generate_btn.click(
138
+ fn=generate_story,
139
+ inputs=[
140
+ prompt_input,
141
+ max_length_slider,
142
+ temperature_slider,
143
+ top_k_slider,
144
+ top_p_slider,
145
+ do_sample_checkbox
146
+ ],
147
+ outputs=output_text
148
+ )
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.46.0
3
+ torch>=2.0.0
4
+ accelerate>=0.20.0
5
+ sentencepiece>=0.1.99