lyimo commited on
Commit
ebd2318
·
verified ·
1 Parent(s): ef17f59

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -7
  2. main.py +463 -0
app.py DELETED
@@ -1,7 +0,0 @@
1
- # app.py
2
- from main import app
3
-
4
- # For Hugging Face Spaces
5
- if __name__ == "__main__":
6
- import uvicorn
7
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fasthtml.common import *
2
+ from fastai.vision.all import *
3
+ import os
4
+ import time
5
+ from pathlib import Path
6
+ import urllib.request
7
+ from io import BytesIO
8
+
9
+ # Create necessary directories
10
+ os.makedirs('uploads', exist_ok=True)
11
+
12
+ # Function to load model - with fallback for testing
13
+ def load_model():
14
+ try:
15
+ model_path = 'levit.pkl'
16
+ # Check if model exists, if not try to download a sample model (for demo purposes)
17
+ if not os.path.exists(model_path):
18
+ print("Model not found. This is just for testing purposes.")
19
+ # In a real deployment, you'd want to handle this more gracefully
20
+ return None, ['class1', 'class2', 'class3']
21
+
22
+ learn = load_learner(model_path)
23
+ labels = learn.dls.vocab
24
+ print(f"Model loaded successfully with labels: {labels}")
25
+ return learn, labels
26
+ except Exception as e:
27
+ print(f"Error loading model: {e}")
28
+ # Fallback for testing
29
+ return None, ['class1', 'class2', 'class3']
30
+
31
+ # Load the model at startup
32
+ learn, labels = load_model()
33
+
34
+ # Create a FastHTML app
35
+ app, rt = fast_app()
36
+
37
+ # Define the prediction function
38
+ def predict(img_bytes):
39
+ try:
40
+ # If no model is loaded, return mock predictions for testing
41
+ if learn is None:
42
+ import random
43
+ mock_results = {label: random.random() for label in labels}
44
+ # Sort by values and normalize to ensure they sum to 1
45
+ total = sum(mock_results.values())
46
+ return {k: v/total for k, v in sorted(mock_results.items(), key=lambda x: x[1], reverse=True)}
47
+
48
+ # Real prediction with the model
49
+ img = PILImage.create(BytesIO(img_bytes))
50
+ img = img.resize((512, 512))
51
+ pred, pred_idx, probs = learn.predict(img)
52
+ return {labels[i]: float(probs[i]) for i in range(len(labels))}
53
+ except Exception as e:
54
+ print(f"Prediction error: {e}")
55
+ return {"Error": 1.0}
56
+
57
+ # Main page route
58
+ @rt("/")
59
+ def get():
60
+ # Create a form for image upload
61
+ upload_form = Form(
62
+ Div(
63
+ H1("FastAI Image Classifier"),
64
+ P("Upload an image to classify it using a pre-trained model."),
65
+ cls="instructions"
66
+ ),
67
+ Div(
68
+ Input(type="file", name="image", accept="image/*", required=True,
69
+ hx_indicator="#loading"),
70
+ Button("Classify", type="submit"),
71
+ cls="upload-controls"
72
+ ),
73
+ hx_post="/predict",
74
+ hx_target="#result",
75
+ hx_swap="innerHTML",
76
+ hx_encoding="multipart/form-data",
77
+ id="upload-form"
78
+ )
79
+
80
+ # Add loading indicator
81
+ loading = Div(
82
+ P("Processing your image..."),
83
+ id="loading",
84
+ cls="htmx-indicator"
85
+ )
86
+
87
+ # Container for results
88
+ result_container = Div(id="result", cls="result-container")
89
+
90
+ # Example section
91
+ examples = Div(
92
+ H2("Or try an example:"),
93
+ A("Example Image", href="#",
94
+ hx_get="/predict_example",
95
+ hx_target="#result",
96
+ hx_indicator="#loading"),
97
+ cls="examples-section"
98
+ )
99
+
100
+ # CSS styles
101
+ css = """
102
+ :root {
103
+ --primary-color: #3498db;
104
+ --secondary-color: #2c3e50;
105
+ --background-color: #f9f9f9;
106
+ --error-color: #e74c3c;
107
+ --shadow-color: rgba(0, 0, 0, 0.1);
108
+ --border-color: #ddd;
109
+ }
110
+
111
+ body {
112
+ font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
113
+ line-height: 1.6;
114
+ color: #333;
115
+ max-width: 800px;
116
+ margin: 0 auto;
117
+ padding: 20px;
118
+ background-color: #fff;
119
+ }
120
+
121
+ h1 {
122
+ color: var(--secondary-color);
123
+ margin-bottom: 1rem;
124
+ font-weight: 600;
125
+ }
126
+
127
+ h2 {
128
+ color: var(--primary-color);
129
+ margin-top: 1.5rem;
130
+ font-weight: 500;
131
+ }
132
+
133
+ .instructions {
134
+ margin-bottom: 20px;
135
+ }
136
+
137
+ .upload-controls {
138
+ display: flex;
139
+ gap: 10px;
140
+ margin-bottom: 30px;
141
+ align-items: center;
142
+ flex-wrap: wrap;
143
+ }
144
+
145
+ button {
146
+ background-color: var(--primary-color);
147
+ color: white;
148
+ border: none;
149
+ padding: 10px 15px;
150
+ border-radius: 4px;
151
+ cursor: pointer;
152
+ transition: background-color 0.3s;
153
+ font-weight: 500;
154
+ }
155
+
156
+ button:hover {
157
+ background-color: #2980b9;
158
+ }
159
+
160
+ input[type="file"] {
161
+ padding: 10px;
162
+ border: 1px solid var(--border-color);
163
+ border-radius: 4px;
164
+ flex-grow: 1;
165
+ }
166
+
167
+ #upload-form {
168
+ margin-bottom: 40px;
169
+ padding: 20px;
170
+ border-radius: 8px;
171
+ background-color: var(--background-color);
172
+ box-shadow: 0 2px 10px var(--shadow-color);
173
+ }
174
+
175
+ .result-container {
176
+ margin-top: 20px;
177
+ }
178
+
179
+ .prediction-results {
180
+ margin-top: 20px;
181
+ padding: 20px;
182
+ border: 1px solid var(--border-color);
183
+ border-radius: 8px;
184
+ background-color: var(--background-color);
185
+ box-shadow: 0 2px 8px var(--shadow-color);
186
+ }
187
+
188
+ .result-image {
189
+ max-width: 100%;
190
+ height: auto;
191
+ border-radius: 8px;
192
+ box-shadow: 0 2px 5px var(--shadow-color);
193
+ margin-bottom: 20px;
194
+ display: block;
195
+ }
196
+
197
+ .prediction-list {
198
+ margin-top: 15px;
199
+ }
200
+
201
+ .prediction-item {
202
+ padding: 12px 15px;
203
+ margin-bottom: 10px;
204
+ background-color: white;
205
+ border-radius: 6px;
206
+ box-shadow: 0 1px 3px var(--shadow-color);
207
+ }
208
+
209
+ .label-text {
210
+ margin-bottom: 8px;
211
+ font-weight: 500;
212
+ display: flex;
213
+ justify-content: space-between;
214
+ }
215
+
216
+ .examples-section {
217
+ margin-top: 30px;
218
+ padding-top: 20px;
219
+ border-top: 1px solid var(--border-color);
220
+ }
221
+
222
+ .htmx-indicator {
223
+ display: none;
224
+ padding: 15px;
225
+ background-color: #e8f4fc;
226
+ border-radius: 6px;
227
+ text-align: center;
228
+ margin: 15px 0;
229
+ box-shadow: 0 1px 3px var(--shadow-color);
230
+ }
231
+
232
+ .htmx-request .htmx-indicator {
233
+ display: block;
234
+ }
235
+
236
+ .progress-bar {
237
+ height: 10px;
238
+ background-color: #f0f0f0;
239
+ border-radius: 5px;
240
+ margin: 5px 0;
241
+ overflow: hidden;
242
+ }
243
+
244
+ .progress-fill {
245
+ height: 100%;
246
+ background-color: var(--primary-color);
247
+ width: 0;
248
+ transition: width 0.5s ease;
249
+ }
250
+
251
+ .error-message {
252
+ color: var(--error-color);
253
+ padding: 15px;
254
+ border: 1px solid var(--error-color);
255
+ border-radius: 5px;
256
+ background-color: #fde9e7;
257
+ }
258
+
259
+ a {
260
+ color: var(--primary-color);
261
+ text-decoration: none;
262
+ font-weight: 500;
263
+ }
264
+
265
+ a:hover {
266
+ text-decoration: underline;
267
+ }
268
+
269
+ /* Responsive styling */
270
+ @media (max-width: 600px) {
271
+ .upload-controls {
272
+ flex-direction: column;
273
+ align-items: stretch;
274
+ }
275
+
276
+ button {
277
+ width: 100%;
278
+ }
279
+ }
280
+
281
+ .model-info {
282
+ font-size: 0.9rem;
283
+ color: #666;
284
+ margin-top: 40px;
285
+ padding-top: 20px;
286
+ border-top: 1px solid var(--border-color);
287
+ }
288
+ """
289
+
290
+ # Model information
291
+ model_info = Div(
292
+ P(f"Model: {'Model loaded successfully' if learn is not None else 'Demo mode - no model loaded'}"),
293
+ P(f"Classes: {', '.join(labels)}"),
294
+ cls="model-info"
295
+ )
296
+
297
+ return Titled("FastAI Image Classifier",
298
+ upload_form,
299
+ loading,
300
+ result_container,
301
+ examples,
302
+ model_info,
303
+ Style(css))
304
+
305
+ # Prediction route for uploaded images
306
+ @rt("/predict")
307
+ async def post(image: UploadFile):
308
+ try:
309
+ # Read the uploaded image
310
+ image_bytes = await image.read()
311
+
312
+ # Generate a unique filename to avoid conflicts
313
+ from datetime import datetime
314
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
315
+ safe_filename = f"{timestamp}_{image.filename.replace(' ', '_')}"
316
+
317
+ # Save the image temporarily
318
+ img_path = f"uploads/{safe_filename}"
319
+ with open(img_path, "wb") as f:
320
+ f.write(image_bytes)
321
+
322
+ # Add a small delay to make the loading indicator visible
323
+ time.sleep(0.5)
324
+
325
+ # Make a prediction
326
+ results = predict(image_bytes)
327
+
328
+ # Sort results by probability
329
+ sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
330
+ top_results = dict(list(sorted_results.items())[:3])
331
+
332
+ # Create prediction items with progress bars
333
+ prediction_items = []
334
+ for label, prob in top_results.items():
335
+ percentage = int(prob * 100)
336
+ prediction_items.append(
337
+ Div(
338
+ Div(
339
+ Span(f"{label}"),
340
+ Span(f"{percentage}%"),
341
+ cls="label-text"
342
+ ),
343
+ Div(
344
+ Div(cls="progress-fill", style=f"width: {percentage}%;"),
345
+ cls="progress-bar"
346
+ ),
347
+ cls="prediction-item"
348
+ )
349
+ )
350
+
351
+ # Create result HTML
352
+ result_html = Div(
353
+ H2("Prediction Results:"),
354
+ Img(src=f"/image/{safe_filename}", cls="result-image", alt="Uploaded image"),
355
+ Div(*prediction_items, cls="prediction-list"),
356
+ cls="prediction-results"
357
+ )
358
+
359
+ return result_html
360
+
361
+ except Exception as e:
362
+ return Div(
363
+ H2("Error"),
364
+ P(f"An error occurred during prediction: {str(e)}"),
365
+ cls="error-message"
366
+ )
367
+
368
+ # Route to serve saved images
369
+ @rt("/image/{filename}")
370
+ def get(filename: str):
371
+ file_path = f"uploads/{filename}"
372
+ if os.path.exists(file_path):
373
+ return FileResponse(file_path)
374
+ else:
375
+ return Div(
376
+ H2("Error"),
377
+ P("Image not found."),
378
+ cls="error-message"
379
+ )
380
+
381
+ # Route for example image
382
+ @rt("/predict_example")
383
+ def get():
384
+ try:
385
+ # Path to example image
386
+ example_path = "image.jpg"
387
+
388
+ # Check if example image exists
389
+ if os.path.exists(example_path):
390
+ with open(example_path, "rb") as f:
391
+ image_bytes = f.read()
392
+
393
+ # Save the example image to uploads
394
+ example_name = "example.jpg"
395
+ with open(f"uploads/{example_name}", "wb") as f:
396
+ f.write(image_bytes)
397
+
398
+ # Add a small delay to make the loading indicator visible
399
+ time.sleep(0.5)
400
+
401
+ # Make a prediction
402
+ results = predict(image_bytes)
403
+
404
+ # Sort results by probability
405
+ sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
406
+ top_results = dict(list(sorted_results.items())[:3])
407
+
408
+ # Create prediction items with progress bars
409
+ prediction_items = []
410
+ for label, prob in top_results.items():
411
+ percentage = int(prob * 100)
412
+ prediction_items.append(
413
+ Div(
414
+ Div(
415
+ Span(f"{label}"),
416
+ Span(f"{percentage}%"),
417
+ cls="label-text"
418
+ ),
419
+ Div(
420
+ Div(cls="progress-fill", style=f"width: {percentage}%;"),
421
+ cls="progress-bar"
422
+ ),
423
+ cls="prediction-item"
424
+ )
425
+ )
426
+
427
+ # Create result HTML
428
+ result_html = Div(
429
+ H2("Prediction Results:"),
430
+ Img(src=f"/image/{example_name}", cls="result-image", alt="Example image"),
431
+ Div(*prediction_items, cls="prediction-list"),
432
+ P("This is a demonstration using the provided example image.", style="font-style: italic; color: #666;"),
433
+ cls="prediction-results"
434
+ )
435
+
436
+ return result_html
437
+ else:
438
+ return Div(
439
+ H2("Example Not Found"),
440
+ P("The example image 'image.jpg' was not found. Please try uploading your own image."),
441
+ cls="error-message"
442
+ )
443
+
444
+ except Exception as e:
445
+ return Div(
446
+ H2("Error"),
447
+ P(f"An error occurred with the example: {str(e)}"),
448
+ cls="error-message"
449
+ )
450
+
451
+ # Health check endpoint (useful for Docker/Kubernetes)
452
+ @rt("/health")
453
+ def get():
454
+ return {"status": "ok", "model_loaded": learn is not None}
455
+
456
+ # Run the app
457
+ if __name__ == "__main__":
458
+ # Use environment variables if available (common in Docker)
459
+ host = os.environ.get("HOST", "0.0.0.0")
460
+ port = int(os.environ.get("PORT", 8000))
461
+
462
+ print(f"Starting FastHTML server on {host}:{port}")
463
+ serve(app=app, host=host, port=port)