diff --git a/Cargo.lock b/Cargo.lock index d82f790..170e389 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -473,9 +473,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasi 0.14.2+wasi-0.2.4", + "wasm-bindgen", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c9f7f0e..b00cd75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ exclude = ["tests/*"] [features] default = ["onig", "hf-hub"] hf-hub = ["dep:hf-hub"] +local-only = [] onig = ["tokenizers/onig", "tokenizers/progressbar", "tokenizers/esaxx_fast"] @@ -23,6 +24,8 @@ onig = ["tokenizers/onig", fancy-regex = ["tokenizers/fancy-regex", "tokenizers/progressbar", "tokenizers/esaxx_fast"] +wasm = ["local-only", + "tokenizers/unstable_wasm"] [dependencies] tokenizers = { version = "0.21", default-features = false } diff --git a/README.md b/README.md index f2f267e..4d91cf5 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,36 @@ cargo build --release * **Batch Processing:** Encodes multiple sentences in batches. * **Configurable Encoding:** Allows customization of maximum sequence length and batch size during encoding. +### Feature flags + +The crate exposes a few feature combinations for different runtimes: + +* `default`: native build with `onig` tokenization and optional Hugging Face Hub downloads +* `fancy-regex`: alternative tokenizer backend for native builds +* `local-only`: disable remote model downloads and restrict loading to local paths or `from_bytes(...)` +* `wasm`: minimal WebAssembly-oriented feature set for in-memory loading via `from_bytes(...)` + +Typical invocations are: + +* native local-only build: + `cargo build --no-default-features --features onig,local-only` +* wasm check: + `RUSTFLAGS='--cfg getrandom_backend="wasm_js"' cargo check --no-default-features --features wasm --target wasm32-unknown-unknown` + +The `wasm` feature is intended for `wasm32-unknown-unknown` builds that load models +from in-memory bytes, for example after fetching assets over HTTP or embedding them +into the binary. Direct filesystem access is usually not available in browser-style +WebAssembly environments, so callers should pass file contents through `from_bytes(...)`. +Remote Hugging Face downloads are not available in this mode. + +For `wasm32-unknown-unknown`, `getrandom` also requires a target-specific backend +configuration. The minimal check command is: + +```bash +RUSTFLAGS='--cfg getrandom_backend="wasm_js"' \ +cargo check --no-default-features --features wasm --target wasm32-unknown-unknown +``` + ## What is Model2Vec? Model2Vec is a technique to distill large sentence transformer models into highly efficient static embedding models. This process significantly reduces model size and computational requirements for inference. For a detailed understanding of how Model2Vec works, including the distillation process and model training, please refer to the [main Model2Vec Python repository](https://github.com/MinishLab/model2vec) and its [documentation](https://github.com/MinishLab/model2vec/blob/main/docs/what_is_model2vec.md). diff --git a/src/model.rs b/src/model.rs index 40d1d44..fac2768 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,12 +1,12 @@ use anyhow::{anyhow, Context, Result}; use half::f16; -#[cfg(feature = "hf-hub")] +#[cfg(all(feature = "hf-hub", not(feature = "local-only")))] use hf_hub::api::sync::Api; use ndarray::{Array2, ArrayView2, CowArray, Ix2}; use safetensors::{tensor::Dtype, SafeTensors}; use serde_json::Value; use std::borrow::Cow; -#[cfg(feature = "hf-hub")] +#[cfg(all(feature = "hf-hub", not(feature = "local-only")))] use std::env; use std::{fs, path::Path}; use tokenizers::Tokenizer; @@ -384,6 +384,8 @@ fn resolve_model_files>( ) -> Result { #[cfg(not(feature = "hf-hub"))] let _ = token; + #[cfg(feature = "local-only")] + let _ = token; let (tokenizer, model, config) = { let base = repo_or_path.as_ref(); @@ -397,12 +399,18 @@ fn resolve_model_files>( } (tokenizer, model, config) } else { - #[cfg(feature = "hf-hub")] + #[cfg(all(feature = "hf-hub", not(feature = "local-only")))] { let files = download_model_files(repo_or_path.as_ref().to_string_lossy().as_ref(), token, subfolder)?; (files.tokenizer, files.model, files.config) } - #[cfg(not(feature = "hf-hub"))] + #[cfg(feature = "local-only")] + { + return Err(anyhow!( + "remote model downloads are disabled by the `local-only` feature; pass a local model directory instead" + )); + } + #[cfg(all(not(feature = "hf-hub"), not(feature = "local-only")))] { return Err(anyhow!( "remote model downloads require the `hf-hub` feature; pass a local model directory instead" @@ -418,7 +426,7 @@ fn resolve_model_files>( }) } -#[cfg(feature = "hf-hub")] +#[cfg(all(feature = "hf-hub", not(feature = "local-only")))] fn download_model_files(repo_id: &str, token: Option<&str>, subfolder: Option<&str>) -> Result { let previous = token.and_then(|_| env::var_os("HF_HUB_TOKEN")); if let Some(tok) = token { diff --git a/tests/test_model.rs b/tests/test_model.rs index 6004a27..9f59931 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -132,3 +132,13 @@ fn test_from_pretrained_remote_requires_hf_hub_feature() { "expected remote loading without hf-hub to mention the missing feature" ); } + +#[cfg(all(feature = "hf-hub", feature = "local-only"))] +#[test] +fn test_from_pretrained_remote_disallowed_by_local_only_feature() { + let err = StaticModel::from_pretrained("minishlab/potion-base-2M", None, None, None).unwrap_err(); + assert!( + err.to_string().contains("local-only"), + "expected remote loading with local-only to mention the local-only restriction" + ); +}