diff --git a/Cargo.lock b/Cargo.lock index 6083e83..d82f790 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1288,7 +1288,6 @@ dependencies = [ "esaxx-rs", "fancy-regex", "getrandom 0.3.3", - "hf-hub", "indicatif", "itertools", "log", diff --git a/Cargo.toml b/Cargo.toml index 4f1b27e..c9f7f0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,22 +14,21 @@ categories = ["science", "text-processing"] exclude = ["tests/*"] [features] -default = ["onig"] +default = ["onig", "hf-hub"] +hf-hub = ["dep:hf-hub"] onig = ["tokenizers/onig", "tokenizers/progressbar", - "tokenizers/esaxx_fast", - "tokenizers/http"] + "tokenizers/esaxx_fast"] fancy-regex = ["tokenizers/fancy-regex", "tokenizers/progressbar", - "tokenizers/esaxx_fast", - "tokenizers/http"] + "tokenizers/esaxx_fast"] [dependencies] tokenizers = { version = "0.21", default-features = false } safetensors = "0.5" ndarray = "0.15" -hf-hub = { version = "0.4", default-features = false, features = ["ureq"] } +hf-hub = { version = "0.4", default-features = false, features = ["ureq"], optional = true } clap = { version = "4.0", features = ["derive"] } anyhow = "1.0" serde_json = "1.0" diff --git a/src/model.rs b/src/model.rs index d798f20..40d1d44 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,11 +1,14 @@ use anyhow::{anyhow, Context, Result}; use half::f16; +#[cfg(feature = "hf-hub")] use hf_hub::api::sync::Api; use ndarray::{Array2, ArrayView2, CowArray, Ix2}; use safetensors::{tensor::Dtype, SafeTensors}; use serde_json::Value; use std::borrow::Cow; -use std::{env, fs, path::Path}; +#[cfg(feature = "hf-hub")] +use std::env; +use std::{fs, path::Path}; use tokenizers::Tokenizer; /// Static embedding model for Model2Vec @@ -379,9 +382,8 @@ fn resolve_model_files>( token: Option<&str>, subfolder: Option<&str>, ) -> Result { - if let Some(tok) = token { - env::set_var("HF_HUB_TOKEN", tok); - } + #[cfg(not(feature = "hf-hub"))] + let _ = token; let (tokenizer, model, config) = { let base = repo_or_path.as_ref(); @@ -395,14 +397,17 @@ fn resolve_model_files>( } (tokenizer, model, config) } else { - let api = Api::new().context("hf-hub API init failed")?; - let repo = api.model(repo_or_path.as_ref().to_string_lossy().into_owned()); - let prefix = subfolder.map(|s| format!("{s}/")).unwrap_or_default(); - ( - repo.get(&format!("{prefix}tokenizer.json"))?, - repo.get(&format!("{prefix}model.safetensors"))?, - repo.get(&format!("{prefix}config.json"))?, - ) + #[cfg(feature = "hf-hub")] + { + let files = download_model_files(repo_or_path.as_ref().to_string_lossy().as_ref(), token, subfolder)?; + (files.tokenizer, files.model, files.config) + } + #[cfg(not(feature = "hf-hub"))] + { + return Err(anyhow!( + "remote model downloads require the `hf-hub` feature; pass a local model directory instead" + )); + } } }; @@ -412,3 +417,32 @@ fn resolve_model_files>( config, }) } + +#[cfg(feature = "hf-hub")] +fn download_model_files(repo_id: &str, token: Option<&str>, subfolder: Option<&str>) -> Result { + let previous = token.and_then(|_| env::var_os("HF_HUB_TOKEN")); + if let Some(tok) = token { + env::set_var("HF_HUB_TOKEN", tok); + } + + let result = (|| { + let api = Api::new().context("hf-hub API init failed")?; + let repo = api.model(repo_id.to_owned()); + let prefix = subfolder.map(|s| format!("{s}/")).unwrap_or_default(); + Ok(ModelFiles { + tokenizer: repo.get(&format!("{prefix}tokenizer.json"))?, + model: repo.get(&format!("{prefix}model.safetensors"))?, + config: repo.get(&format!("{prefix}config.json"))?, + }) + })(); + + if token.is_some() { + if let Some(value) = previous { + env::set_var("HF_HUB_TOKEN", value); + } else { + env::remove_var("HF_HUB_TOKEN"); + } + } + + result +} diff --git a/tests/test_model.rs b/tests/test_model.rs index 33f8b8f..6004a27 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -122,3 +122,13 @@ fn test_from_bytes_matches_from_pretrained_for_local_model() { ); } } + +#[cfg(not(feature = "hf-hub"))] +#[test] +fn test_from_pretrained_remote_requires_hf_hub_feature() { + let err = StaticModel::from_pretrained("minishlab/potion-base-2M", None, None, None).unwrap_err(); + assert!( + err.to_string().contains("hf-hub"), + "expected remote loading without hf-hub to mention the missing feature" + ); +}