mohsinmubaraksk commited on
Commit
304986b
Β·
verified Β·
1 Parent(s): 1dd600d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import joblib
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
+
8
+ # Set the title of the app
9
+ st.title("πŸ€– Robotic Grasp Robustness Predictor")
10
+ st.markdown("""
11
+ Welcome! This application predicts the **stability/robustness** of a robotic grasp based on sensor data.
12
+ You can either upload a CSV file with sensor features or input values manually. πŸš€
13
+ """)
14
+
15
+ # Display robotics images at the top (ensure the URLs are accessible)
16
+ st.image("https://upload.wikimedia.org/wikipedia/commons/thumb/3/34/Robotic_arm_by_Robust_automation_in_Australia.jpg/640px-Robotic_arm_by_Robust_automation_in_Australia.jpg",
17
+ caption="Robotic Arm 🦾", use_container_width=True)
18
+ st.image("https://upload.wikimedia.org/wikipedia/commons/9/95/Robotic_arm.jpg",
19
+ caption="Advanced Robotics πŸ€–", use_container_width=True)
20
+
21
+ # Load the saved model (make sure the model file 'model.pkl' is in the same directory)
22
+ @st.cache_resource
23
+ def load_model():
24
+ model = joblib.load("model.pkl")
25
+ return model
26
+
27
+ model = load_model()
28
+
29
+ # Sidebar for input selection
30
+ st.sidebar.header("Input Options")
31
+ input_method = st.sidebar.radio("Choose input method:", ("CSV Upload", "Manual Input"))
32
+
33
+ # Define the list of features expected by the model.
34
+ FEATURES = [
35
+ "H1_F1J2_pos", "H1_F1J2_vel", "H1_F1J2_eff",
36
+ "H1_F1J3_pos", "H1_F1J3_vel", "H1_F1J3_eff",
37
+ "H1_F1J1_pos", "H1_F1J1_vel", "H1_F1J1_eff",
38
+ "H1_F3J1_pos", "H1_F3J1_vel", "H1_F3J1_eff",
39
+ "H1_F3J2_pos", "H1_F3J2_vel", "H1_F3J2_eff",
40
+ "H1_F3J3_pos", "H1_F3J3_vel", "H1_F3J3_eff",
41
+ "H1_F2J1_pos", "H1_F2J1_vel", "H1_F2J1_eff",
42
+ "H1_F2J3_pos", "H1_F2J3_vel", "H1_F2J3_eff",
43
+ "H1_F2J2_pos", "H1_F2J2_vel", "H1_F2J2_eff"
44
+ ]
45
+
46
+ # CSV Upload Option
47
+ if input_method == "CSV Upload":
48
+ st.header("Upload CSV File πŸ“")
49
+ uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
50
+ if uploaded_file is not None:
51
+ try:
52
+ input_df = pd.read_csv(uploaded_file)
53
+ st.success("CSV file successfully loaded! βœ…")
54
+ # Clean column names
55
+ input_df.columns = input_df.columns.str.strip()
56
+ missing_features = [col for col in FEATURES if col not in input_df.columns]
57
+ if missing_features:
58
+ st.error(f"Missing required feature(s): {', '.join(missing_features)} 😟")
59
+ else:
60
+ # Make predictions
61
+ predictions = model.predict(input_df[FEATURES])
62
+ input_df["Predicted Robustness"] = predictions
63
+
64
+ # Interactive histogram of predictions
65
+ fig_hist = px.histogram(input_df, x="Predicted Robustness", nbins=30,
66
+ title="Distribution of Predicted Robustness")
67
+ st.plotly_chart(fig_hist, use_container_width=True)
68
+
69
+ st.subheader("Predictions")
70
+ st.dataframe(input_df)
71
+ except Exception as e:
72
+ st.error(f"Error processing file: {e}")
73
+
74
+ # Manual Input Option
75
+ else:
76
+ st.header("Manual Input ✍️")
77
+ st.markdown("Enter the sensor values for prediction below:")
78
+
79
+ # Add a radio button for selecting the input mode within manual input
80
+ input_mode = st.radio("Select input mode:", ("Custom Input", "API Default Values"))
81
+
82
+ # Default API values (you can adjust these as appropriate)
83
+ default_values = {
84
+ "H1_F1J2_pos": 0.5, "H1_F1J2_vel": 0.0, "H1_F1J2_eff": 0.1,
85
+ "H1_F1J3_pos": 0.5, "H1_F1J3_vel": 0.0, "H1_F1J3_eff": 0.1,
86
+ "H1_F1J1_pos": 0.5, "H1_F1J1_vel": 0.0, "H1_F1J1_eff": 0.1,
87
+ "H1_F3J1_pos": 1.0, "H1_F3J1_vel": 0.0, "H1_F3J1_eff": 0.2,
88
+ "H1_F3J2_pos": 1.0, "H1_F3J2_vel": 0.0, "H1_F3J2_eff": 0.2,
89
+ "H1_F3J3_pos": 1.0, "H1_F3J3_vel": 0.0, "H1_F3J3_eff": 0.2,
90
+ "H1_F2J1_pos": 0.7, "H1_F2J1_vel": 0.0, "H1_F2J1_eff": 0.15,
91
+ "H1_F2J3_pos": 0.7, "H1_F2J3_vel": 0.0, "H1_F2J3_eff": 0.15,
92
+ "H1_F2J2_pos": 0.7, "H1_F2J2_vel": 0.0, "H1_F2J2_eff": 0.15,
93
+ }
94
+
95
+ # Container dictionary for all input values.
96
+ input_data = {}
97
+
98
+ # Create expandable sections for sensor groups.
99
+ with st.expander("F1 Joint Sensors"):
100
+ st.write("Sensors for Finger 1:")
101
+ input_data["H1_F1J2_pos"] = st.number_input("H1_F1J2_pos",
102
+ value=default_values["H1_F1J2_pos"] if input_mode=="API Default Values" else 0.0,
103
+ format="%.5f")
104
+ input_data["H1_F1J2_vel"] = st.number_input("H1_F1J2_vel",
105
+ value=default_values["H1_F1J2_vel"] if input_mode=="API Default Values" else 0.0,
106
+ format="%.5f")
107
+ input_data["H1_F1J2_eff"] = st.number_input("H1_F1J2_eff",
108
+ value=default_values["H1_F1J2_eff"] if input_mode=="API Default Values" else 0.0,
109
+ format="%.5f")
110
+
111
+ st.write("Sensors for Finger 3:")
112
+ input_data["H1_F1J3_pos"] = st.number_input("H1_F1J3_pos",
113
+ value=default_values["H1_F1J3_pos"] if input_mode=="API Default Values" else 0.0,
114
+ format="%.5f")
115
+ input_data["H1_F1J3_vel"] = st.number_input("H1_F1J3_vel",
116
+ value=default_values["H1_F1J3_vel"] if input_mode=="API Default Values" else 0.0,
117
+ format="%.5f")
118
+ input_data["H1_F1J3_eff"] = st.number_input("H1_F1J3_eff",
119
+ value=default_values["H1_F1J3_eff"] if input_mode=="API Default Values" else 0.0,
120
+ format="%.5f")
121
+
122
+ st.write("Sensors for Finger 1 (alternate):")
123
+ input_data["H1_F1J1_pos"] = st.number_input("H1_F1J1_pos",
124
+ value=default_values["H1_F1J1_pos"] if input_mode=="API Default Values" else 0.0,
125
+ format="%.5f")
126
+ input_data["H1_F1J1_vel"] = st.number_input("H1_F1J1_vel",
127
+ value=default_values["H1_F1J1_vel"] if input_mode=="API Default Values" else 0.0,
128
+ format="%.5f")
129
+ input_data["H1_F1J1_eff"] = st.number_input("H1_F1J1_eff",
130
+ value=default_values["H1_F1J1_eff"] if input_mode=="API Default Values" else 0.0,
131
+ format="%.5f")
132
+
133
+ with st.expander("F3 Joint Sensors"):
134
+ input_data["H1_F3J1_pos"] = st.number_input("H1_F3J1_pos",
135
+ value=default_values["H1_F3J1_pos"] if input_mode=="API Default Values" else 0.0,
136
+ format="%.5f")
137
+ input_data["H1_F3J1_vel"] = st.number_input("H1_F3J1_vel",
138
+ value=default_values["H1_F3J1_vel"] if input_mode=="API Default Values" else 0.0,
139
+ format="%.5f")
140
+ input_data["H1_F3J1_eff"] = st.number_input("H1_F3J1_eff",
141
+ value=default_values["H1_F3J1_eff"] if input_mode=="API Default Values" else 0.0,
142
+ format="%.5f")
143
+
144
+ input_data["H1_F3J2_pos"] = st.number_input("H1_F3J2_pos",
145
+ value=default_values["H1_F3J2_pos"] if input_mode=="API Default Values" else 0.0,
146
+ format="%.5f")
147
+ input_data["H1_F3J2_vel"] = st.number_input("H1_F3J2_vel",
148
+ value=default_values["H1_F3J2_vel"] if input_mode=="API Default Values" else 0.0,
149
+ format="%.5f")
150
+ input_data["H1_F3J2_eff"] = st.number_input("H1_F3J2_eff",
151
+ value=default_values["H1_F3J2_eff"] if input_mode=="API Default Values" else 0.0,
152
+ format="%.5f")
153
+
154
+ input_data["H1_F3J3_pos"] = st.number_input("H1_F3J3_pos",
155
+ value=default_values["H1_F3J3_pos"] if input_mode=="API Default Values" else 0.0,
156
+ format="%.5f")
157
+ input_data["H1_F3J3_vel"] = st.number_input("H1_F3J3_vel",
158
+ value=default_values["H1_F3J3_vel"] if input_mode=="API Default Values" else 0.0,
159
+ format="%.5f")
160
+ input_data["H1_F3J3_eff"] = st.number_input("H1_F3J3_eff",
161
+ value=default_values["H1_F3J3_eff"] if input_mode=="API Default Values" else 0.0,
162
+ format="%.5f")
163
+
164
+ with st.expander("F2 Joint Sensors"):
165
+ input_data["H1_F2J1_pos"] = st.number_input("H1_F2J1_pos",
166
+ value=default_values["H1_F2J1_pos"] if input_mode=="API Default Values" else 0.0,
167
+ format="%.5f")
168
+ input_data["H1_F2J1_vel"] = st.number_input("H1_F2J1_vel",
169
+ value=default_values["H1_F2J1_vel"] if input_mode=="API Default Values" else 0.0,
170
+ format="%.5f")
171
+ input_data["H1_F2J1_eff"] = st.number_input("H1_F2J1_eff",
172
+ value=default_values["H1_F2J1_eff"] if input_mode=="API Default Values" else 0.0,
173
+ format="%.5f")
174
+
175
+ input_data["H1_F2J3_pos"] = st.number_input("H1_F2J3_pos",
176
+ value=default_values["H1_F2J3_pos"] if input_mode=="API Default Values" else 0.0,
177
+ format="%.5f")
178
+ input_data["H1_F2J3_vel"] = st.number_input("H1_F2J3_vel",
179
+ value=default_values["H1_F2J3_vel"] if input_mode=="API Default Values" else 0.0,
180
+ format="%.5f")
181
+ input_data["H1_F2J3_eff"] = st.number_input("H1_F2J3_eff",
182
+ value=default_values["H1_F2J3_eff"] if input_mode=="API Default Values" else 0.0,
183
+ format="%.5f")
184
+
185
+ input_data["H1_F2J2_pos"] = st.number_input("H1_F2J2_pos",
186
+ value=default_values["H1_F2J2_pos"] if input_mode=="API Default Values" else 0.0,
187
+ format="%.5f")
188
+ input_data["H1_F2J2_vel"] = st.number_input("H1_F2J2_vel",
189
+ value=default_values["H1_F2J2_vel"] if input_mode=="API Default Values" else 0.0,
190
+ format="%.5f")
191
+ input_data["H1_F2J2_eff"] = st.number_input("H1_F2J2_eff",
192
+ value=default_values["H1_F2J2_eff"] if input_mode=="API Default Values" else 0.0,
193
+ format="%.5f")
194
+
195
+ # Predict button for manual input
196
+ if st.button("Predict 🏁"):
197
+ input_df = pd.DataFrame([input_data])
198
+ prediction = model.predict(input_df)
199
+ st.success(f"The predicted grasp robustness is: {prediction[0]:.3f}")
200
+
201
+ # Create a gauge chart using Plotly for visualization.
202
+ fig_gauge = go.Figure(go.Indicator(
203
+ mode="gauge+number",
204
+ value=prediction[0],
205
+ title={'text': "Grasp Robustness"},
206
+ gauge={
207
+ 'axis': {'range': [0, 100]},
208
+ 'bar': {'color': "darkblue"},
209
+ 'steps': [
210
+ {'range': [0, 30], 'color': "red"},
211
+ {'range': [30, 70], 'color': "yellow"},
212
+ {'range': [70, 100], 'color': "green"}
213
+ ],
214
+ }
215
+ ))
216
+ st.plotly_chart(fig_gauge, use_container_width=True)
217
+
218
+ # Display feature importances if available
219
+ try:
220
+ feature_importances = model.named_steps['model'].feature_importances_
221
+ imp_df = pd.DataFrame({
222
+ 'Feature': FEATURES,
223
+ 'Importance': feature_importances
224
+ }).sort_values(by='Importance', ascending=False)
225
+ st.subheader("Feature Importances πŸ“Š")
226
+ fig_imp = px.bar(imp_df, x='Feature', y='Importance', title="Feature Importances")
227
+ st.plotly_chart(fig_imp, use_container_width=True)
228
+ except Exception as ex:
229
+ st.info("Feature importance is not available for this model.")