Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions tools/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ class LlamaData {
}

private:
int download(const std::string & url, const std::string & output_file, const bool progress,
static int download(const std::string & url, const std::string & output_file, const bool progress,
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
HttpClient http;
if (http.init(url, headers, output_file, progress, response_str)) {
Expand All @@ -799,7 +799,7 @@ class LlamaData {
}

// Helper function to handle model tag extraction and URL construction
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
static std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
std::string model_tag = "latest";
const size_t colon_pos = model.find(':');
if (colon_pos != std::string::npos) {
Expand All @@ -813,7 +813,7 @@ class LlamaData {
}

// Helper function to download and parse the manifest
int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
static int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
nlohmann::json & manifest) {
std::string manifest_str;
int ret = download(url, "", false, headers, &manifest_str);
Expand All @@ -826,7 +826,7 @@ class LlamaData {
return 0;
}

int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) {
static int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) {
// Find the second occurrence of '/' after protocol string
size_t pos = model.find('/');
pos = model.find('/', pos + 1);
Expand Down Expand Up @@ -855,17 +855,17 @@ class LlamaData {
return download(url, bn, true, headers);
}

int modelscope_dl(std::string & model, const std::string & bn) {
static int modelscope_dl(std::string & model, const std::string & bn) {
std::string model_endpoint = "https://modelscope.cn/models/";
return dl_from_endpoint(model_endpoint, model, bn);
}

int huggingface_dl(std::string & model, const std::string & bn) {
static int huggingface_dl(std::string & model, const std::string & bn) {
std::string model_endpoint = get_model_endpoint();
return dl_from_endpoint(model_endpoint, model, bn);
}

int ollama_dl(std::string & model, const std::string & bn) {
static int ollama_dl(std::string & model, const std::string & bn) {
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
if (model.find('/') == std::string::npos) {
model = "library/" + model;
Expand All @@ -891,7 +891,7 @@ class LlamaData {
return download(blob_url, bn, true, headers);
}

int github_dl(const std::string & model, const std::string & bn) {
static int github_dl(const std::string & model, const std::string & bn) {
std::string repository = model;
std::string branch = "main";
const size_t at_pos = model.find('@');
Expand All @@ -916,7 +916,7 @@ class LlamaData {
return download(url, bn, true);
}

int s3_dl(const std::string & model, const std::string & bn) {
static int s3_dl(const std::string & model, const std::string & bn) {
const size_t slash_pos = model.find('/');
if (slash_pos == std::string::npos) {
return 1;
Expand Down Expand Up @@ -949,7 +949,7 @@ class LlamaData {
return download(url, bn, true, headers);
}

std::string basename(const std::string & path) {
static std::string basename(const std::string & path) {
const size_t pos = path.find_last_of("/\\");
if (pos == std::string::npos) {
return path;
Expand All @@ -958,7 +958,7 @@ class LlamaData {
return path.substr(pos + 1);
}

int rm_until_substring(std::string & model_, const std::string & substring) {
static int rm_until_substring(std::string & model_, const std::string & substring) {
const std::string::size_type pos = model_.find(substring);
if (pos == std::string::npos) {
return 1;
Expand All @@ -968,7 +968,7 @@ class LlamaData {
return 0;
}

int resolve_model(std::string & model_) {
static int resolve_model(std::string & model_) {
int ret = 0;
if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
rm_until_substring(model_, "://");
Expand Down Expand Up @@ -1007,7 +1007,7 @@ class LlamaData {
}

// Initializes the model and returns a unique pointer to it
llama_model_ptr initialize_model(Opt & opt) {
static llama_model_ptr initialize_model(Opt & opt) {
ggml_backend_load_all();
resolve_model(opt.model_);
printe("\r" LOG_CLR_TO_EOL "Loading model");
Expand All @@ -1021,7 +1021,7 @@ class LlamaData {
}

// Initializes the context with the specified parameters
llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
static llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params));
if (!context) {
printe("%s: error: failed to create the llama_context\n", __func__);
Expand All @@ -1031,7 +1031,7 @@ class LlamaData {
}

// Initializes and configures the sampler
llama_sampler_ptr initialize_sampler(const Opt & opt) {
static llama_sampler_ptr initialize_sampler(const Opt & opt) {
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature));
Expand Down
Loading