diff --git a/psyche-book/src/development/models.md b/psyche-book/src/development/models.md index 08c033aa7..170b28601 100644 --- a/psyche-book/src/development/models.md +++ b/psyche-book/src/development/models.md @@ -9,15 +9,15 @@ The `train` example, documented below, is useful to test how your model trains u ## Running ```bash -cargo run --example train -- ---help +cargo run --example train -- --help ``` -You'll need a pre-tokenized dataset downloaded to your disk for training. - -> A PR is welcome to add an option to the trainer to use the HTTP data provider! You can refer to the http example in the data-provider crate for a sample implementation. +You'll need a pre-tokenized dataset for training. The `train` example supports multiple data sources: local files, HTTP URLs, GCP buckets, and weighted configurations. For a Llama 2 model, a pre-tokenized dataset to test with is available at [https://huggingface.co/datasets/emozilla/fineweb-10bt-tokenized-datatrove-llama2/](https://huggingface.co/datasets/emozilla/fineweb-10bt-tokenized-datatrove-llama2/tree/main). -Psyche only needs the `.ds` files, and will load any/all `.ds` files in the specified folder - you can download just one for smaller tests. +Psyche only needs the `.ds` files, and will load any/all `.ds` files in the specified folder - you can use just one for smaller tests. + +### Local data If you've downloaded part or all of the above dataset into a folder `data/fineweb-10bt` inside the Psyche repo, you can start a simple training run on a 20m parameter Llama 2 model: @@ -29,6 +29,68 @@ cargo run --example train -- \ --micro-batch 1 ``` +#### Local preprocessed data + +For preprocessed data in parquet format (with `inputs` column), use `local-preprocessed`: + +```bash +cargo run --example train -- \ + --model emozilla/llama2-20m-init \ + --total-batch 2 \ + --micro-batch 1 \ + local-preprocessed --path ./data/parquet/ +``` + +### HTTP + +You can stream data directly from HTTP URLs without downloading the dataset first. There are several ways to specify HTTP data sources: + +#### URL template + +Use a template with `{}` placeholder that gets replaced with padded numbers: + +```bash +cargo run --example train -- \ + --model emozilla/llama2-20m-init \ + --total-batch 2 \ + --micro-batch 1 \ + http-template \ + --template "https://example.com/data/{}.ds" \ + --start 0 \ + --end 10 \ + --left-pad-zeros 5 +``` + +This would load files from `https://example.com/data/00000.ds` through `https://example.com/data/00009.ds`. + +#### Explicit URLs + +Provide a list of URLs directly: + +```bash +cargo run --example train -- \ + --model emozilla/llama2-20m-init \ + --total-batch 2 \ + --micro-batch 1 \ + urls \ + https://example.com/data/file1.ds \ + https://example.com/data/file2.ds +``` + +#### GCP bucket + +Load all `.ds` files from a Google Cloud Storage bucket: + +```bash +cargo run --example train -- \ + --model emozilla/llama2-20m-init \ + --total-batch 2 \ + --micro-batch 1 \ + gcp \ + --bucket-name my-bucket \ + --directory data/tokenized +``` + ## Adding a new model type The `train` example currently asssumes your model is a Llama or Deepseek v2/v3 model, and instantiates it via `(LlamaForCausalLM|DeepseekForCausalLM)::from_pretrained`. diff --git a/shared/modeling/examples/train.rs b/shared/modeling/examples/train.rs index b67a17f46..d776f3d45 100644 --- a/shared/modeling/examples/train.rs +++ b/shared/modeling/examples/train.rs @@ -1,11 +1,13 @@ use anyhow::{Context, Result}; -use clap::{Args, Parser, Subcommand, ValueEnum}; +use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; use psyche_core::{ Barrier, BatchId, CancellableBarrier, ClosedInterval, CosineLR, OptimizerDefinition, Shuffle, }; use psyche_data_provider::{ DataProvider, LengthKnownDataProvider, LocalDataProvider, PreprocessedDataProvider, Split, - TokenizedDataProvider, download_model_repo_sync, + TokenizedDataProvider, WeightedDataProvider, WeightedHttpProvidersConfig, + download_model_repo_sync, + http::{FileURLs, HttpDataProvider}, }; use psyche_modeling::{ AttentionImplementation, Batch, BatchData, BatchDataCPU, CausalLM, CommunicatorId, @@ -91,14 +93,68 @@ enum Commands { }, } +#[derive(Subcommand, Debug, Clone)] +enum DataSource { + /// Local directory with .ds files + Local { + #[arg(long, default_value = "data")] + path: String, + }, + /// Local directory with .parquet files + LocalPreprocessed { + #[arg(long, default_value = "data")] + path: String, + }, + /// HTTP with URL template + HttpTemplate { + /// URL template with {} placeholder (e.g., "http://example.com/{}.ds") + #[arg(long)] + template: String, + /// Start index + #[arg(long, default_value = "0")] + start: u32, + /// End index + #[arg(long)] + end: u32, + /// Number of zeros to left-pad to + #[arg(long, default_value = "0")] + left_pad_zeros: u8, + }, + /// HTTP with explicit URLs + Urls { + /// List of data URLs + urls: Vec, + }, + /// HTTP from GCP bucket + Gcp { + /// The name of the GCP bucket + #[arg(long)] + bucket_name: String, + /// An optional directory to filter by + #[arg(long)] + directory: Option, + }, + /// Weighted HTTP config (JSON file or URL) + WeightedConfig { + /// Path or URL to WeightedHttpProvidersConfig JSON file + #[arg(long)] + config: String, + }, +} + #[derive(Args, Debug, Clone)] struct RunArgs { #[arg(long, default_value = "emozilla/llama2-215m-init")] model: String, + /// Path to local data directory (used when no data source subcommand is specified) #[arg(long, default_value = "data")] data_path: String, + /// Data source subcommand (defaults to local .ds files from --data-path) + #[command(subcommand)] + data_source: Option, + #[arg(long, default_value_t = 2048)] sequence_length: usize, @@ -233,27 +289,41 @@ async fn main() -> Result<()> { None => Shuffle::DontShuffle, }; - let mut dataset: DataProvider = match LocalDataProvider::new_from_directory( - &args.data_path, - args.token_size.try_into()?, - args.sequence_length, - shuffle, - ) - .with_context(|| "Failed to load data with local data provider.") - { - Ok(dataset) => { + let token_size = args.token_size.try_into()?; + + let mut dataset: DataProvider = match &args.data_source { + None | Some(DataSource::Local { .. }) => { + // Use data_path from Local subcommand or fallback to args.data_path + let data_path = match &args.data_source { + Some(DataSource::Local { path }) => path, + _ => &args.data_path, + }; + if !std::path::Path::new(data_path).exists() { + eprintln!("Error: Data directory '{}' does not exist.\n", data_path); + CliArgs::command().print_long_help()?; + std::process::exit(1); + } + let dataset = LocalDataProvider::new_from_directory( + data_path, + token_size, + args.sequence_length, + shuffle, + ) + .with_context(|| format!("Failed to load local .ds data from '{}'", data_path))?; info!( "Loaded local dataset with {} samples", dataset.num_sequences() ); DataProvider::Local(dataset) } - Err(err) => { - println!( - "Failed to load with local data provider. {err:?} Trying preprocessed data provider instead" - ); + Some(DataSource::LocalPreprocessed { path }) => { + if !std::path::Path::new(path).exists() { + eprintln!("Error: Data directory '{}' does not exist.\n", path); + CliArgs::command().print_long_help()?; + std::process::exit(1); + } let dataset = PreprocessedDataProvider::new_from_directory( - &args.data_path, + path, args.sequence_length, shuffle, Some(Split::Train), @@ -266,6 +336,69 @@ async fn main() -> Result<()> { ); DataProvider::Preprocessed(dataset) } + Some(DataSource::HttpTemplate { + template, + start, + end, + left_pad_zeros, + }) => { + if end <= start { + anyhow::bail!("end ({}) must be greater than start ({})", end, start); + } + let urls = FileURLs::from_template(template, *start, *left_pad_zeros, end - start) + .await + .with_context(|| format!("Failed to load URLs from template: {}", template))?; + let provider = + HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle) + .with_context(|| "Failed to create HTTP data provider from template")?; + info!("Loaded HTTP template dataset"); + DataProvider::Http(provider) + } + Some(DataSource::Urls { urls }) => { + if urls.is_empty() { + anyhow::bail!("At least one URL must be provided"); + } + let urls = FileURLs::from_list(urls) + .await + .with_context(|| "Failed to load URLs from list")?; + let provider = + HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle) + .with_context(|| "Failed to create HTTP data provider from URLs")?; + info!("Loaded HTTP URLs dataset"); + DataProvider::Http(provider) + } + Some(DataSource::Gcp { + bucket_name, + directory, + }) => { + let urls = FileURLs::from_gcp_bucket(bucket_name, directory.clone()) + .await + .with_context(|| format!("Failed to load URLs from GCP bucket: {}", bucket_name))?; + let provider = + HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle) + .with_context(|| "Failed to create HTTP data provider from GCP bucket")?; + info!("Loaded GCP bucket dataset"); + DataProvider::Http(provider) + } + Some(DataSource::WeightedConfig { config }) => { + let provider = if config.starts_with("http://") || config.starts_with("https://") { + WeightedDataProvider::from_config_url(config, args.sequence_length as u32) + .await + .with_context(|| { + format!("Failed to load weighted config from URL: {}", config) + })? + } else { + let content = std::fs::read_to_string(config) + .with_context(|| format!("Failed to read config file: {}", config))?; + let cfg: WeightedHttpProvidersConfig = serde_json::from_str(&content) + .with_context(|| format!("Failed to parse config JSON: {}", config))?; + WeightedDataProvider::from_config(cfg, args.sequence_length as u32) + .await + .with_context(|| "Failed to create weighted data provider")? + }; + info!("Loaded weighted HTTP dataset"); + DataProvider::WeightedHttp(provider) + } }; let schedule = CosineLR::new(