diff --git a/DESCRIPTION b/DESCRIPTION index 13115b0..2037321 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -10,18 +10,15 @@ License: GPL-3 LazyData: TRUE Encoding: UTF-8 Depends: - methods, R (>= 3.5.0) Imports: + methods, quanteda (>= 2.0), quanteda.textstats, stringi, digest, Matrix, RSpectra, - irlba, - rsvd, - rsparse, proxyC, stats, ggplot2, @@ -29,11 +26,15 @@ Imports: reshape2, locfit Suggests: + testthat, spelling, knitr, rmarkdown, - testthat -RoxygenNote: 7.3.1 + wordvector, + irlba, + rsvd, + rsparse +RoxygenNote: 7.3.2 Roxygen: list(markdown = TRUE) BugReports: https://github.com/koheiw/LSX/issues URL: https://koheiw.github.io/LSX/ diff --git a/NAMESPACE b/NAMESPACE index 50611a9..8d89225 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -8,6 +8,7 @@ S3method(as.statistics_textmodel,matrix) S3method(as.textmodel_lss,matrix) S3method(as.textmodel_lss,numeric) S3method(as.textmodel_lss,textmodel_lss) +S3method(as.textmodel_lss,textmodel_wordvector) S3method(coef,textmodel_lss) S3method(diagnosys,character) S3method(diagnosys,corpus) diff --git a/NEWS.md b/NEWS.md index b081ca7..92c58a3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ -## Changes in v1.4.2 +## Changes in v1.5.0 +* Add `as.textmodel_lss()` for objects from the **wordvector** package. +* Reduce dependent packages by moving **rsparse**, **irlba** and *rsvd* to Suggests. * Fix handling of phrasal patterns in `textplot_terms()`. * Improve objects created by `as.textmodel_lss.textmodel_lss()`. diff --git a/R/as.textmodel.R b/R/as.textmodel.R index f29ef6a..6929f72 100644 --- a/R/as.textmodel.R +++ b/R/as.textmodel.R @@ -99,3 +99,15 @@ as.textmodel_lss.textmodel_lss <- function(x, ...) { result$frequency <- x$frequency[names(result$beta)] return(result) } + +#' @export +#' @method as.textmodel_lss textmodel_wordvector +as.textmodel_lss.textmodel_wordvector <- function(x, ...) { + if (!requireNamespace("wordvector")) + stop("wordvector package must be installed") + if (is.null(x$vectors)) + stop("x must be a valid textmodel_wordvector object") + result <- as.textmodel_lss(t(x$vectors), ...) + result$frequency <- x$frequency[names(result$beta)] + return(result) +} diff --git a/R/textmodel.R b/R/textmodel.R index eacd032..2f77f61 100644 --- a/R/textmodel.R +++ b/R/textmodel.R @@ -307,12 +307,16 @@ cache_svd <- function(x, k, weight, engine, cache = TRUE, ...) { message("Reading cache file: ", file_cache) result <- readRDS(file_cache) } else { - if (engine == "RSpectra") { - result <- RSpectra::svds(as(x, "dgCMatrix"), k = k, nu = 0, nv = k, ...) - } else if (engine == "rsvd") { + if (engine == "rsvd") { + if (!requireNamespace("rsvd")) + stop("wordvector package must be installed") result <- rsvd::rsvd(as(x, "dgCMatrix"), k = k, nu = 0, nv = k, ...) - } else { + } else if (engine == "irlba") { + if (!requireNamespace("irlba")) + stop("irlba package must be installed") result <- irlba::irlba(as(x, "dgCMatrix"), nv = k, right_only = TRUE, ...) + } else { + result <- RSpectra::svds(as(x, "dgCMatrix"), k = k, nu = 0, nv = k, ...) } if (cache) { message("Writing cache file: ", file_cache) @@ -337,6 +341,8 @@ cache_glove <- function(x, w, x_max = 10, n_iter = 10, cache = TRUE, ...) { message("Reading cache file: ", file_cache) result <- readRDS(file_cache) } else { + if (!requireNamespace("rsparse")) + stop("wordvector package must be installed") glove <- rsparse::GloVe$new(rank = w, x_max = x_max, ...) temp <- glove$fit_transform(Matrix::drop0(x), n_iter = n_iter, n_threads = getOption("quanteda_threads", 1L)) diff --git a/tests/data/word2vec_test.RDS b/tests/data/word2vec_test.RDS new file mode 100644 index 0000000..cce628b Binary files /dev/null and b/tests/data/word2vec_test.RDS differ diff --git a/tests/testthat/test-as.textmodel.R b/tests/testthat/test-as.textmodel.R index 0f26ab2..4277302 100644 --- a/tests/testthat/test-as.textmodel.R +++ b/tests/testthat/test-as.textmodel.R @@ -96,6 +96,17 @@ test_that("as.textmodel_lss works with textmodel_lss", { ) }) +test_that("as.textmodel_lss works with textmodel_wordvector", { + + wdv <- readRDS("../data/word2vec_test.RDS") + lss <- as.textmodel_lss(wdv, seed) + + expect_equal(lss$embedding, t(wdv$vectors)) + expect_identical(lss$frequency, wdv$frequency) + expect_identical(names(lss$frequency), names(wdv$frequency)) + +}) + test_that("as.textmodel_lss works with vector", { weight <- c("decision" = 0.1, "instance" = -0.1, "foundations" = 0.3, "the" = 0) diff --git a/tests/testthat/test-textmodel.R b/tests/testthat/test-textmodel.R index e1ce151..a231d09 100644 --- a/tests/testthat/test-textmodel.R +++ b/tests/testthat/test-textmodel.R @@ -304,10 +304,25 @@ test_that("textmodel_lss works with non-existent seeds", { "No seed word is found in the dfm") }) -test_that("RSpectra and irlba work", { - - expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "RSpectra")) - expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "irlba")) +test_that("rsvd and irlba work", { + + if (requireNamespace("irlba")) { + expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "irlba")) + } else { + expect_error( + expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "irlba")), + "irlba package must be installed" + ) + } + + if (requireNamespace("rsvd")) { + expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "rsvd")) + } else { + expect_error( + expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "rsvd")), + "rsvd package must be installed" + ) + } })