Juna190825 commited on
Commit
013ddd6
·
verified ·
1 Parent(s): b0e1169

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -194
app.py CHANGED
@@ -1,209 +1,252 @@
1
  import os
2
- import torch
3
- import time
4
- import warnings
5
- from fastapi import FastAPI, Request
6
- from fastapi.responses import JSONResponse
7
- from fastapi.middleware.cors import CORSMiddleware
8
- import gradio as gr
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
-
11
- # Suppress specific warnings
12
- warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.hub")
13
-
14
- # Configure environment variables for cache
15
- os.environ["HF_HOME"] = os.getenv("HF_HOME", "/app/cache/huggingface")
16
- os.environ["MPLCONFIGDIR"] = os.getenv("MPLCONFIGDIR", "/app/cache/matplotlib")
17
-
18
- # Ensure cache directories exist
19
- os.makedirs(os.environ["HF_HOME"], exist_ok=True)
20
- os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True)
21
-
22
- # Initialize FastAPI app
23
- app = FastAPI()
24
-
25
- def log_message(message: str):
26
- """Helper function for logging"""
27
- print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}")
28
-
29
- def load_model():
30
- """Load the model with CPU optimization"""
31
- model_name = "trillionlabs/Trillion-7B-preview-AWQ"
32
-
33
- log_message("Loading tokenizer...")
34
- try:
35
- tokenizer = AutoTokenizer.from_pretrained(
36
- model_name,
37
- trust_remote_code=True
 
 
 
38
  )
39
- except Exception as e:
40
- log_message(f"Tokenizer loading failed: {e}")
41
- # Fallback to LlamaTokenizer if available
42
- from transformers import LlamaTokenizer
43
- tokenizer = LlamaTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- log_message("Loading model...")
46
- try:
47
- model = AutoModelForCausalLM.from_pretrained(
48
- model_name,
49
- torch_dtype=torch.float32,
50
- trust_remote_code=True
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Explicitly move to CPU
54
- model = model.to("cpu")
55
- model.eval()
56
 
57
- except Exception as e:
58
- log_message(f"Model loading failed: {e}")
59
- raise
60
 
61
- log_message("Creating pipeline...")
62
- text_generator = pipeline(
63
- "text-generation",
64
- model=model,
65
- tokenizer=tokenizer,
66
- device="cpu"
67
- )
68
 
69
- return text_generator, tokenizer
70
-
71
- # Load model
72
- try:
73
- log_message("Starting model loading process...")
74
- text_generator, tokenizer = load_model()
75
- log_message("Model loaded successfully")
76
- except Exception as e:
77
- log_message(f"Critical error loading model: {e}")
78
- raise
79
-
80
- # API endpoints
81
- @app.post("/api/generate")
82
- async def api_generate(request: Request):
83
- """API endpoint for text generation"""
84
- try:
85
- data = await request.json()
86
- prompt = data.get("prompt", "").strip()
87
- if not prompt:
88
- return JSONResponse({"error": "Prompt cannot be empty"}, status_code=400)
89
 
90
- max_length = min(int(data.get("max_length", 100)), 300) # Conservative limit
91
 
92
- start_time = time.time()
93
- outputs = text_generator(
94
- prompt,
95
- max_length=max_length,
96
- do_sample=True,
97
- temperature=0.7,
98
- top_k=50,
99
- top_p=0.95,
100
- pad_token_id=tokenizer.eos_token_id
101
- )
102
- generation_time = time.time() - start_time
103
 
104
- response_data = {
105
- "generated_text": outputs[0]["generated_text"],
106
- "time_seconds": round(generation_time, 2),
107
- "tokens_generated": len(tokenizer.tokenize(outputs[0]["generated_text"])),
108
- "model": "Trillion-7B-preview-AWQ",
109
- "device": "cpu"
110
- }
111
- return JSONResponse(response_data)
112
- except Exception as e:
113
- log_message(f"API Error: {e}")
114
- return JSONResponse({"error": str(e)}, status_code=500)
115
-
116
- @app.get("/health")
117
- async def health_check():
118
- """Health check endpoint"""
119
- return {
120
- "status": "healthy",
121
- "model_loaded": text_generator is not None,
122
- "device": "cpu",
123
- "cache_path": os.environ["HF_HOME"]
124
- }
125
-
126
- # Gradio Interface
127
- def gradio_generate(prompt, max_length=100):
128
- """Function for Gradio interface generation"""
129
- try:
130
- max_length = min(int(max_length), 300) # Same conservative limit as API
131
- if not prompt.strip():
132
- return "Please enter a prompt"
133
 
134
- outputs = text_generator(
135
- prompt,
136
- max_length=max_length,
137
- do_sample=True,
138
- temperature=0.7,
139
- top_k=50,
140
- top_p=0.95,
141
- pad_token_id=tokenizer.eos_token_id
142
- )
143
- return outputs[0]["generated_text"]
144
- except Exception as e:
145
- log_message(f"Gradio Error: {e}")
146
- return f"Error generating text: {str(e)}"
147
-
148
- with gr.Blocks(title="Trillion-7B CPU Demo", theme=gr.themes.Default()) as gradio_app:
149
- gr.Markdown("""
150
- # 🚀 Trillion-7B-preview-AWQ (CPU Version)
151
- *Running on CPU with optimized settings - responses may be slower than GPU versions*
152
- """)
153
 
154
- with gr.Row():
155
- with gr.Column():
156
- input_prompt = gr.Textbox(
157
- label="Your Prompt",
158
- placeholder="Enter text here...",
159
- lines=5,
160
- max_lines=10
161
- )
162
- with gr.Row():
163
- max_length = gr.Slider(
164
- label="Max Length",
165
- minimum=20,
166
- maximum=300,
167
- value=100,
168
- step=10
169
- )
170
- generate_btn = gr.Button("Generate", variant="primary")
171
- with gr.Column():
172
- output_text = gr.Textbox(
173
- label="Generated Text",
174
- lines=10,
175
- interactive=False
176
- )
177
 
178
- # Examples
179
- gr.Examples(
180
- examples=[
181
- ["Explain quantum computing in simple terms"],
182
- ["Write a haiku about artificial intelligence"],
183
- ["What are the main benefits of renewable energy?"],
184
- ["Suggest three ideas for a science fiction story"]
185
- ],
186
- inputs=input_prompt,
187
- label="Example Prompts"
188
- )
189
 
190
- generate_btn.click(
191
- fn=gradio_generate,
192
- inputs=[input_prompt, max_length],
193
- outputs=output_text
194
- )
195
-
196
- # Mount Gradio app
197
- app = gr.mount_gradio_app(app, gradio_app, path="/")
198
-
199
- # CORS configuration
200
- app.add_middleware(
201
- CORSMiddleware,
202
- allow_origins=["*"],
203
- allow_methods=["*"],
204
- allow_headers=["*"],
205
- )
206
 
207
- if __name__ == "__main__":
208
- import uvicorn
209
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+
3
+ from openai import OpenAI
4
+ import streamlit as st
5
+
6
+ st.title("Trillion-7B-Preview")
7
+
8
+ client = OpenAI(
9
+ api_key=os.getenv("API_KEY"),
10
+ base_url=os.getenv("BASE_URL"),
11
+ )
12
+
13
+ if "openai_model" not in st.session_state:
14
+ st.session_state["openai_model"] = "trillionlabs/Trillion-7B-preview"
15
+
16
+ if "messages" not in st.session_state:
17
+ st.session_state.messages = []
18
+
19
+ for message in st.session_state.messages:
20
+ with st.chat_message(message["role"]):
21
+ st.markdown(message["content"])
22
+
23
+ if prompt := st.chat_input("Message"):
24
+ st.session_state.messages.append({"role": "user", "content": prompt})
25
+ with st.chat_message("user"):
26
+ st.markdown(prompt)
27
+
28
+ with st.chat_message("assistant"):
29
+ stream = client.chat.completions.create(
30
+ model=st.session_state["openai_model"],
31
+ messages=[
32
+ {"role": m["role"], "content": m["content"]}
33
+ for m in st.session_state.messages
34
+ ],
35
+ stream=True,
36
+ extra_body={
37
+ "topP": 0.95,
38
+ "maxTokens": 3072,
39
+ "temperature": 0.6,
40
+ },
41
  )
42
+ response = st.write_stream(stream)
43
+ st.session_state.messages.append({"role": "assistant", "content": response})
44
+ # import os
45
+ # import torch
46
+ # import time
47
+ # import warnings
48
+ # from fastapi import FastAPI, Request
49
+ # from fastapi.responses import JSONResponse
50
+ # from fastapi.middleware.cors import CORSMiddleware
51
+ # import gradio as gr
52
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
53
+
54
+ # # Suppress specific warnings
55
+ # warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.hub")
56
+
57
+ # # Configure environment variables for cache
58
+ # os.environ["HF_HOME"] = os.getenv("HF_HOME", "/app/cache/huggingface")
59
+ # os.environ["MPLCONFIGDIR"] = os.getenv("MPLCONFIGDIR", "/app/cache/matplotlib")
60
+
61
+ # # Ensure cache directories exist
62
+ # os.makedirs(os.environ["HF_HOME"], exist_ok=True)
63
+ # os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True)
64
+
65
+ # # Initialize FastAPI app
66
+ # app = FastAPI()
67
+
68
+ # def log_message(message: str):
69
+ # """Helper function for logging"""
70
+ # print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}")
71
+
72
+ # def load_model():
73
+ # """Load the model with CPU optimization"""
74
+ # model_name = "trillionlabs/Trillion-7B-preview-AWQ"
75
 
76
+ # log_message("Loading tokenizer...")
77
+ # try:
78
+ # tokenizer = AutoTokenizer.from_pretrained(
79
+ # model_name,
80
+ # trust_remote_code=True
81
+ # )
82
+ # except Exception as e:
83
+ # log_message(f"Tokenizer loading failed: {e}")
84
+ # # Fallback to LlamaTokenizer if available
85
+ # from transformers import LlamaTokenizer
86
+ # tokenizer = LlamaTokenizer.from_pretrained(model_name)
87
+
88
+ # log_message("Loading model...")
89
+ # try:
90
+ # model = AutoModelForCausalLM.from_pretrained(
91
+ # model_name,
92
+ # torch_dtype=torch.float32,
93
+ # trust_remote_code=True
94
+ # )
95
 
96
+ # # Explicitly move to CPU
97
+ # model = model.to("cpu")
98
+ # model.eval()
99
 
100
+ # except Exception as e:
101
+ # log_message(f"Model loading failed: {e}")
102
+ # raise
103
 
104
+ # log_message("Creating pipeline...")
105
+ # text_generator = pipeline(
106
+ # "text-generation",
107
+ # model=model,
108
+ # tokenizer=tokenizer,
109
+ # device="cpu"
110
+ # )
111
 
112
+ # return text_generator, tokenizer
113
+
114
+ # # Load model
115
+ # try:
116
+ # log_message("Starting model loading process...")
117
+ # text_generator, tokenizer = load_model()
118
+ # log_message("Model loaded successfully")
119
+ # except Exception as e:
120
+ # log_message(f"Critical error loading model: {e}")
121
+ # raise
122
+
123
+ # # API endpoints
124
+ # @app.post("/api/generate")
125
+ # async def api_generate(request: Request):
126
+ # """API endpoint for text generation"""
127
+ # try:
128
+ # data = await request.json()
129
+ # prompt = data.get("prompt", "").strip()
130
+ # if not prompt:
131
+ # return JSONResponse({"error": "Prompt cannot be empty"}, status_code=400)
132
 
133
+ # max_length = min(int(data.get("max_length", 100)), 300) # Conservative limit
134
 
135
+ # start_time = time.time()
136
+ # outputs = text_generator(
137
+ # prompt,
138
+ # max_length=max_length,
139
+ # do_sample=True,
140
+ # temperature=0.7,
141
+ # top_k=50,
142
+ # top_p=0.95,
143
+ # pad_token_id=tokenizer.eos_token_id
144
+ # )
145
+ # generation_time = time.time() - start_time
146
 
147
+ # response_data = {
148
+ # "generated_text": outputs[0]["generated_text"],
149
+ # "time_seconds": round(generation_time, 2),
150
+ # "tokens_generated": len(tokenizer.tokenize(outputs[0]["generated_text"])),
151
+ # "model": "Trillion-7B-preview-AWQ",
152
+ # "device": "cpu"
153
+ # }
154
+ # return JSONResponse(response_data)
155
+ # except Exception as e:
156
+ # log_message(f"API Error: {e}")
157
+ # return JSONResponse({"error": str(e)}, status_code=500)
158
+
159
+ # @app.get("/health")
160
+ # async def health_check():
161
+ # """Health check endpoint"""
162
+ # return {
163
+ # "status": "healthy",
164
+ # "model_loaded": text_generator is not None,
165
+ # "device": "cpu",
166
+ # "cache_path": os.environ["HF_HOME"]
167
+ # }
168
+
169
+ # # Gradio Interface
170
+ # def gradio_generate(prompt, max_length=100):
171
+ # """Function for Gradio interface generation"""
172
+ # try:
173
+ # max_length = min(int(max_length), 300) # Same conservative limit as API
174
+ # if not prompt.strip():
175
+ # return "Please enter a prompt"
176
 
177
+ # outputs = text_generator(
178
+ # prompt,
179
+ # max_length=max_length,
180
+ # do_sample=True,
181
+ # temperature=0.7,
182
+ # top_k=50,
183
+ # top_p=0.95,
184
+ # pad_token_id=tokenizer.eos_token_id
185
+ # )
186
+ # return outputs[0]["generated_text"]
187
+ # except Exception as e:
188
+ # log_message(f"Gradio Error: {e}")
189
+ # return f"Error generating text: {str(e)}"
190
+
191
+ # with gr.Blocks(title="Trillion-7B CPU Demo", theme=gr.themes.Default()) as gradio_app:
192
+ # gr.Markdown("""
193
+ # # 🚀 Trillion-7B-preview-AWQ (CPU Version)
194
+ # *Running on CPU with optimized settings - responses may be slower than GPU versions*
195
+ # """)
196
 
197
+ # with gr.Row():
198
+ # with gr.Column():
199
+ # input_prompt = gr.Textbox(
200
+ # label="Your Prompt",
201
+ # placeholder="Enter text here...",
202
+ # lines=5,
203
+ # max_lines=10
204
+ # )
205
+ # with gr.Row():
206
+ # max_length = gr.Slider(
207
+ # label="Max Length",
208
+ # minimum=20,
209
+ # maximum=300,
210
+ # value=100,
211
+ # step=10
212
+ # )
213
+ # generate_btn = gr.Button("Generate", variant="primary")
214
+ # with gr.Column():
215
+ # output_text = gr.Textbox(
216
+ # label="Generated Text",
217
+ # lines=10,
218
+ # interactive=False
219
+ # )
220
 
221
+ # # Examples
222
+ # gr.Examples(
223
+ # examples=[
224
+ # ["Explain quantum computing in simple terms"],
225
+ # ["Write a haiku about artificial intelligence"],
226
+ # ["What are the main benefits of renewable energy?"],
227
+ # ["Suggest three ideas for a science fiction story"]
228
+ # ],
229
+ # inputs=input_prompt,
230
+ # label="Example Prompts"
231
+ # )
232
 
233
+ # generate_btn.click(
234
+ # fn=gradio_generate,
235
+ # inputs=[input_prompt, max_length],
236
+ # outputs=output_text
237
+ # )
238
+
239
+ # # Mount Gradio app
240
+ # app = gr.mount_gradio_app(app, gradio_app, path="/")
241
+
242
+ # # CORS configuration
243
+ # app.add_middleware(
244
+ # CORSMiddleware,
245
+ # allow_origins=["*"],
246
+ # allow_methods=["*"],
247
+ # allow_headers=["*"],
248
+ # )
249
 
250
+ # if __name__ == "__main__":
251
+ # import uvicorn
252
+ # uvicorn.run(app, host="0.0.0.0", port=7860)