Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,30 @@ rust_tokenizers = "8.1.1"
tch = { version = "0.17.0", features = ["download-libtorch"] }
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "4.2.0"
ordered-float = "5"
uuid = { version = "1", features = ["v4"] }
thiserror = "1"
thiserror = "2"
half = "2"
regex = "1.10"
regex = "1.11"

cached-path = { version = "0.6", default-features = false, optional = true }
dirs = { version = "5", optional = true }
cached-path = { version = "0.8", default-features = false, optional = true }
dirs = { version = "6", optional = true }
lazy_static = { version = "1", optional = true }
ort = { version = "1.16.3", optional = true, default-features = false, features = [
"half",
] }
ndarray = { version = "0.15", optional = true }
tokenizers = { version = "0.20", optional = true, default-features = false, features = [
tokenizers = { version = "0.21", optional = true, default-features = false, features = [
"onig",
] }

[dev-dependencies]
anyhow = "1"
csv = "1"
criterion = "0.5"
tokio = { version = "1.35", features = ["sync", "rt-multi-thread", "macros"] }
criterion = "0.6"
tokio = { version = "1.45", features = ["sync", "rt-multi-thread", "macros"] }
tempfile = "3"
itertools = "0.13.0"
itertools = "0.14.0"
tracing-subscriber = { version = "0.3", default-features = false, features = [
"env-filter",
"fmt",
Expand Down
3 changes: 2 additions & 1 deletion benches/generation_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#[macro_use]
extern crate criterion;

use criterion::{black_box, Criterion};
use criterion::Criterion;
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use std::hint::black_box;
use std::time::{Duration, Instant};
use tch::Device;

Expand Down
3 changes: 2 additions & 1 deletion benches/squad_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#[macro_use]
extern crate criterion;

use criterion::{black_box, Criterion};
use criterion::Criterion;
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::question_answering::{
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::RemoteResource;
use std::env;
use std::hint::black_box;
use std::path::PathBuf;
use std::time::{Duration, Instant};

Expand Down
3 changes: 2 additions & 1 deletion benches/summarization_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#[macro_use]
extern crate criterion;

use criterion::{black_box, Criterion};
use criterion::Criterion;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use std::hint::black_box;
use std::time::{Duration, Instant};
use tch::Device;

Expand Down
3 changes: 2 additions & 1 deletion benches/tensor_operations_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#[macro_use]
extern crate criterion;

use criterion::{black_box, Criterion};
use criterion::Criterion;
use std::hint::black_box;
use std::time::{Duration, Instant};
use tch::kind::Kind;
use tch::{Device, Tensor};
Expand Down
3 changes: 2 additions & 1 deletion benches/token_classification_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use criterion::{criterion_group, criterion_main, Criterion};
use rust_bert::pipelines::token_classification::{
TokenClassificationConfig, TokenClassificationModel,
};
use std::hint::black_box;
use tch::Device;

fn create_model() -> TokenClassificationModel {
Expand Down
3 changes: 2 additions & 1 deletion benches/translation_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#[macro_use]
extern crate criterion;

use criterion::{black_box, Criterion};
use criterion::Criterion;
use std::hint::black_box;
// use rust_bert::pipelines::common::ModelType;
// use rust_bert::pipelines::translation::TranslationOption::{Marian, T5};
use rust_bert::pipelines::common::ModelType;
Expand Down
18 changes: 9 additions & 9 deletions examples/generation_gptj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ use tch::Device;
/// inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
///
/// with torch.no_grad():
/// gen_tokens = model.generate(
/// **inputs,
/// min_length=0,
/// max_length=32,
/// do_sample=False,
/// early_stopping=True,
/// num_beams=1,
/// num_return_sequences=1
/// )
/// gen_tokens = model.generate(
/// **inputs,
/// min_length=0,
/// max_length=32,
/// do_sample=False,
/// early_stopping=True,
/// num_beams=1,
/// num_return_sequences=1
/// )
///
/// gen_texts = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
/// ````
Expand Down
4 changes: 2 additions & 2 deletions src/common/resources/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl ResourceProvider for LocalResource {
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
Expand All @@ -42,7 +42,7 @@ impl ResourceProvider for LocalResource {
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_resource();
/// ```
Expand Down
4 changes: 2 additions & 2 deletions src/common/resources/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//! - LocalResource: points to a local file
//! - RemoteResource: points to a remote file via a URL
//! - BufferResource: refers to a buffer that contains file contents for a resource (currently only
//! usable for weights)
//! usable for weights)
//!
//! For `LocalResource` and `RemoteResource`, the local location of the file can be retrieved using
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
Expand Down Expand Up @@ -52,7 +52,7 @@ pub trait ResourceProvider: Debug + Send + Sync {
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
Expand Down
6 changes: 3 additions & 3 deletions src/common/resources/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ impl RemoteResource {
/// ```no_run
/// use rust_bert::resources::RemoteResource;
/// let model_resource = RemoteResource::from_pretrained((
/// "distilbert-sst2",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// "distilbert-sst2",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// ));
/// ```
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
Expand All @@ -84,7 +84,7 @@ impl ResourceProvider for RemoteResource {
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_local_path();
/// ```
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.4`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu124.zip` for a Linux version with CUDA12.
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu124.zip` for a Linux version with CUDA12.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:
Expand Down
92 changes: 46 additions & 46 deletions src/models/albert/albert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,19 @@ impl AlbertModel {
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
/// .unwrap()
/// albert_model
/// .forward_t(
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
pub fn forward_t(
Expand Down Expand Up @@ -430,17 +430,17 @@ impl AlbertForMaskedLM {
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
/// .expand(&[batch_size, sequence_length], true);
///
/// let masked_lm_output = no_grad(|| {
/// albert_model.forward_t(
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
/// albert_model.forward_t(
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
/// });
/// ```
pub fn forward_t(
Expand Down Expand Up @@ -505,7 +505,7 @@ impl AlbertForSequenceClassification {
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForSequenceClassification =
/// AlbertForSequenceClassification::new(&p.root(), &config).unwrap();
/// AlbertForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(
p: P,
Expand Down Expand Up @@ -581,12 +581,12 @@ impl AlbertForSequenceClassification {
///
/// let classification_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
Expand Down Expand Up @@ -655,7 +655,7 @@ impl AlbertForTokenClassification {
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForTokenClassification =
/// AlbertForTokenClassification::new(&p.root(), &config).unwrap();
/// AlbertForTokenClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(
p: P,
Expand Down Expand Up @@ -730,12 +730,12 @@ impl AlbertForTokenClassification {
///
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
Expand Down Expand Up @@ -862,12 +862,12 @@ impl AlbertForQuestionAnswering {
///
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
Expand Down Expand Up @@ -1005,12 +1005,12 @@ impl AlbertForMultipleChoice {
///
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false).unwrap()
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false).unwrap()
/// });
/// ```
pub fn forward_t(
Expand Down
Loading
Loading