update content with the text model from Thomas repository https://huggingface.co/spaces/tombou/frugal-ai-challenge
42b7ac6
| import random | |
| import numpy as np | |
| import pytest | |
| from main import load_config | |
| from tasks.data.data_loaders import TextDataLoader | |
| from tasks.models.text_classifiers import DistilBERTModel, ModelFactory, TextEmbedder, MLModel, EmbeddingMLModel, \ | |
| TfIdfEmbedder | |
| from tasks.utils.evaluation import TextEvaluationRequest | |
| def data_loader(): | |
| # define text request | |
| text_request = TextEvaluationRequest() | |
| return TextDataLoader(text_request, light=True) | |
| def train_dataset(data_loader): | |
| return data_loader.get_train_dataset() | |
| def test_dataset(data_loader): | |
| return data_loader.get_test_dataset() | |
| class TestDistilBERTModel: | |
| def distilBERT_model(self): | |
| config = load_config("config_training_test.json") | |
| return ModelFactory.create_model(config) | |
| def test_trained_distilBERT(self, train_dataset, distilBERT_model, test_dataset): | |
| assert "DistilBERT" in distilBERT_model.description | |
| # train model | |
| distilBERT_model.train(train_dataset) | |
| # inference | |
| predictions = [distilBERT_model.predict(quote) for quote in test_dataset["quote"]] | |
| for prediction in predictions: | |
| assert prediction in range(8) | |
| def test_data_preprocessing(self, train_dataset, distilBERT_model): | |
| pre_processed_data = distilBERT_model.pre_process_data(train_dataset) | |
| assert pre_processed_data is not None | |
| assert pre_processed_data["train"].num_rows == 8 | |
| assert pre_processed_data["test"].num_rows == 2 | |
| for subset in ["train", "test"]: | |
| for feature_name in ['quote', 'label', 'input_ids', 'attention_mask']: | |
| assert feature_name in pre_processed_data[subset].features.keys() | |
| class DummyEmbedder(TextEmbedder): | |
| def encode(self, text: str) -> np.ndarray: | |
| return np.random.rand(42) | |
| class DummyMLModel(MLModel): | |
| def fit(self, X, y): | |
| pass | |
| def predict(self, X): | |
| return random.choice(range(8)) | |
| class TestEmbeddingMLModel: | |
| def embeddingML(self): | |
| config = load_config("config_training_embedding_test.json") | |
| config["model"] = "EmbeddingMLModel" | |
| return ModelFactory.create_model(config) | |
| def test_EmbeddingML(self, train_dataset, embeddingML): | |
| assert "EmbeddingMLModel" in embeddingML.description | |
| # train model | |
| embeddingML.train(train_dataset) | |
| # inference | |
| assert embeddingML.predict("a quote") in range(8) | |
| def test_dummy_train_EmbeddingML(self, train_dataset): | |
| dummy_model = EmbeddingMLModel(embedder=DummyEmbedder(), | |
| ml_model=DummyMLModel()) | |
| dummy_model.train(train_dataset) | |
| assert dummy_model.predict("dummy") in range(8) | |
| class TestEmbedders: | |
| def test_tf_idf(self): | |
| embedder = TfIdfEmbedder() | |
| texts = [ | |
| "hello world", | |
| "world hello", | |
| "yet another text", | |
| "this is a test", | |
| "this one as well" | |
| ] | |
| encoded_texts = embedder.encode(texts) | |
| assert encoded_texts.shape == (5, 11) | |