library(tidyverse)
library(forcats)
library(ggrepel)
library(smodels)
library(cleanNLP)
library(glmnet)
library(Matrix)
library(xgboost)
library(stringi)
library(magrittr)

theme_set(theme_minimal())
options(dplyr.summarise.inform = FALSE)
options(width = 77L)
options(sparse.colnames = TRUE)

Amazon Authorship Data: Music

As in the previous notes, I will use the extra project dataset today:

amazon <- read_csv(file.path("data", sprintf("%s.csv.gz", "cds")))
token <- read_csv(file.path("data", sprintf("%s_token.csv.gz", "cds")))

We will use the same design matrix from the previous notes as well:

X <- token %>%
  cnlp_utils_tf(doc_set = amazon$doc_id,
                min_df = 0.005,
                max_df = 1,
                max_features = 200,
                token_var = "lemma")

It will be useful for reference later to have a sense of how well a penalized regression model performs on this data:

model <- cv.glmnet(
  X[amazon$train_id == "train", ],
  amazon$user_id[amazon$train_id == "train"],
  alpha = 0.9,
  family = "multinomial",
  nfolds = 3,
  trace.it = FALSE,
  relax = FALSE,
  lambda.min.ratio = 0.01,
  nlambda = 100
)

amazon %>%
  mutate(pred = as.vector(predict(model, newx = X, type = "class"))) %>%
  group_by(train_id) %>%
  summarize(class_rate = mean(user_id == pred))
## # A tibble: 2 x 2
##   train_id class_rate
##   <chr>         <dbl>
## 1 train         0.959
## 2 valid         0.869

Today we will continue our study of local models by looking at a very powerful model based on an adaptive form of KNN.

Decision Trees

Before we get to the main model today, we need to understand an intermediate model called a decision tree. There is no good R package for running simple decision trees for multiclass regression, so we will not actually run this model in R. Instead, we will just discuss how the model works. I will explain this in the case that we are predicting a continuous response variable y.

Consider a single feature (column of X) in our dataset. For some cutoff value v, we can split the entire dataset into two parts: data less than v and data greater than v. Within each of these haves, we could assign each observation a prediction based on the average value of the response within each part. By measuring how well this split of the data reduces the RMSE, we can categorize how good a particular split is.

A decision tree starts by considering all possible split values for all of the features. It then picks the best split and groups the data into two buckets based on this most predictive split. Then, it recursively splits each of these subgroups further by finding the best second splits from all possible options. This continues until some stopping criterion is reached (minimum improvement, maximum depth, maximum splits, etc.). The resulting model has the data split into N buckets, with each bucket being given the average value of the training data in each bucket.

The final form of a decision tree looks something like this:

It is an adaptive form of KNN. We use the response vector y to determine how to group the training data into observations that are similar to one another based on the best variables and the best cutoff values, rather than using Euclidean distance directly.

If you want to see a longer visualization of decision trees, I suggest checking out A visual introduction to machine learning. It has the benefit of giving great visual intuition for the model as well as reviewing some keep machine learning concepts.

Gradient Boosted Trees

An individual decision tree is often not a particularly powerful model for complex prediction tasks. A clever way to increase the predictive power of a decision tree is to build a large collection of trees; prediction is then done by predicting on each individual tree and averaging (or taking the majority class) across the whole set. One such model is called a random forest. The one that we will look at here is a gradient boosted tree, or gradient boosted machine (GBM). For a continuous response, the algorithm works like this:

  • select a random subset of the training data
  • build a decision tree with the selected training data to predict the response variable
  • take the predictions from this first tree, multiply by a fixed parameter called the learning rate (say, 0.01), and compute the residuals for the entire training set
  • take another random subset of the training data
  • building a decision tree with the selected training data to predict the residuals from the first model
  • repeat this process many times

If you prefer a mathematical description, if the fitted values from the t-th tree are given by:

\[ \widehat{Y_i^t} \]

Then we train the k-th tree on the values Z given by:

\[ Z_i = Y_i - \eta \cdot \sum_{t = 1}^{k - 1} \widehat{Y_i^t} \]

The parameter eta is the learning rate. If set to one, this is exactly fitting on the residuals of the prior trees. Setting to less than one stops the trees from overfitting on the first few trees.

The details for classification tasks are a bit more complex, but the general ideas are the same. To run gradient boosted trees, we will use the xgboost package, which has a very fast implementation of a learning algorithm. It requires us to convert our categorical user_id variable into an integer starting at 0:

author_set <- unique(amazon$user_id)
y <- (match(amazon$user_id, author_set) - 1L)

Then, we create training and validation sets, which are converted into an efficient data structure by the function xgb.DMatrix.

y_train <- y[amazon$train_id == "train"]
y_valid <- y[amazon$train_id == "valid"]
X_train <- X[amazon$train_id == "train",]
X_valid <- X[amazon$train_id == "valid",]

data_train <- xgb.DMatrix(data = X_train, label = y_train)
data_valid <- xgb.DMatrix(data = X_valid, label = y_valid)

watchlist <- list(train=data_train, valid=data_valid)

Then, we train the actual model using the xgb.train function. We set the depth of the decision tree (here, 3), the learning rate (here, 0.05), and the number of trees to build (here, 10). The number of threads is just a computational details about how many cores to run on your machine. We also have to indicate the number of classes (5) and tell xgboost that we are running a multiclass prediction.

model <- xgb.train(data = data_train,
                   max_depth = 3,
                   eta = 0.05,
                   nthread = 2,
                   nrounds = 10,
                   objective = "multi:softmax",
                   eval_metric = "mlogloss",
                   watchlist = watchlist,
                   verbose = TRUE,
                   num_class = length(author_set))
## [1]  train-mlogloss:2.994378 valid-mlogloss:3.019480 
## [2]  train-mlogloss:2.819161 valid-mlogloss:2.861098 
## [3]  train-mlogloss:2.677038 valid-mlogloss:2.734014 
## [4]  train-mlogloss:2.552745 valid-mlogloss:2.622512 
## [5]  train-mlogloss:2.445944 valid-mlogloss:2.526886 
## [6]  train-mlogloss:2.349151 valid-mlogloss:2.440435 
## [7]  train-mlogloss:2.262303 valid-mlogloss:2.363707 
## [8]  train-mlogloss:2.183780 valid-mlogloss:2.293034 
## [9]  train-mlogloss:2.110232 valid-mlogloss:2.228540 
## [10] train-mlogloss:2.042316 valid-mlogloss:2.168054

You can see that the function prints out the training and validation error rates (1 minus the classification rate) after each step. We can do slightly better by decreasing the learning rate and increasing the number of trees. I will make sure that only some intermediate steps are printed out due to the large number of trees with print_every_n.

model <- xgb.train(data = data_train,
                   max_depth = 3,
                   eta = 0.04,
                   nthread = 2,
                   nrounds = 1500,
                   objective = "multi:softmax",
                   eval_metric = "mlogloss",
                   watchlist = watchlist,
                   verbose = TRUE,
                   print_every_n = 25,
                   num_class = length(author_set))
## [1]  train-mlogloss:3.038783 valid-mlogloss:3.058880 
## [26] train-mlogloss:1.525352 valid-mlogloss:1.719471 
## [51] train-mlogloss:1.004015 valid-mlogloss:1.271631 
## [76] train-mlogloss:0.716291 valid-mlogloss:1.021495 
## [101]    train-mlogloss:0.537913 valid-mlogloss:0.865591 
## [126]    train-mlogloss:0.417408 valid-mlogloss:0.761944 
## [151]    train-mlogloss:0.330678 valid-mlogloss:0.688407 
## [176]    train-mlogloss:0.267053 valid-mlogloss:0.633421 
## [201]    train-mlogloss:0.218353 valid-mlogloss:0.589987 
## [226]    train-mlogloss:0.180151 valid-mlogloss:0.555693 
## [251]    train-mlogloss:0.149932 valid-mlogloss:0.528557 
## [276]    train-mlogloss:0.125521 valid-mlogloss:0.506466 
## [301]    train-mlogloss:0.105331 valid-mlogloss:0.487855 
## [326]    train-mlogloss:0.089070 valid-mlogloss:0.472458 
## [351]    train-mlogloss:0.075822 valid-mlogloss:0.459866 
## [376]    train-mlogloss:0.065094 valid-mlogloss:0.449602 
## [401]    train-mlogloss:0.056115 valid-mlogloss:0.440133 
## [426]    train-mlogloss:0.048670 valid-mlogloss:0.432330 
## [451]    train-mlogloss:0.042498 valid-mlogloss:0.425696 
## [476]    train-mlogloss:0.037360 valid-mlogloss:0.419929 
## [501]    train-mlogloss:0.033155 valid-mlogloss:0.415227 
## [526]    train-mlogloss:0.029655 valid-mlogloss:0.411227 
## [551]    train-mlogloss:0.026683 valid-mlogloss:0.407738 
## [576]    train-mlogloss:0.024153 valid-mlogloss:0.404778 
## [601]    train-mlogloss:0.022059 valid-mlogloss:0.401762 
## [626]    train-mlogloss:0.020240 valid-mlogloss:0.399363 
## [651]    train-mlogloss:0.018671 valid-mlogloss:0.397032 
## [676]    train-mlogloss:0.017324 valid-mlogloss:0.394898 
## [701]    train-mlogloss:0.016177 valid-mlogloss:0.393018 
## [726]    train-mlogloss:0.015147 valid-mlogloss:0.391456 
## [751]    train-mlogloss:0.014251 valid-mlogloss:0.389942 
## [776]    train-mlogloss:0.013448 valid-mlogloss:0.388700 
## [801]    train-mlogloss:0.012743 valid-mlogloss:0.387385 
## [826]    train-mlogloss:0.012115 valid-mlogloss:0.386056 
## [851]    train-mlogloss:0.011547 valid-mlogloss:0.385188 
## [876]    train-mlogloss:0.011042 valid-mlogloss:0.384252 
## [901]    train-mlogloss:0.010593 valid-mlogloss:0.383384 
## [926]    train-mlogloss:0.010185 valid-mlogloss:0.382547 
## [951]    train-mlogloss:0.009809 valid-mlogloss:0.381767 
## [976]    train-mlogloss:0.009465 valid-mlogloss:0.381178 
## [1001]   train-mlogloss:0.009154 valid-mlogloss:0.380621 
## [1026]   train-mlogloss:0.008853 valid-mlogloss:0.380230 
## [1051]   train-mlogloss:0.008588 valid-mlogloss:0.379909 
## [1076]   train-mlogloss:0.008347 valid-mlogloss:0.379612 
## [1101]   train-mlogloss:0.008125 valid-mlogloss:0.379468 
## [1126]   train-mlogloss:0.007924 valid-mlogloss:0.379201 
## [1151]   train-mlogloss:0.007742 valid-mlogloss:0.378864 
## [1176]   train-mlogloss:0.007573 valid-mlogloss:0.378591 
## [1201]   train-mlogloss:0.007416 valid-mlogloss:0.378302 
## [1226]   train-mlogloss:0.007266 valid-mlogloss:0.378164 
## [1251]   train-mlogloss:0.007129 valid-mlogloss:0.377894 
## [1276]   train-mlogloss:0.007003 valid-mlogloss:0.377840 
## [1301]   train-mlogloss:0.006882 valid-mlogloss:0.377730 
## [1326]   train-mlogloss:0.006768 valid-mlogloss:0.377577 
## [1351]   train-mlogloss:0.006661 valid-mlogloss:0.377562 
## [1376]   train-mlogloss:0.006559 valid-mlogloss:0.377572 
## [1401]   train-mlogloss:0.006463 valid-mlogloss:0.377573 
## [1426]   train-mlogloss:0.006377 valid-mlogloss:0.377521 
## [1451]   train-mlogloss:0.006295 valid-mlogloss:0.377530 
## [1476]   train-mlogloss:0.006219 valid-mlogloss:0.377596 
## [1500]   train-mlogloss:0.006148 valid-mlogloss:0.377568

Let’s see how well this model predicts the classes in our data:

y_hat <- author_set[predict(model, newdata = X) + 1]

amazon %>%
  mutate(author_pred = y_hat) %>%
  group_by(train_id) %>%
  summarize(class_rate = mean(user_id == author_pred))
## # A tibble: 2 x 2
##   train_id class_rate
##   <chr>         <dbl>
## 1 train         1    
## 2 valid         0.890

It manages about 89%, significantly better than before with the KNN model we built in the previous section and slightly better even than the penalized regression model. Another useful benefit of the gradient boosted trees over KNN is that the former also provides variable importance scores:

importance_matrix <- xgb.importance(model = model)
importance_matrix
##      Feature         Gain        Cover    Frequency
##   1:       " 8.460528e-02 5.610164e-02 3.850989e-02
##   2:    \n\n 7.452512e-02 5.066716e-02 3.074273e-02
##   3:       , 5.700294e-02 4.258762e-02 4.168348e-02
##   4:       . 4.842310e-02 3.103626e-02 2.805798e-02
##   5:       ; 4.820510e-02 1.886420e-02 1.431609e-02
##  ---                                               
## 196:   where 4.286685e-05 1.488001e-04 2.483007e-04
## 197:     own 3.950088e-05 1.419483e-04 1.862255e-04
## 198:   right 1.432234e-05 1.680970e-05 5.431578e-05
## 199:    long 1.415058e-05 3.087550e-05 1.241503e-04
## 200:     off 8.793385e-06 2.562117e-05 1.008722e-04

Finally, let’s create a confusion matrix for our model:

y_hat_id <- (predict(model, newdata = X) + 1)

amazon %>%
  mutate(user_num = match(user_id, author_set)) %>%
  mutate(user_num_pred = y_hat_id) %>%
  filter(train_id == "valid") %>%
  select(y = user_num, yhat = user_num_pred) %>%
  table()
##     yhat
## y      1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
##   1  119   0   0   0   0   0   0   0   0   0   0   1   0   0   0   2   1   0
##   2    0 106   0   2   0   3   0   1   0   2   0   0   3   0   0   0   1   1
##   3    0   1 115   0   0   4   0   0   0   1   0   0   0   3   0   0   0   0
##   4    0   0   0 102   2   0   4   0   0   1   0   0   0   0   0   0   0   1
##   5    0   2   0   6 102   0   3   0   0   1   0   0   0   0   0   0   0   3
##   6    0   1   2   1   0 107   0   2   0   0   0   0   6   0   0   0   0   0
##   7    1   0   0   5   1   1 103   0   0   1   2   1   0   0   0   4   1   1
##   8    0   1   2   0   0   5   0 105   0   1   1   0   5   1   0   0   0   0
##   9    0   0   1   0   0   0   0   0 119   0   0   0   0   0   0   3   0   2
##   10   0   0   0   1   2   0   3   1   0 114   0   1   2   0   0   0   0   0
##   11   0   1   0   0   0   0   1   0   0   0 120   0   0   0   0   0   0   1
##   12   0   0   0   0   0   0   2   0   0   0   0 118   0   0   0   1   2   1
##   13   0   5   0   0   0   4   0   5   0   0   0   0 105   1   0   1   0   0
##   14   0   0   2   0   0   6   0   0   0   0   0   0   1 116   0   0   0   0
##   15   0   0   0   0   0   0   0   0   0   1   0   1   1   0 118   1   1   1
##   16   0   0   0   0   1   0   3   0   0   1   0   0   0   0   0 117   0   1
##   17   1   0   0   0   0   1   0   0   0   0   0   1   0   0   0   0 119   0
##   18   2   1   1   3   2   1  10   0   0   1   2   0   0   0   0   3   0  98
##   19   0   3   0   0   1   1   1   0   0   0   0   0   0   0   0   1   2   2
##   20   1   0   0   0   0   0   0   0   0   1   2   2   1   0   0   0   5   1
##   21   0   2   0   5   0   0   0   0   0   0   0   0   0   0   0   0   1   2
##   22   0   2   0   6   2   0   1   0   0   6   2   0   0   0   0   0   2   0
##   23   0   1   0   0   0   3   0   6   0   0   0   0   2   1   0   0   0   0
##   24   1   0   0   0   0   0   0   2   0   0   2   0   1   0   0   0   0   0
##   25   0   0   0   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0
##     yhat
## y     19  20  21  22  23  24  25
##   1    0   1   0   0   1   0   0
##   2    4   0   0   1   1   0   0
##   3    0   0   0   0   1   0   0
##   4    1   0   8   5   0   0   1
##   5    6   0   1   1   0   0   0
##   6    0   0   2   0   3   1   0
##   7    0   1   1   2   0   0   0
##   8    0   0   1   0   3   0   0
##   9    0   0   0   0   0   0   0
##   10   0   0   1   0   0   0   0
##   11   0   1   0   0   0   1   0
##   12   0   0   0   1   0   0   0
##   13   0   2   0   0   2   0   0
##   14   0   0   0   0   0   0   0
##   15   0   1   0   0   0   0   0
##   16   0   2   0   0   0   0   0
##   17   0   1   2   0   0   0   0
##   18   0   1   0   0   0   0   0
##   19 109   1   2   1   0   1   0
##   20   0 112   0   0   0   0   0
##   21   0   0 113   2   0   0   0
##   22   0   0   4 100   0   0   0
##   23   0   1   0   0 110   0   1
##   24   1   0   1   1   1 115   0
##   25   3   0   0   3   0   0 118

Here we see that on the validation set, there are no commonly confused reviewers and that the model does fairly well across all authors.

Thoughts on local models

We will continue to make the heaviest use of regression-based models, but the two local models we today (particularly gradient boosted trees) will be useful to augment these, particularly when looking at features that benefit from determining interaction terms, such as POS N-grams. In machine learning competitions, particularly those on non-image and non-textual data, gradient boosted trees are very often the winning model. They can, however, take a bit of tuning to get right. Usually this consists in slowly lowering the learning rate and increasing the number of trees until the model saturates (no longer improves).