diff --git a/README.md b/README.md index 33c14d9..7d2432c 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,9 @@ const tokenizerConfig = await fetch(`https://huggingface.co/${modelId}/resolve/m const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig); // Tokenize text -const tokens = tokenizer.tokenize('Hello World'); // ['Hello', 'ĠWorld'] -const encoded = tokenizer.encode('Hello World'); // [9906, 4435] -const decoded = tokenizer.decode(encoded); // 'Hello World' +const tokens = tokenizer.tokenize('Hello World'); // ['Hello', 'ĠWorld'] +const encoded = tokenizer.encode('Hello World'); // { ids: [9906, 4435], tokens: ['Hello', 'ĠWorld'], attention_mask: [1, 1] } +const decoded = tokenizer.decode(encoded.ids); // 'Hello World' ``` ## Requirements diff --git a/src/core/Tokenizer.ts b/src/core/Tokenizer.ts index 844086b..95946b0 100644 --- a/src/core/Tokenizer.ts +++ b/src/core/Tokenizer.ts @@ -128,19 +128,48 @@ class Tokenizer { this.config.do_lowercase_and_remove_accent ?? false; } + /** + * Encodes a single text or a pair of texts using the model's tokenizer. + * + * @param text The text to encode. + * @param options An optional object containing the following properties: + * @returns An object containing the encoded text. + */ + + // Overload: when return_token_type_ids is explicitly true + public encode( + text: string, + options: EncodeOptions & { return_token_type_ids: true }, + ): EncodingSingle & { token_type_ids: number[] }; + + // Overload: when return_token_type_ids is false/null or not provided + public encode(text: string, options?: EncodeOptions): EncodingSingle; + + // Implementation public encode( text: string, { - text_pair, - add_special_tokens, - return_token_type_ids, + text_pair = null, + add_special_tokens = true, + return_token_type_ids = null, }: EncodeOptions = {}, - ): Array { - return this.encode_plus(text, { + ): EncodingSingle { + const { tokens, token_type_ids } = this.tokenize_helper(text, { text_pair, add_special_tokens, - return_token_type_ids, - }).input_ids; + }); + + const input_ids = this.model.convert_tokens_to_ids(tokens); + const result: EncodingSingle = { + ids: input_ids, + tokens, + attention_mask: new Array(input_ids.length).fill(1), + }; + + if (return_token_type_ids && token_type_ids) { + result.token_type_ids = token_type_ids; + } + return result; } public decode( @@ -198,40 +227,6 @@ class Tokenizer { return this.tokenize_helper(text, { text_pair, add_special_tokens }).tokens; } - /** - * Encodes a single text or a pair of texts using the model's tokenizer. - * - * @param text The text to encode. - * @param options An optional object containing the following properties: - * @returns An object containing the encoded text. - * @private - */ - - private encode_plus( - text: string, - { - text_pair = null, - add_special_tokens = true, - return_token_type_ids = null, - }: EncodeOptions, - ): EncodingSingle { - const { tokens, token_type_ids } = this.tokenize_helper(text, { - text_pair, - add_special_tokens, - }); - - const input_ids = this.model.convert_tokens_to_ids(tokens); - const result: EncodingSingle = { - input_ids, - attention_mask: new Array(input_ids.length).fill(1), - }; - - if (return_token_type_ids && token_type_ids) { - result.token_type_ids = token_type_ids; - } - return result; - } - private encode_text(text: string | null): string[] | null { if (text === null) { return null; diff --git a/src/static/types.ts b/src/static/types.ts index 78e77fb..f86e7ee 100644 --- a/src/static/types.ts +++ b/src/static/types.ts @@ -33,7 +33,8 @@ export type DataType = | "int4"; export interface EncodingSingle { - input_ids: number[]; + ids: number[]; + tokens: string[]; attention_mask: number[]; token_type_ids?: number[]; } diff --git a/tests/bundle.test.ts b/tests/bundle.test.ts index dcee92b..e39b62b 100644 --- a/tests/bundle.test.ts +++ b/tests/bundle.test.ts @@ -14,14 +14,21 @@ const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig); // Tokenize text const tokens = tokenizer.tokenize('Hello World'); const encoded = tokenizer.encode('Hello World'); -const decoded = tokenizer.decode(encoded); +const decoded = tokenizer.decode(encoded.ids); console.log(tokens); console.log(encoded); console.log(decoded); `; -const TARGET_OUTPUT = "[ '▁Hello', '▁World' ]\n[ 1, 15043, 2787 ]\n Hello World\n"; +const TARGET_OUTPUT = `[ '▁Hello', '▁World' ] +{ + ids: [ 1, 15043, 2787 ], + tokens: [ '', '▁Hello', '▁World' ], + attention_mask: [ 1, 1, 1 ] +} + Hello World +`; const wrap_async_iife = (code: string) => `(async function() { ${code} })();`; diff --git a/tests/edgeCases.test.ts b/tests/edgeCases.test.ts index e05273f..e5280ef 100644 --- a/tests/edgeCases.test.ts +++ b/tests/edgeCases.test.ts @@ -8,8 +8,8 @@ describe("Edge cases", () => { const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig); let text = String.prototype.repeat.call("a", 50000); - let token_ids = tokenizer.encode(text); - expect(token_ids).toEqual([101, 100, 102]); + let { ids } = tokenizer.encode(text); + expect(ids).toEqual([101, 100, 102]); }, 5000); // NOTE: 5 seconds it("Special/added tokens with earlier partial matches", async () => { @@ -17,12 +17,12 @@ describe("Edge cases", () => { const { tokenizerJson, tokenizerConfig } = await fetchConfigById(modelId); const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig); { - let token_ids = tokenizer.encode("\n", { add_special_tokens: false }); - expect(token_ids).toEqual([108]); + let { ids } = tokenizer.encode("\n", { add_special_tokens: false }); + expect(ids).toEqual([108]); } { - let token_ids = tokenizer.encode("\n\n", { add_special_tokens: false }); - expect(token_ids).toEqual([109]); // Should not be [108, 108] + let { ids } = tokenizer.encode("\n\n", { add_special_tokens: false }); + expect(ids).toEqual([109]); // Should not be [108, 108] } }, 60_000); @@ -32,7 +32,7 @@ describe("Edge cases", () => { const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig); let text = "hello world!"; - let token_ids = tokenizer.encode(text); - expect(token_ids).toEqual([128000, 15339, 1917, 0]); + let { ids } = tokenizer.encode(text); + expect(ids).toEqual([128000, 15339, 1917, 0]); }, 5000); // NOTE: 5 seconds }); diff --git a/tests/models/llama/llama.test.ts b/tests/models/llama/llama.test.ts index 307628c..8186914 100644 --- a/tests/models/llama/llama.test.ts +++ b/tests/models/llama/llama.test.ts @@ -49,14 +49,14 @@ describe("hard-coded", () => { const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig); for (const [text, expected] of Object.entries(data)) { - const token_ids = tokenizer.encode(text, { + const encoded = tokenizer.encode(text, { add_special_tokens: false, }); - expect(token_ids).toEqual(expected); + expect(encoded.ids).toEqual(expected); // If reversible, test that decoding produces the original text if (reversible) { - const decoded = tokenizer.decode(token_ids); + const decoded = tokenizer.decode(encoded.ids); expect(decoded).toEqual(text); } } diff --git a/tests/models/t5/t5.test.ts b/tests/models/t5/t5.test.ts index 2743df4..30c4d37 100644 --- a/tests/models/t5/t5.test.ts +++ b/tests/models/t5/t5.test.ts @@ -38,14 +38,14 @@ describe("hard-coded", () => { const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig); for (const [text, expected] of Object.entries(data)) { - const token_ids = tokenizer.encode(text, { + const encoded = tokenizer.encode(text, { add_special_tokens: false, }); - expect(token_ids).toEqual(expected); + expect(encoded.ids).toEqual(expected); // If reversible, test that decoding produces the original text if (reversible) { - const decoded = tokenizer.decode(token_ids); + const decoded = tokenizer.decode(encoded.ids); expect(decoded).toEqual(text); } } diff --git a/tests/tokenizers.test.ts b/tests/tokenizers.test.ts index 0e4ab24..387aacc 100644 --- a/tests/tokenizers.test.ts +++ b/tests/tokenizers.test.ts @@ -20,10 +20,10 @@ describe("Tokenizers (model-specific)", () => { for (const [testName, testCase] of Object.entries(config.default[modelId])) { test(testName, () => { if (testCase.ids) { - const ids = tokenizer.encode(testCase.text, { + const encoded = tokenizer.encode(testCase.text, { text_pair: testCase.text_pair, }); - expect(ids).toEqual(testCase.ids); + expect(encoded.ids).toEqual(testCase.ids); if (testCase.decoded) { const decoded = tokenizer.decode(testCase.ids);