Getting Started

Before running this notebook, select “Session > Restart R and Clear Output” in the menu above to start a new R session. You may also have to hit the broom in the upper right-hand corner of the window. This will clear any old data sets and give us a blank slate to start with.

After starting a new session, run the following code chunk to load the libraries and data that we will be working with today.

I have set the options message=FALSE and echo=FALSE to avoid cluttering your solutions with all the output from this code.

Reading the Data

Today we are going to look at a dataset of short texts taken from a set of 5 British authors:

docs <- read_csv("../data/stylo_uk.csv")
anno <- read_csv("../data/stylo_uk_token.csv.gz")

The prediction task is to determine the identity of the author based on the text.


Baseline Elastic Net

Start by fitting an elastic net model with all of the default parameters. We will need this several different times, so save the model with a unique name such as model_enet:

# Question 01
model_enet <- dsst_enet_build(anno, docs)
## as(<dgCMatrix>, "dgTMatrix") is deprecated since Matrix 1.5-0; do as(., "TsparseMatrix") instead

Compute the error rate of the elastic net model on the training and validation sets.

# Question 02
model_enet$docs %>%
  group_by(train_id) %>%
  summarize(erate = mean(label != pred_label))
## # A tibble: 2 × 2
##   train_id erate
##   <chr>    <dbl>
## 1 train    0.129
## 2 valid    0.157

Baseline Gradient Boosted Trees

Now, fit a gradient boosted tree model using the default parameters:

# Question 03
model_gbm <- dsst_gbm_build(anno, docs)

Compute the error rate of this model:

# Question 04
model_gbm$docs %>%
  group_by(train_id) %>%
  summarize(erate = mean(label != pred_label))
## # A tibble: 2 × 2
##   train_id erate
##   <chr>    <dbl>
## 1 train    0.383
## 2 valid    0.393

Note that this model is not particularly good. The problem is that we need a significantly larger set of trees. Create a new gradient boosted tree model using 1000 trees:

# Question 05
model_gbm <- dsst_gbm_build(anno, docs, nrounds = 1000)

Compute the error rate now for the training and validation sets. Do this below the normal way, but note that they should also be given in the output printed above.

# Question 06
model_gbm$docs %>%
  group_by(train_id) %>%
  summarize(erate = mean(label != pred_label))
## # A tibble: 2 × 2
##   train_id  erate
##   <chr>     <dbl>
## 1 train    0.0641
## 2 valid    0.134

How does the error rate compare to the elastic net model? Make sure to look at both the training and validation results.

Variable Importance

Now, look at the coefficients for the elastic net model. There are a lot, so you may want to limit the lambda number to something around 30.

# Question 07
dsst_coef(model_enet$model, lambda_num = 30)
## 17 x 6 sparse Matrix of class "dgCMatrix"
##                  Austen    Dickens       Doyle    Stevenson        Wells
## (Intercept) -0.12749647 -0.1597772  0.15570271 -0.004719147  0.136290059
## '            .           0.3009869  .           .            .          
## ;            .           .         -0.06233148  0.140627172 -0.001383177
## Graham       .           .          .           .            0.559097593
## Catherine    0.41900412  .          .           .            .          
## not          0.08767208  .          .           .            .          
## Elinor       0.36386735  .          .           .            .          
## be           0.02473169  .          .           .            .          
## sister       0.24203049  .          .           .            .          
## Holmes       .           .          0.46787891  .            .          
## Miss         0.13067056  .          .           .            .          
## Anne         0.11738379  .          .           .            .          
## very         0.03034543  .          .           .            .          
## Emma         0.11215630  .          .           .            .          
## Mrs.         0.04832228  .          .           .            .          
## Elizabeth    0.05997300  .          .           .            .          
## -PRON-       .           .          .           .           -0.000688534
##             MLN
## (Intercept)   .
## '             2
## ;            19
## Graham       21
## Catherine    22
## not          22
## Elinor       23
## be           23
## sister       24
## Holmes       25
## Miss         25
## Anne         28
## very         28
## Emma         29
## Mrs.         29
## Elizabeth    30
## -PRON-       30

And, for comparison, look at the importance scores for the gradient boosted trees:

# Question 08
## # A tibble: 946 × 4
##    feature       gain   cover frequency
##    <chr>        <dbl>   <dbl>     <dbl>
##  1 ";"         0.0864 0.0263    0.0200 
##  2 "'"         0.0841 0.0194    0.0195 
##  3 "\""        0.0402 0.0146    0.0204 
##  4 "Graham"    0.0202 0.00670   0.00406
##  5 "Catherine" 0.0200 0.00708   0.00642
##  6 "."         0.0174 0.00855   0.00829
##  7 "the"       0.0172 0.00937   0.0214 
##  8 ","         0.0156 0.0193    0.0235 
##  9 "Holmes"    0.0152 0.00864   0.00618
## 10 "-PRON-"    0.0144 0.00800   0.0140 
## # … with 936 more rows

How do the two lists compare to one another? Is there a lot of overlap? Is there any particular pattern to the differences?

Data Visualization

For the first visualization in today’s notebook, take the elastic net model you built above (model_net) and select the attached docs table. Filter to just include the validation set and arrange the data from the highest value of pred_value (the probability that the model thinks it is correct). Then, slice to include just the first 1000 rows and compute the error rate. Note that this should be much lower than the rate you had above. Play around with the number of rows that are included in the summary and try to understand the pattern.

# Question 09
model_enet$docs %>%
  filter(train_id == "valid") %>%
  arrange(desc(pred_value)) %>%
  slice(1:1000) %>%
  summarize(erate = mean(label != pred_label))
## # A tibble: 1 × 1
##   erate
##   <dbl>
## 1 0.006

Now, we want to create a plot that visualizes the number of points included in the summary. To do this, replace the summary command you have above with a mutate verb, using the function cummean in place of mean. Also, create a variable with the command prop = row_number() / n(). Take the result and plot the proportion of the corpus predicted on the x-axis and the error rate on the y-axis using a line geometry.

# Question 10
model_enet$docs %>%
  filter(train_id == "valid") %>%
  arrange(desc(pred_value)) %>%
  mutate(erate = cummean(label != pred_label), prop = row_number() / n()) %>%
  ggplot(aes(prop, erate)) +

Now, take the predictions from the top 4500 most confident predictions in the validation data (it’s about 75% of the values) and build a confusion matrix. Take a moment to understand what it’s telling us, keeping in mind that the data are balanced between the labels.

# Question 11
model_enet$docs %>%
  filter(train_id == "valid") %>%
  arrange(desc(pred_value)) %>%
  slice(1:4500) %>%
  select(label, pred_label) %>%
##            pred_label
## label       Austen Dickens Doyle Stevenson Wells
##   Austen       977      10     3        18     0
##   Dickens        2     938     5        42     1
##   Doyle          3      42   786        11    20
##   Stevenson      6      10    11       746    17
##   Wells          1      11     7        18   815

Finally, repeat the last question on Notebook 5 using just the validation data from the dataset here. Order the authors from the highest proportion of verbs used to the lowest.

# Question 12
anno %>%
  group_by(doc_id, sid) %>%
  summarize(n_verb = sum(upos == "VERB")) %>%
  left_join(docs, by = "doc_id") %>%
  filter(train_id == "valid") %>%
  group_by(label) %>%
  summarize(avg = mean(n_verb), s = sd(n_verb) / sqrt(n())) %>%
  arrange(desc(avg)) %>%
  mutate(label = fct_inorder(label)) %>%
  ggplot(aes(label, avg)) +
      ymin = avg - 2 * s,  ymax = avg + 2 * s

Using the intersection of the confidence intervals (it’s a rough measurement, but a reasonably accurate measurement of statistical significant), you should see that the authors fall into three rough buckets.