| # A Multi-task learning model with two prediction heads | |
| * One prediction head classifies between keyword sentences vs statements/questions | |
| * Other prediction head corresponds to classifier for statements vs questions | |
| ## Scores | |
| ##### Spaadia SQuaD Test acc: **0.9891** | |
| ##### Quora Keyword Pairs Test acc: **0.98048** | |
| ## Datasets: | |
| Quora Keyword Pairs: https://www.kaggle.com/stefanondisponibile/quora-question-keyword-pairs | |
| Spaadia SQuaD pairs: https://www.kaggle.com/shahrukhkhan/questions-vs-statementsclassificationdataset | |
| ## Article | |
| [Medium article](https://medium.com/@shahrukhx01/multi-task-learning-with-transformers-part-1-multi-prediction-heads-b7001cf014bf) | |
| ## Demo Notebook | |
| [Colab Notebook Multi-task Query classifiers](https://colab.research.google.com/drive/1R7WcLHxDsVvZXPhr5HBgIWa3BlSZKY6p?usp=sharing) | |
| ## Clone the model repo | |
| ```bash | |
| git clone https://huggingface.co/shahrukhx01/bert-multitask-query-classifiers | |
| ``` | |
| ```python | |
| %cd bert-multitask-query-classifiers/ | |
| ``` | |
| ## Load model | |
| ```python | |
| from multitask_model import BertForSequenceClassification | |
| from transformers import AutoTokenizer | |
| import torch | |
| model = BertForSequenceClassification.from_pretrained( | |
| "shahrukhx01/bert-multitask-query-classifiers", | |
| task_labels_map={"quora_keyword_pairs": 2, "spaadia_squad_pairs": 2}, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("shahrukhx01/bert-multitask-query-classifiers") | |
| ``` | |
| ## Run inference on both Tasks | |
| ```python | |
| from multitask_model import BertForSequenceClassification | |
| from transformers import AutoTokenizer | |
| import torch | |
| model = BertForSequenceClassification.from_pretrained( | |
| "shahrukhx01/bert-multitask-query-classifiers", | |
| task_labels_map={"quora_keyword_pairs": 2, "spaadia_squad_pairs": 2}, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("shahrukhx01/bert-multitask-query-classifiers") | |
| ## Keyword vs Statement/Question Classifier | |
| input = ["keyword query", "is this a keyword query?"] | |
| task_name="quora_keyword_pairs" | |
| sequence = tokenizer(input, padding=True, return_tensors="pt")['input_ids'] | |
| logits = model(sequence, task_name=task_name)[0] | |
| predictions = torch.argmax(torch.softmax(logits, dim=1).detach().cpu(), axis=1) | |
| for input, prediction in zip(input, predictions): | |
| print(f"task: {task_name}, input: {input} \n prediction=> {prediction}") | |
| print() | |
| ## Statement vs Question Classifier | |
| input = ["where is berlin?", "is this a keyword query?", "Berlin is in Germany."] | |
| task_name="spaadia_squad_pairs" | |
| sequence = tokenizer(input, padding=True, return_tensors="pt")['input_ids'] | |
| logits = model(sequence, task_name=task_name)[0] | |
| predictions = torch.argmax(torch.softmax(logits, dim=1).detach().cpu(), axis=1) | |
| for input, prediction in zip(input, predictions): | |
| print(f"task: {task_name}, input: {input} \n prediction=> {prediction}") | |
| print() | |
| ``` |