progs2002 commited on
Commit
a87d429
·
1 Parent(s): b1d7f25

basic streamlit interface

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. app.py +38 -0
  3. model.py +23 -0
  4. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ __pycache__/
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ascii_art = r"""
3
+ _ _ _
4
+ ___| |_ __ _ _ __ | |_ _ __ ___| | __
5
+ / __| __/ _` | '__| | __| '__/ _ \ |/ /
6
+ \__ \ || (_| | | | |_| | | __/ <
7
+ |___/\__\__,_|_| \__|_| \___|_|\_\
8
+
9
+ """
10
+
11
+ from model import LLM
12
+ import streamlit as st
13
+ import streamlit_scrollable_textbox as stx
14
+
15
+ st.text(ascii_art)
16
+
17
+ with st.spinner("Please wait... loading model"):
18
+ llm = LLM()
19
+
20
+ demo_text = """DATA: The ship has gone into warp, sir.
21
+ RIKER: Who gave the command?
22
+ DATA: Apparently no one. Helm and navigation controls are not functioning. Our speed is now warp seven-point-three and holding.
23
+ PICARD: Picard to Engineering. Mister La Forge, what's going on down there?
24
+ """
25
+ text = st.text_area("First few lines of the script:", demo_text)
26
+
27
+ col1, col2 = st.columns(2)
28
+
29
+ with col1:
30
+ temp = st.slider('Temperature', 0, 1, 1)
31
+ max_len = st.number_input('Max length', min_value=1, max_value=2048, value=512)
32
+
33
+ with col2:
34
+ top_p = st.slider('p', 0, 1, 0.95)
35
+ top_k = st.slider('k', 1, 100, 50)
36
+
37
+ with st.spinner("Generating text..."):
38
+ stx.scrollableTextbox(llm.generate(text, max_len, temp, top_k, top_p), height=400)
model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+
3
+ class LLM:
4
+ def __init__(self):
5
+ self.model = AutoModelForCausalLM.from_pretrained('progs2002/star-trek-tng-script-generator')
6
+ self.tokenizer = AutoTokenizer.from_pretrained('progs2002/star-trek-tng-script-generator')
7
+
8
+ def generate(self, text, max_len=512, temp=1, k=50, p=0.95):
9
+ encoded_prompt = self.tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
10
+
11
+ output_tokens = self.model.generate(
12
+ input_ids = encoded_prompt,
13
+ max_length = max_len,
14
+ do_sample=True,
15
+ num_return_sequences=1,
16
+ pad_token_id=self.model.config.eos_token_id,
17
+ temperature=temp,
18
+ top_k=k,
19
+ top_p=p
20
+ )
21
+
22
+ text_out = self.tokenizer.decode(output_tokens[0], clean_up_tokenization_spaces=True, skip_special_tokens=True)
23
+ return text_out
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ streamlit-scrollable-textbox