mohsinmubaraksk commited on
Commit
fa9714e
·
verified ·
1 Parent(s): 1bc659d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import joblib
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ import plotly.express as px
7
+
8
+ # Set the title of the app
9
+ st.title("Robotic Grasp Robustness Predictor")
10
+ st.markdown("""
11
+ This application predicts the stability/robustness of a robotic grasp based on sensor data.
12
+ You can either upload a CSV file with sensor features or input values manually.
13
+ """)
14
+
15
+ # Load the saved model (make sure the model file 'model.pkl' is in the same directory)
16
+ @st.cache_resource
17
+ def load_model():
18
+ model = joblib.load("model.pkl")
19
+ return model
20
+
21
+ model = load_model()
22
+
23
+ # Sidebar for input selection
24
+ st.sidebar.header("Input Options")
25
+ input_method = st.sidebar.radio("Choose input method:", ("CSV Upload", "Manual Input"))
26
+
27
+ # Define the list of features expected by the model.
28
+ # Replace these with your actual feature names after cleaning.
29
+ FEATURES = [
30
+ "H1_F1J2_pos", "H1_F1J2_vel", "H1_F1J2_eff",
31
+ "H1_F1J3_pos", "H1_F1J3_vel", "H1_F1J3_eff",
32
+ "H1_F1J1_pos", "H1_F1J1_vel", "H1_F1J1_eff",
33
+ "H1_F3J1_pos", "H1_F3J1_vel", "H1_F3J1_eff",
34
+ "H1_F3J2_pos", "H1_F3J2_vel", "H1_F3J2_eff",
35
+ "H1_F3J3_pos", "H1_F3J3_vel", "H1_F3J3_eff",
36
+ "H1_F2J1_pos", "H1_F2J1_vel", "H1_F2J1_eff",
37
+ "H1_F2J3_pos", "H1_F2J3_vel", "H1_F2J3_eff",
38
+ "H1_F2J2_pos", "H1_F2J2_vel", "H1_F2J2_eff"
39
+ ]
40
+
41
+ # Option 1: CSV Upload
42
+ if input_method == "CSV Upload":
43
+ st.header("Upload CSV File")
44
+ uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
45
+ if uploaded_file is not None:
46
+ try:
47
+ input_df = pd.read_csv(uploaded_file)
48
+ st.success("CSV file successfully loaded!")
49
+ # Clean column names
50
+ input_df.columns = input_df.columns.str.strip()
51
+ missing_features = [col for col in FEATURES if col not in input_df.columns]
52
+ if missing_features:
53
+ st.error(f"The following required feature(s) are missing from the file: {', '.join(missing_features)}")
54
+ else:
55
+ # Make predictions
56
+ predictions = model.predict(input_df[FEATURES])
57
+ input_df["Predicted Robustness"] = predictions
58
+
59
+ # Interactive histogram of predictions
60
+ fig_hist = px.histogram(input_df, x="Predicted Robustness", nbins=30,
61
+ title="Distribution of Predicted Robustness")
62
+ st.plotly_chart(fig_hist, use_container_width=True)
63
+
64
+ st.subheader("Predictions")
65
+ st.dataframe(input_df)
66
+ except Exception as e:
67
+ st.error(f"Error processing file: {e}")
68
+ # Option 2: Manual Input
69
+ else:
70
+ st.header("Manual Input")
71
+ st.markdown("Enter the sensor values for prediction:")
72
+
73
+ # Group sensor inputs into sections to improve readability
74
+ input_data = {}
75
+
76
+ with st.expander("F1 Joint Sensors"):
77
+ st.write("Sensors for Finger 1:")
78
+ input_data["H1_F1J2_pos"] = st.number_input("H1_F1J2_pos", value=0.0, format="%.5f")
79
+ input_data["H1_F1J2_vel"] = st.number_input("H1_F1J2_vel", value=0.0, format="%.5f")
80
+ input_data["H1_F1J2_eff"] = st.number_input("H1_F1J2_eff", value=0.0, format="%.5f")
81
+
82
+ st.write("Sensors for Finger 3:")
83
+ input_data["H1_F1J3_pos"] = st.number_input("H1_F1J3_pos", value=0.0, format="%.5f")
84
+ input_data["H1_F1J3_vel"] = st.number_input("H1_F1J3_vel", value=0.0, format="%.5f")
85
+ input_data["H1_F1J3_eff"] = st.number_input("H1_F1J3_eff", value=0.0, format="%.5f")
86
+
87
+ st.write("Sensors for Finger 1 (alternate):")
88
+ input_data["H1_F1J1_pos"] = st.number_input("H1_F1J1_pos", value=0.0, format="%.5f")
89
+ input_data["H1_F1J1_vel"] = st.number_input("H1_F1J1_vel", value=0.0, format="%.5f")
90
+ input_data["H1_F1J1_eff"] = st.number_input("H1_F1J1_eff", value=0.0, format="%.5f")
91
+
92
+ with st.expander("F3 Joint Sensors"):
93
+ input_data["H1_F3J1_pos"] = st.number_input("H1_F3J1_pos", value=0.0, format="%.5f")
94
+ input_data["H1_F3J1_vel"] = st.number_input("H1_F3J1_vel", value=0.0, format="%.5f")
95
+ input_data["H1_F3J1_eff"] = st.number_input("H1_F3J1_eff", value=0.0, format="%.5f")
96
+
97
+ input_data["H1_F3J2_pos"] = st.number_input("H1_F3J2_pos", value=0.0, format="%.5f")
98
+ input_data["H1_F3J2_vel"] = st.number_input("H1_F3J2_vel", value=0.0, format="%.5f")
99
+ input_data["H1_F3J2_eff"] = st.number_input("H1_F3J2_eff", value=0.0, format="%.5f")
100
+
101
+ input_data["H1_F3J3_pos"] = st.number_input("H1_F3J3_pos", value=0.0, format="%.5f")
102
+ input_data["H1_F3J3_vel"] = st.number_input("H1_F3J3_vel", value=0.0, format="%.5f")
103
+ input_data["H1_F3J3_eff"] = st.number_input("H1_F3J3_eff", value=0.0, format="%.5f")
104
+
105
+ with st.expander("F2 Joint Sensors"):
106
+ input_data["H1_F2J1_pos"] = st.number_input("H1_F2J1_pos", value=0.0, format="%.5f")
107
+ input_data["H1_F2J1_vel"] = st.number_input("H1_F2J1_vel", value=0.0, format="%.5f")
108
+ input_data["H1_F2J1_eff"] = st.number_input("H1_F2J1_eff", value=0.0, format="%.5f")
109
+
110
+ input_data["H1_F2J3_pos"] = st.number_input("H1_F2J3_pos", value=0.0, format="%.5f")
111
+ input_data["H1_F2J3_vel"] = st.number_input("H1_F2J3_vel", value=0.0, format="%.5f")
112
+ input_data["H1_F2J3_eff"] = st.number_input("H1_F2J3_eff", value=0.0, format="%.5f")
113
+
114
+ input_data["H1_F2J2_pos"] = st.number_input("H1_F2J2_pos", value=0.0, format="%.5f")
115
+ input_data["H1_F2J2_vel"] = st.number_input("H1_F2J2_vel", value=0.0, format="%.5f")
116
+ input_data["H1_F2J2_eff"] = st.number_input("H1_F2J2_eff", value=0.0, format="%.5f")
117
+
118
+ # Real-time prediction as values change.
119
+ if st.button("Predict"):
120
+ input_df = pd.DataFrame([input_data])
121
+ prediction = model.predict(input_df)
122
+ st.success(f"The predicted grasp robustness is: {prediction[0]:.3f}")
123
+
124
+ # Create a gauge chart (using Plotly) to visualize the prediction
125
+ # (Customize the min/max values based on your domain; below, 0-100 is used as an example.)
126
+ fig_gauge = go.Figure(go.Indicator(
127
+ mode = "gauge+number",
128
+ value = prediction[0],
129
+ title = {'text': "Grasp Robustness"},
130
+ gauge = {'axis': {'range': [None, 100]},
131
+ 'bar': {'color': "darkblue"},
132
+ 'steps' : [
133
+ {'range': [0, 30], 'color': "red"},
134
+ {'range': [30, 70], 'color': "yellow"},
135
+ {'range': [70, 100], 'color': "green"}],
136
+ }
137
+ ))
138
+ st.plotly_chart(fig_gauge, use_container_width=True)
139
+
140
+ # If the model supports feature importances (e.g. RandomForest), display them
141
+ try:
142
+ feature_importances = model.named_steps['model'].feature_importances_
143
+ imp_df = pd.DataFrame({
144
+ 'Feature': FEATURES,
145
+ 'Importance': feature_importances
146
+ }).sort_values(by='Importance', ascending=False)
147
+ st.subheader("Feature Importances")
148
+ fig_imp = px.bar(imp_df, x='Feature', y='Importance', title="Feature Importances")
149
+ st.plotly_chart(fig_imp, use_container_width=True)
150
+ except Exception as ex:
151
+ st.info("Feature importance is not available for this model.")