Commit
·
9bf0efa
1
Parent(s):
31521f5
log something while infer
Browse files- app.py +5 -9
- infer_concat.py +2 -1
app.py
CHANGED
|
@@ -7,8 +7,6 @@ st.set_page_config(layout="wide")
|
|
| 7 |
st.title("Tóm tắt Đa văn bản Tiếng Việt")
|
| 8 |
|
| 9 |
col1, col2 = st.columns([1, 1])
|
| 10 |
-
col2_title, = col2.columns(1)
|
| 11 |
-
col2_chdg, col2_vit5 = col2.columns(2)
|
| 12 |
|
| 13 |
# Initialize session state
|
| 14 |
if 'num_docs' not in st.session_state:
|
|
@@ -38,14 +36,12 @@ category = col1.selectbox("Chọn chủ để của văn bản: ", ['Giáo dục
|
|
| 38 |
def summarize():
|
| 39 |
summ, _ = infer(st.session_state.docs, category)
|
| 40 |
with col2.container():
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
col2_chdg.write("CHDG:")
|
| 46 |
-
col2_chdg.write(summ)
|
| 47 |
summ_vit5 = vit5_infer(st.session_state.docs)
|
| 48 |
-
|
| 49 |
|
| 50 |
if col1.button("Tóm tắt"):
|
| 51 |
summarize()
|
|
|
|
| 7 |
st.title("Tóm tắt Đa văn bản Tiếng Việt")
|
| 8 |
|
| 9 |
col1, col2 = st.columns([1, 1])
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Initialize session state
|
| 12 |
if 'num_docs' not in st.session_state:
|
|
|
|
| 36 |
def summarize():
|
| 37 |
summ, _ = infer(st.session_state.docs, category)
|
| 38 |
with col2.container():
|
| 39 |
+
col2.subheader("Kết quả: ")
|
| 40 |
+
col2.write("\n")
|
| 41 |
+
col2.write("Sử dụng CHDG:")
|
| 42 |
+
col2.write(summ)
|
|
|
|
|
|
|
| 43 |
summ_vit5 = vit5_infer(st.session_state.docs)
|
| 44 |
+
col2.write(summ_vit5)
|
| 45 |
|
| 46 |
if col1.button("Tóm tắt"):
|
| 47 |
summarize()
|
infer_concat.py
CHANGED
|
@@ -82,7 +82,8 @@ def infer_2_hier(model, data_loader, device, tokenizer):
|
|
| 82 |
att_mask = iter['list_att_mask']
|
| 83 |
|
| 84 |
for i in range(len(inputs)):
|
| 85 |
-
|
|
|
|
| 86 |
if torch.all(inputs[i] == 0):
|
| 87 |
# If the input is all zeros, skip this iteration
|
| 88 |
continue
|
|
|
|
| 82 |
att_mask = iter['list_att_mask']
|
| 83 |
|
| 84 |
for i in range(len(inputs)):
|
| 85 |
+
print(f"input {i}")
|
| 86 |
+
# Check if the input tensor is all zeros
|
| 87 |
if torch.all(inputs[i] == 0):
|
| 88 |
# If the input is all zeros, skip this iteration
|
| 89 |
continue
|