Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions psyche-book/src/development/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ The `train` example, documented below, is useful to test how your model trains u
## Running

```bash
cargo run --example train -- ---help
cargo run --example train -- --help
```

You'll need a pre-tokenized dataset downloaded to your disk for training.

> A PR is welcome to add an option to the trainer to use the HTTP data provider! You can refer to the http example in the data-provider crate for a sample implementation.
You'll need a pre-tokenized dataset for training. The `train` example supports multiple data sources: local files, HTTP URLs, GCP buckets, and weighted configurations.

For a Llama 2 model, a pre-tokenized dataset to test with is available at [https://huggingface.co/datasets/emozilla/fineweb-10bt-tokenized-datatrove-llama2/](https://huggingface.co/datasets/emozilla/fineweb-10bt-tokenized-datatrove-llama2/tree/main).
Psyche only needs the `.ds` files, and will load any/all `.ds` files in the specified folder - you can download just one for smaller tests.
Psyche only needs the `.ds` files, and will load any/all `.ds` files in the specified folder - you can use just one for smaller tests.

### Local data

If you've downloaded part or all of the above dataset into a folder `data/fineweb-10bt` inside the Psyche repo, you can start a simple training run on a 20m parameter Llama 2 model:

Expand All @@ -29,6 +29,68 @@ cargo run --example train -- \
--micro-batch 1
```

#### Local preprocessed data

For preprocessed data in parquet format (with `inputs` column), use `local-preprocessed`:

```bash
cargo run --example train -- \
--model emozilla/llama2-20m-init \
--total-batch 2 \
--micro-batch 1 \
local-preprocessed --path ./data/parquet/
```

### HTTP

You can stream data directly from HTTP URLs without downloading the dataset first. There are several ways to specify HTTP data sources:

#### URL template

Use a template with `{}` placeholder that gets replaced with padded numbers:

```bash
cargo run --example train -- \
--model emozilla/llama2-20m-init \
--total-batch 2 \
--micro-batch 1 \
http-template \
--template "https://example.com/data/{}.ds" \
--start 0 \
--end 10 \
--left-pad-zeros 5
```

This would load files from `https://example.com/data/00000.ds` through `https://example.com/data/00009.ds`.

#### Explicit URLs

Provide a list of URLs directly:

```bash
cargo run --example train -- \
--model emozilla/llama2-20m-init \
--total-batch 2 \
--micro-batch 1 \
urls \
https://example.com/data/file1.ds \
https://example.com/data/file2.ds
```

#### GCP bucket

Load all `.ds` files from a Google Cloud Storage bucket:

```bash
cargo run --example train -- \
--model emozilla/llama2-20m-init \
--total-batch 2 \
--micro-batch 1 \
gcp \
--bucket-name my-bucket \
--directory data/tokenized
```

## Adding a new model type

The `train` example currently asssumes your model is a Llama or Deepseek v2/v3 model, and instantiates it via `(LlamaForCausalLM|DeepseekForCausalLM)::from_pretrained`.
Expand Down
165 changes: 149 additions & 16 deletions shared/modeling/examples/train.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use anyhow::{Context, Result};
use clap::{Args, Parser, Subcommand, ValueEnum};
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use psyche_core::{
Barrier, BatchId, CancellableBarrier, ClosedInterval, CosineLR, OptimizerDefinition, Shuffle,
};
use psyche_data_provider::{
DataProvider, LengthKnownDataProvider, LocalDataProvider, PreprocessedDataProvider, Split,
TokenizedDataProvider, download_model_repo_sync,
TokenizedDataProvider, WeightedDataProvider, WeightedHttpProvidersConfig,
download_model_repo_sync,
http::{FileURLs, HttpDataProvider},
};
use psyche_modeling::{
AttentionImplementation, Batch, BatchData, BatchDataCPU, CausalLM, CommunicatorId,
Expand Down Expand Up @@ -91,14 +93,68 @@ enum Commands {
},
}

#[derive(Subcommand, Debug, Clone)]
enum DataSource {
/// Local directory with .ds files
Local {
#[arg(long, default_value = "data")]
path: String,
},
/// Local directory with .parquet files
LocalPreprocessed {
#[arg(long, default_value = "data")]
path: String,
},
/// HTTP with URL template
HttpTemplate {
/// URL template with {} placeholder (e.g., "http://example.com/{}.ds")
#[arg(long)]
template: String,
/// Start index
#[arg(long, default_value = "0")]
start: u32,
/// End index
#[arg(long)]
end: u32,
/// Number of zeros to left-pad to
#[arg(long, default_value = "0")]
left_pad_zeros: u8,
},
/// HTTP with explicit URLs
Urls {
/// List of data URLs
urls: Vec<String>,
},
/// HTTP from GCP bucket
Gcp {
/// The name of the GCP bucket
#[arg(long)]
bucket_name: String,
/// An optional directory to filter by
#[arg(long)]
directory: Option<String>,
},
/// Weighted HTTP config (JSON file or URL)
WeightedConfig {
/// Path or URL to WeightedHttpProvidersConfig JSON file
#[arg(long)]
config: String,
},
}

#[derive(Args, Debug, Clone)]
struct RunArgs {
#[arg(long, default_value = "emozilla/llama2-215m-init")]
model: String,

/// Path to local data directory (used when no data source subcommand is specified)
#[arg(long, default_value = "data")]
data_path: String,

/// Data source subcommand (defaults to local .ds files from --data-path)
#[command(subcommand)]
data_source: Option<DataSource>,

#[arg(long, default_value_t = 2048)]
sequence_length: usize,

Expand Down Expand Up @@ -233,27 +289,41 @@ async fn main() -> Result<()> {
None => Shuffle::DontShuffle,
};

let mut dataset: DataProvider<DummyNodeIdentity> = match LocalDataProvider::new_from_directory(
&args.data_path,
args.token_size.try_into()?,
args.sequence_length,
shuffle,
)
.with_context(|| "Failed to load data with local data provider.")
{
Ok(dataset) => {
let token_size = args.token_size.try_into()?;

let mut dataset: DataProvider<DummyNodeIdentity> = match &args.data_source {
None | Some(DataSource::Local { .. }) => {
// Use data_path from Local subcommand or fallback to args.data_path
let data_path = match &args.data_source {
Some(DataSource::Local { path }) => path,
_ => &args.data_path,
};
if !std::path::Path::new(data_path).exists() {
eprintln!("Error: Data directory '{}' does not exist.\n", data_path);
CliArgs::command().print_long_help()?;
std::process::exit(1);
}
let dataset = LocalDataProvider::new_from_directory(
data_path,
token_size,
args.sequence_length,
shuffle,
)
.with_context(|| format!("Failed to load local .ds data from '{}'", data_path))?;
info!(
"Loaded local dataset with {} samples",
dataset.num_sequences()
);
DataProvider::Local(dataset)
}
Err(err) => {
println!(
"Failed to load with local data provider. {err:?} Trying preprocessed data provider instead"
);
Some(DataSource::LocalPreprocessed { path }) => {
if !std::path::Path::new(path).exists() {
eprintln!("Error: Data directory '{}' does not exist.\n", path);
CliArgs::command().print_long_help()?;
std::process::exit(1);
}
let dataset = PreprocessedDataProvider::new_from_directory(
&args.data_path,
path,
args.sequence_length,
shuffle,
Some(Split::Train),
Expand All @@ -266,6 +336,69 @@ async fn main() -> Result<()> {
);
DataProvider::Preprocessed(dataset)
}
Some(DataSource::HttpTemplate {
template,
start,
end,
left_pad_zeros,
}) => {
if end <= start {
anyhow::bail!("end ({}) must be greater than start ({})", end, start);
}
let urls = FileURLs::from_template(template, *start, *left_pad_zeros, end - start)
.await
.with_context(|| format!("Failed to load URLs from template: {}", template))?;
let provider =
HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)
.with_context(|| "Failed to create HTTP data provider from template")?;
info!("Loaded HTTP template dataset");
DataProvider::Http(provider)
}
Some(DataSource::Urls { urls }) => {
if urls.is_empty() {
anyhow::bail!("At least one URL must be provided");
}
let urls = FileURLs::from_list(urls)
.await
.with_context(|| "Failed to load URLs from list")?;
let provider =
HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)
.with_context(|| "Failed to create HTTP data provider from URLs")?;
info!("Loaded HTTP URLs dataset");
DataProvider::Http(provider)
}
Some(DataSource::Gcp {
bucket_name,
directory,
}) => {
let urls = FileURLs::from_gcp_bucket(bucket_name, directory.clone())
.await
.with_context(|| format!("Failed to load URLs from GCP bucket: {}", bucket_name))?;
let provider =
HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)
.with_context(|| "Failed to create HTTP data provider from GCP bucket")?;
info!("Loaded GCP bucket dataset");
DataProvider::Http(provider)
}
Some(DataSource::WeightedConfig { config }) => {
let provider = if config.starts_with("http://") || config.starts_with("https://") {
WeightedDataProvider::from_config_url(config, args.sequence_length as u32)
.await
.with_context(|| {
format!("Failed to load weighted config from URL: {}", config)
})?
} else {
let content = std::fs::read_to_string(config)
.with_context(|| format!("Failed to read config file: {}", config))?;
let cfg: WeightedHttpProvidersConfig = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse config JSON: {}", config))?;
WeightedDataProvider::from_config(cfg, args.sequence_length as u32)
.await
.with_context(|| "Failed to create weighted data provider")?
};
info!("Loaded weighted HTTP dataset");
DataProvider::WeightedHttp(provider)
}
};

let schedule = CosineLR::new(
Expand Down