diff --git a/flutter/cpp/datasets/ifeval.cc b/flutter/cpp/datasets/ifeval.cc index 07834ebbd..8588bdb6b 100644 --- a/flutter/cpp/datasets/ifeval.cc +++ b/flutter/cpp/datasets/ifeval.cc @@ -32,6 +32,13 @@ IFEval::IFEval(Backend* backend, const std::string& input_tfrecord, std::vector input_tokens; sp_processor->Encode(input_formatted.c_str(), &input_tokens).ok(); + // input token sanity check + if (input_tokens.size() > input_token_limit_) { + LOG(WARNING) << "Input token limit exceeded for entry " + << std::to_string(i) << ". Ignoring."; + continue; + } + auto sample = std::make_unique(); sample->key = key; sample->prompt = prompt; diff --git a/flutter/cpp/datasets/ifeval.h b/flutter/cpp/datasets/ifeval.h index 5ce4ea357..684bca3c0 100644 --- a/flutter/cpp/datasets/ifeval.h +++ b/flutter/cpp/datasets/ifeval.h @@ -81,6 +81,7 @@ class IFEval : public Dataset { std::unordered_set used_sample_ids_; std::set loaded_sample_ids_; std::unique_ptr sp_processor; + static constexpr int input_token_limit_ = 1024; static constexpr int token_limit_ = 1024; }; diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 4b79efc44..15358af8b 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -45,6 +45,13 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, std::vector input_tokens; sp_processor->Encode(input.c_str(), &input_tokens).ok(); + // input token sanity check + if (input_tokens.size() > input_token_limit_) { + LOG(WARNING) << "Input token limit exceeded for entry " + << std::to_string(i) << ". Ignoring."; + continue; + } + auto sample = std::make_unique(); sample->input = input; sample->input_tokens = input_tokens; diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 18ecc4d0d..8bbb8f544 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -66,6 +66,7 @@ class MmluGen : public Dataset { std::unordered_set used_sample_ids_; std::set loaded_sample_ids_; std::unique_ptr sp_processor; + static constexpr int input_token_limit_ = 1024; static constexpr int token_limit_ = 4; }; diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index f7552cd4d..31818ea75 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -151,8 +151,8 @@ struct LLMBackendData { kv_cache_t kv_cache; std::vector prompt_tokens; std::vector output_tokens; - uint8_t threads = 2; - int max_output_tokens = 1024; + uint8_t threads = 8; + int max_output_tokens = 128; std::unordered_set stop_token_ids{128001, 128008, 128009}; LLMBackendData() {}