Skip to content

Commit

Permalink
Merge pull request #104 from koheiw/add-wordvector
Browse files Browse the repository at this point in the history
Add wordvector
  • Loading branch information
koheiw authored Dec 13, 2024
2 parents b792446 + e118750 commit bf5539b
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 15 deletions.
13 changes: 7 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,31 @@ License: GPL-3
LazyData: TRUE
Encoding: UTF-8
Depends:
methods,
R (>= 3.5.0)
Imports:
methods,
quanteda (>= 2.0),
quanteda.textstats,
stringi,
digest,
Matrix,
RSpectra,
irlba,
rsvd,
rsparse,
proxyC,
stats,
ggplot2,
ggrepel,
reshape2,
locfit
Suggests:
testthat,
spelling,
knitr,
rmarkdown,
testthat
RoxygenNote: 7.3.1
wordvector,
irlba,
rsvd,
rsparse
RoxygenNote: 7.3.2
Roxygen: list(markdown = TRUE)
BugReports: https://github.com/koheiw/LSX/issues
URL: https://koheiw.github.io/LSX/
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ S3method(as.statistics_textmodel,matrix)
S3method(as.textmodel_lss,matrix)
S3method(as.textmodel_lss,numeric)
S3method(as.textmodel_lss,textmodel_lss)
S3method(as.textmodel_lss,textmodel_wordvector)
S3method(coef,textmodel_lss)
S3method(diagnosys,character)
S3method(diagnosys,corpus)
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## Changes in v1.4.2
## Changes in v1.5.0

* Add `as.textmodel_lss()` for objects from the **wordvector** package.
* Reduce dependent packages by moving **rsparse**, **irlba** and *rsvd* to Suggests.
* Fix handling of phrasal patterns in `textplot_terms()`.
* Improve objects created by `as.textmodel_lss.textmodel_lss()`.

Expand Down
12 changes: 12 additions & 0 deletions R/as.textmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,15 @@ as.textmodel_lss.textmodel_lss <- function(x, ...) {
result$frequency <- x$frequency[names(result$beta)]
return(result)
}

#' @export
#' @method as.textmodel_lss textmodel_wordvector
as.textmodel_lss.textmodel_wordvector <- function(x, ...) {
if (!requireNamespace("wordvector"))
stop("wordvector package must be installed")
if (is.null(x$vectors))
stop("x must be a valid textmodel_wordvector object")
result <- as.textmodel_lss(t(x$vectors), ...)
result$frequency <- x$frequency[names(result$beta)]
return(result)
}
14 changes: 10 additions & 4 deletions R/textmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,16 @@ cache_svd <- function(x, k, weight, engine, cache = TRUE, ...) {
message("Reading cache file: ", file_cache)
result <- readRDS(file_cache)
} else {
if (engine == "RSpectra") {
result <- RSpectra::svds(as(x, "dgCMatrix"), k = k, nu = 0, nv = k, ...)
} else if (engine == "rsvd") {
if (engine == "rsvd") {
if (!requireNamespace("rsvd"))
stop("wordvector package must be installed")
result <- rsvd::rsvd(as(x, "dgCMatrix"), k = k, nu = 0, nv = k, ...)
} else {
} else if (engine == "irlba") {
if (!requireNamespace("irlba"))
stop("irlba package must be installed")
result <- irlba::irlba(as(x, "dgCMatrix"), nv = k, right_only = TRUE, ...)
} else {
result <- RSpectra::svds(as(x, "dgCMatrix"), k = k, nu = 0, nv = k, ...)
}
if (cache) {
message("Writing cache file: ", file_cache)
Expand All @@ -337,6 +341,8 @@ cache_glove <- function(x, w, x_max = 10, n_iter = 10, cache = TRUE, ...) {
message("Reading cache file: ", file_cache)
result <- readRDS(file_cache)
} else {
if (!requireNamespace("rsparse"))
stop("wordvector package must be installed")
glove <- rsparse::GloVe$new(rank = w, x_max = x_max, ...)
temp <- glove$fit_transform(Matrix::drop0(x), n_iter = n_iter,
n_threads = getOption("quanteda_threads", 1L))
Expand Down
Binary file added tests/data/word2vec_test.RDS
Binary file not shown.
11 changes: 11 additions & 0 deletions tests/testthat/test-as.textmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ test_that("as.textmodel_lss works with textmodel_lss", {
)
})

test_that("as.textmodel_lss works with textmodel_wordvector", {

wdv <- readRDS("../data/word2vec_test.RDS")
lss <- as.textmodel_lss(wdv, seed)

expect_equal(lss$embedding, t(wdv$vectors))
expect_identical(lss$frequency, wdv$frequency)
expect_identical(names(lss$frequency), names(wdv$frequency))

})

test_that("as.textmodel_lss works with vector", {
weight <- c("decision" = 0.1, "instance" = -0.1,
"foundations" = 0.3, "the" = 0)
Expand Down
23 changes: 19 additions & 4 deletions tests/testthat/test-textmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,25 @@ test_that("textmodel_lss works with non-existent seeds", {
"No seed word is found in the dfm")
})

test_that("RSpectra and irlba work", {

expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "RSpectra"))
expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "irlba"))
test_that("rsvd and irlba work", {

if (requireNamespace("irlba")) {
expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "irlba"))
} else {
expect_error(
expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "irlba")),
"irlba package must be installed"
)
}

if (requireNamespace("rsvd")) {
expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "rsvd"))
} else {
expect_error(
expect_silent(textmodel_lss(dfmt_test, seedwords("pos-neg"), k = 10, engine = "rsvd")),
"rsvd package must be installed"
)
}

})

Expand Down

0 comments on commit bf5539b

Please sign in to comment.