fixed viz function parameter error
Browse files- app.py +1 -1
- attention_viz.py +1 -1
app.py
CHANGED
|
@@ -40,7 +40,7 @@ def infer_bart(context, task_type, decoding_type_str):
|
|
| 40 |
if decoding_type_str =='default':
|
| 41 |
response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
|
| 42 |
else:
|
| 43 |
-
response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True)
|
| 44 |
elif Data_Type(task_type) == Data_Type.ELI5:
|
| 45 |
response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
|
| 46 |
else:
|
|
|
|
| 40 |
if decoding_type_str =='default':
|
| 41 |
response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
|
| 42 |
else:
|
| 43 |
+
response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True, max_concepts=2)
|
| 44 |
elif Data_Type(task_type) == Data_Type.ELI5:
|
| 45 |
response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
|
| 46 |
else:
|
attention_viz.py
CHANGED
|
@@ -170,7 +170,7 @@ class AttentionVisualizer:
|
|
| 170 |
plt.title(title)
|
| 171 |
plt.show()
|
| 172 |
|
| 173 |
-
def plot_attn_lines_concepts_ids(title, examples, layer, head,
|
| 174 |
relations_total, width=3, example_sep=3,
|
| 175 |
word_height=1, pad=0.1, hide_sep=False):
|
| 176 |
# examples -> {'words': tokens, 'attentions': [layer][head]}
|
|
|
|
| 170 |
plt.title(title)
|
| 171 |
plt.show()
|
| 172 |
|
| 173 |
+
def plot_attn_lines_concepts_ids(self, title, examples, layer, head,
|
| 174 |
relations_total, width=3, example_sep=3,
|
| 175 |
word_height=1, pad=0.1, hide_sep=False):
|
| 176 |
# examples -> {'words': tokens, 'attentions': [layer][head]}
|