Deep NLP Classifiers — CNN vs. RNN

In this blog, we will train multiple deep learning NLP classifiers to predict which kind of forum a post most likely comes from — dieting, eating disorders, general health or irrelevant forums.

These models are baseline prompts waiting for better labeled training data. Please click here to help: https://www.zooniverse.org/projects/joyqiu/edetectives.

Transfer GloVe embeddings

GloVe word embedding is a pre-trained vector representation for words based on an unsupervised learning algorithm. Training is performed on aggregated global word correlation statistics from a corpus.

Similar to transfer learning in neural network architecture, we initiate the embedding layer with GloVe embedding values, can set this layer trainable so that those ‘weights’ get updated throughout the whole training process. Compared to random initiation, it helps embodying linear substructures of the word vector space regarding to human language, but also keeps the network context-sensitive while doing this particular classification task.

CNN Classifier

Word embeddings are numerical representation for words, we can treat them as vectors and use a convolutional neural network architecture to train the classifier. We trained two types of CNN, one with regularization, one without.

Figure 1

Figure 1 shows the training result of the CNN without regularization. The loss plot suggests overfitting since the beginning and a regularized CNN might be a better option. With strategy of early stopping, we save the ‘best’ weights of the model, which gives us 2.2329 loss and 0.7821 accuracy.

The confusion matrix indicates this network does a good job in distinguishing eating disorders and irrelevant posts from other categories, but performs not as good in classifying general health posts and dieting posts.

Figure 2

Figure 2 is the training results after adding panelizing layers to the same CNN architecture above, in order to prevent overfitting. We can see the validation loss is more stable and drops after a few epochs. Again, using early stopping, we have the ‘best’ weights.

It gives lower test loss value 1.96 and improves the performance in identifying irrelevant, eating disorders and dieting post, but more posts from general health forums are classified as eating disorders and dieting posts.

RNN Classifier

Recurrent Neural Network treat input word embedding as sequences that contains temporal information. We build a RNN architecture with one bidirectional LSTM layer, two LSTM layers, and several dense layers. The depth of this architecture is not as deep as CNN, but it consists of more parameters according to the nature of LSTM (more ‘gate’ nodes).

Figure 3

The model performance improves in two ways: the test loss drops to 1.2254 and test accuracy goes up; and the model trained longer before validation loss began to increase, in other words, we are able to learn more textual information from the dataset before overfitting.

From confusion matrix, this RNN architecture predicts more dieting posts belonging to eating disorder and general health forums. However, for posts comes from eating disorders forums, it classifies over 91 percent of them accurately.

Comments

In general, RNN does performs better than CNN in this NLP task. Temporal information seems to be important in identifying eating disorders posts.

Meanwhile, RNN predicts more dieting related post to be from eating disorders forums or general health related forums. This might suggest there are posts have higher risk of eating disorders and some posts have more neutral attitudes in terms of public health.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: