This is the second part of a two-part blog series, where we explore how to develop the machine learning model that powers our solution.
In the first part we presented an end-to-end, AI-powered solution architecture to automate support tickets classification and discussed key details highlighting the usage of serverless and PaaS services in Microsoft Azure. This approach allows a rapid implementation, so we can focus our effort on solving the business problem.
The AI piece in our solution is a machine learning model to automatically categorize text information extracted from support tickets, matching a given category with the correct recipient. That AI piece in this and similar solutions is present in many industries and business scenarios and we can frame it as a text classification problem.
In this second part, we dive deep into the details of developing that AI piece as a machine learning model on Azure ML using advanced deep learning techniques for NLP (Natural Language Processing). We explore the usage of pre-trained state-of-the-art deep learning models and how we can leverage such models to solve our specific NLP task. For our implementation, we choose the BERT model, due to its popularity, performance and availability of open-source implementations.
What is BERT
BERT (Bidirectional Encoder Representations from Transformers) is a language representation model developed by Google Research. Alongside other models such as ELMo and OpenAI GPT, BERT is a successful example from the most recent generation of deep learning-based models for NLP which are pre-trained in an unsupervised way using a very large text corpus. The learned language representation is powerful enough, that it can be used in several different downstream tasks with minimal architecture modifications.
How to use BERT for text classification
We can use a pre-trained BERT model and then leverage transfer learning as a technique to solve specific NLP tasks in specific domains, such as text classification of support tickets in a specific business domain.
Transfer learning is key here because training BERT from scratch is very hard. The original BERT model was pre-trained with a combined text corpus containing about 3.3 billion words. The pre-training takes about 4 days to complete on 16 TPU chips, whereas most fine-tuning procedures from pre-trained models will take about one to few hours to run on a single GPU.
This process can be implemented with the following tasks:
- Choose a pre-trained BERT model according to the language needs for our task.
- Modify the pre-trained model architecture to fit our specific task.
- Prepare the training data according to our specific task.
- Fine-tune the modified pre-trained model by further training it using our own dataset.
Fig. 1: high-level overview of BERT pre-training and fine-tuning tasks
Now let's get into more details about each of the four tasks listed above.
Choose a pre-trained BERT model according to the language needs for our task
BERT is a very popular model and the original implementation was open sourced by Google. Therefore, there are several pre-trained models and extension packages readily available. Here we use the popular transformers package from Hugging Face, which provides pre-trained BERT models of various sizes and from several languages.
For our task we choose the distilbert-base-uncased, which is pre-trained on the same data used to pre-train BERT (concatenation of the Toronto Book Corpus and full English Wikipedia) using a technique known as knowledge distillation with the supervision of the bert-base-uncased version of BERT. The model has 6 layers, 768 dimension and 12 heads, totalizing 66M parameters. It can be trained 60% faster than the original uncased base BERT, which has 12 layers and approximately 110M parameters, while preserving 97% of the model performance.
Modify the pre-trained model architecture to fit our specific task
BERT was designed to be pre-trained in an unsupervised way to perform two tasks: masked language modeling and next sentence prediction. In the masked language modeling, some percentage of the input tokens are masked at random and the model is trained to predict those masked tokens at the output. For the next sentence prediction task, the model is trained for a binary classification task by choosing pairs of sentences A and B for each pretraining example, so that 50% of the time B is the actual next sentence that follows A (labeled as IsNext), and 50% of the time it is a random sentence from the corpus (labeled as NotNext).
Having a single architecture to accommodate for those pre-training tasks described above, BERT can then be fine-tuned for a variety of downstream NLP tasks involving single sentences or pair of sentences, such as text classification, NER (Named Entity Recognition), question answering, and others.
In our specific task, we need to modify the base BERT model to perform text classification. This can be done by feeding the first output token of the last transformer layer into a classifier of our choice. That first token at the output layer is an aggregate sequence representation of an entire sequence that is fed as input to the model.
The package we use in our implementation already has several modified BERT models to perform different tasks, including one for text classification, so we don't need to plug a custom classifier.
Fig. 2: high-level overview of the modified BERT model to perform text classification
Prepare the training data according to our specific task
In order to reduce the input data dimensionality and work with a smaller vocabulary, we can pre-process the input text according to common NLP practices such as removing punctuation, removing stop words, and normalizing all text to lower case.
To work with BERT, we also need to prepare our data according to what the model architecture expects. For the text classification task, the input text needs to be prepared as following:
- Tokenize text sequences according to the WordPiece. In this specification, tokens can represent words, sub-words, or even single characters. For example, the word ‘requisitions' is tokenized as [‘re', ‘##qui', ‘##sit', ‘##ions']. Here, the two hash signs preceding some sub-words denote that a sub-word is part of a larger word and preceded by another sub-word.
- Truncate and pad your sequences to the maximum sequence length suitable for your task, respecting the hard limit of 512 tokens per sequence according to the BERT specification.
- Annotate your tokenized sequences with the special tokens ‘[CLS]' and ‘[SEP]' to mark the beginning and end of each sequence, respectively.
- Convert your tokenized sequences into sequences of indices that are specific for the BERT vocabulary.
- Create a sequence mask to indicate which elements in a sequence are tokens and which are paddings.
- Create the numeric sequential array to be used for the positional embeddings, which is required by the transformer
Fig. 3: overview of the expected input data for the BERT model
notice that for the text classification task we don't need the segment embeddings
Luckily, the package that we use in our scenario has a class that implements all we need to generate the tokenized, indexed input data in the expected format. It also deals internally with the requirements for the embedding layer to take the positional information into account.
Fine-tune the modified pre-trained model by further training it using our own dataset
After choosing and instantiating a pre-trained BERT model and preparing our data for model training and validation, we can finally perform the model fine-tuning. This is very similar to training a model from scratch, except usually for fine-tuning we have far less training data, less hyperparameters to tune, and we can train for a couple of epochs only in order to get good results.
There are several strategies to perform fine-tuning. As a rule-of-thumb, if our dataset represents a specific language domain (as in our case, for support tickets in a specific business domain), we usually perform fine-tuning for all model parameters. On the other hand, if we don't have a specific language domain (e.g. classifying sentiment from tweets), we could freeze all but the classification layer and train only for that last layer. This would be equivalent to use the pre-trained BERT model as a feature extractor. Notice that in this case it would be preferable to use all token outputs from the model as features for the classifier, instead of using only the [CLS] token output.
The original BERT paper gives some guidance regarding batch size, learning rate, number of epochs and dropout rate as a starting point for fine-tuning.
Here we use the Azure ML platform and associated SDK to run the code for fine-tuning according to the steps described above. We take advantage of some Azure ML artifacts, such as the PyTorch Estimator, to facilitate distributed training of PyTorch models on multiple GPUs running on Azure, and the Hyperdrive to perform model hyperparameter search. In this way, we don't need to worry about creating and configure clusters, writing CUDA and MPI, or write code to perform multiple runs and find the best hyperparameters. All of these are taken care by the Azure ML platform.
The code used in this implementation and corresponding instructions for creating the Azure ML environment are available in this GitHub repository. For this implementation, in order to make the code public accessible, we used data from the Consumer Complaint Database, provided by the U.S. Government. This data provides a collection of complaints about consumer financial products and services, where we have the textual information describing a complaint and the corresponding financial product or service category. Therefore, it is very similar to what we would find in a call center of a financial institution for example.
Data preparation, model training, and results in the example implementation
The original consumer complaints dataset has approximately 1.5 million rows. From all available columns, we used only the Consumer Complaint Narrative as feature, which is the textual information describing a complaint. Our label is the Product column, which indicates the most relevant financial products or services associated with a complaint and originally has 18 distinct values.
To make training faster, we sampled about 10% of the original data. This corresponds to approximately 47 thousand rows. From them, we split the data into 80% for training, 10% for validation, and 10% for testing.
We also aggregated the product column into 6 distinct categories: Credit Reporting, Debt , Mortgage, Card Services, Loans, and Banking Services, because there are several overlapping categories and some of them are very underrepresented in the original data.
We then fine-tuned the entire distilled BERT-based model (all model parameters in all layers, plus a classification layer) for 4 epochs only, and obtained an accuracy equal or greater than 80% for all 6 distinct categories.
Fig. 4: losses (cross-entropy loss) on the training data (left) and on the validation data (right)
Fig. 5: classification report (left) and confusion matrix (right) on the test data
In the first part of this blog series, we explored the architecture of an end-to-end AI-powered solution to automate support tickets classification on Azure. The AI piece for that and similar solutions is present in many industries and business scenarios and we can frame it as a text classification problem.
In this post we explored the development of that AI piece. We implemented it as a machine learning model for text classification, using state-of-the-art deep learning techniques that we exploited by leveraging transfer learning, through the fine-tuning of a distilled BERT-based model.
We also presented a high-level overview of BERT and how we used its power to create the AI piece in our solution. The solution was developed using the Azure Machine Learning Platform, where we started with a pre-trained BERT model which was modified for text classification, then performed the fine-tuning and automatic model hyperparameter search in a distributed manner, on a remote GPU cluster managed by Azure ML.
The code used in this implementation and corresponding instructions for creating the Azure ML environment are available in this GitHub repository.