## -----------------------------------------------------------------------------
# remotes::install_github('dhicks/tmfast.realbooks')

## -----------------------------------------------------------------------------
# install.packages('tmfast.realbooks', repos = 'https://dhicks.github.io/drat/')

## -----------------------------------------------------------------------------
knitr::opts_chunk$set(
      eval = requireNamespace('tmfast.realbooks', quietly = TRUE)
)
if (!requireNamespace('tmfast.realbooks', quietly = TRUE)) {
      warning('Data package not available; skipping execution')
}

## -----------------------------------------------------------------------------
library(dplyr)
library(tidyr)
library(tibble)
library(ggplot2)
theme_set(theme_minimal())
library(ggbeeswarm)
library(tictoc)

library(tidytext)
library(tmfast)

library(tmfast.realbooks)

## -----------------------------------------------------------------------------
data(corpus_raw)

## ~17 sec
tic()
dataf = corpus_raw |>
      unnest_tokens(term, text, token = 'words') |>
      count(gutenberg_id, author, title, term)
toc()

meta_df = distinct(dataf, author, title)
dataf

## -----------------------------------------------------------------------------
distinct(dataf, author, title) |>
      count(author)

with(dataf, n_distinct(author, title))

## -----------------------------------------------------------------------------
dataf |>
      group_by(author, title) |>
      summarize(n = sum(n)) |>
      summarize(
            min = min(n),
            median = median(n),
            max = max(n),
            total = sum(n)
      ) |>
      arrange(desc(total))

dataf |>
      group_by(author, title) |>
      summarize(n = sum(n)) |>
      ggplot(aes(author, n, color = author)) +
      geom_boxplot() +
      geom_beeswarm() +
      scale_color_discrete(guide = 'none') +
      coord_flip()

## -----------------------------------------------------------------------------
vocab_size = n_distinct(dataf$author, dataf$title) * 10
vocab_size

## -----------------------------------------------------------------------------
tic()
H_df = ndH(dataf, title, term, n)
R_df = ndR(dataf, title, term, n) |>
      mutate(in_vocab = rank(desc(ndR)) <= vocab_size)
toc()
H_df
R_df

## -----------------------------------------------------------------------------
inner_join(H_df, R_df, by = 'term') |>
      ggplot(aes(ndH, ndR, color = in_vocab)) +
      geom_point(aes(alpha = rank(desc(ndH)) <= vocab_size))

inner_join(H_df, R_df, by = 'term') |>
      mutate(ndH_rank = rank(desc(ndH)), ndR_rank = rank(desc(ndR))) |>
      ggplot(aes(ndH_rank, ndR_rank, color = in_vocab)) +
      geom_point(aes(alpha = ndH_rank <= vocab_size)) +
      scale_x_log10() +
      scale_y_log10()

## -----------------------------------------------------------------------------
vocab = R_df |>
      filter(in_vocab) |>
      pull(term)
head(vocab, 50)

## -----------------------------------------------------------------------------
dataf |>
      filter(term %in% vocab) |>
      group_by(author, title) |>
      summarize(n = sum(n)) |>
      ggplot(aes(author, n, color = author)) +
      geom_boxplot() +
      geom_beeswarm() +
      scale_color_discrete(guide = 'none') +
      coord_flip()

## -----------------------------------------------------------------------------
dtm = dataf |>
      filter(term %in% vocab) |>
      mutate(n = log1p(n))

n_authors = n_distinct(dataf$author)

tic()
fitted_tmf = tmfast(
      dtm,
      n = c(5, n_authors, n_authors + 5),
      row = title,
      column = term,
      value = n
)
toc()

screeplot(fitted_tmf, npcs = n_authors + 5)

## -----------------------------------------------------------------------------
tidy(fitted_tmf, n_authors, 'gamma') |>
      left_join(meta_df, by = c('document' = 'title')) |>
      ggplot(aes(document, gamma, fill = topic)) +
      geom_col() +
      facet_wrap(vars(author), scales = 'free_x') +
      scale_x_discrete(guide = 'none') +
      scale_fill_viridis_d()

## -----------------------------------------------------------------------------
alpha = peak_alpha(n_authors, 1, peak = .8, scale = 10)
target_entropy = expected_entropy(alpha)
target_entropy

exponent = tidy(fitted_tmf, n_authors, 'gamma') |>
      target_power(document, gamma, target_entropy)
exponent

tidy(fitted_tmf, n_authors, 'gamma', exponent = exponent) |>
      left_join(meta_df, by = c('document' = 'title')) |>
      ggplot(aes(document, gamma, fill = topic)) +
      geom_col() +
      facet_wrap(vars(author), scales = 'free_x') +
      scale_x_discrete(guide = 'none') +
      scale_fill_viridis_d()

tidy(fitted_tmf, n_authors, 'gamma', exponent = exponent) |>
      left_join(meta_df, by = c('document' = 'title')) |>
      ggplot(aes(topic, document, fill = gamma)) +
      geom_raster() +
      facet_grid(rows = vars(author), scales = 'free_y', switch = 'y') +
      scale_y_discrete(guide = 'none') +
      theme(strip.text.y.left = element_text(angle = 0))

## -----------------------------------------------------------------------------
topic_author = tribble(
      ~topic , ~authors                    ,
      'V01'  , 'C. Brontë, Shelley'        ,
      'V02'  , 'Dickens and Lovecraft'     ,
      'V03'  , 'Dickens'                   ,
      'V04'  , 'Austen and Shelley'        ,
      'V05'  , 'Alcott'                    ,
      'V06'  , 'Wells'                     ,
      'V07'  , 'Dickens, Eliot, Lovecraft' ,
      'V08'  , 'Dickens'                   ,
      'V09'  , 'Brontës, Eliot, Shelley'   ,
      'V10'  , 'Dickens, Lovecraft, Wells'
)

## -----------------------------------------------------------------------------
target_entropy_term = expected_entropy(5, k = vocab_size)
target_entropy_term

exponent_term = tidy(fitted_tmf, n_authors, 'beta') |>
      target_power(topic, beta, target_entropy_term)
exponent_term

beta_df = tidy(fitted_tmf, n_authors, 'beta', exponent = exponent_term)

## -----------------------------------------------------------------------------
top_terms = beta_df |>
      group_by(topic) |>
      arrange(topic, desc(beta)) |>
      top_n(15, beta) |>
      left_join(topic_author, by = 'topic')
top_terms

top_terms |>
      mutate(token = reorder_within(token, by = beta, within = topic)) |>
      ggplot(aes(token, beta)) +
      geom_point() +
      geom_segment(aes(xend = token), yend = 0) +
      facet_wrap(vars(topic, authors), scales = 'free_y') +
      coord_flip() +
      scale_x_reordered()

