Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ const tokenizerConfig = await fetch(`https://huggingface.co/${modelId}/resolve/m
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);

// Tokenize text
const tokens = tokenizer.tokenize('Hello World'); // ['Hello', 'ĠWorld']
const encoded = tokenizer.encode('Hello World'); // [9906, 4435]
const decoded = tokenizer.decode(encoded); // 'Hello World'
const tokens = tokenizer.tokenize('Hello World'); // ['Hello', 'ĠWorld']
const encoded = tokenizer.encode('Hello World'); // { ids: [9906, 4435], tokens: ['Hello', 'ĠWorld'], attention_mask: [1, 1] }
const decoded = tokenizer.decode(encoded.ids); // 'Hello World'
```

## Requirements
Expand Down
77 changes: 36 additions & 41 deletions src/core/Tokenizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,48 @@ class Tokenizer {
this.config.do_lowercase_and_remove_accent ?? false;
}

/**
* Encodes a single text or a pair of texts using the model's tokenizer.
*
* @param text The text to encode.
* @param options An optional object containing the following properties:
* @returns An object containing the encoded text.
*/

// Overload: when return_token_type_ids is explicitly true
public encode(
text: string,
options: EncodeOptions & { return_token_type_ids: true },
): EncodingSingle & { token_type_ids: number[] };

// Overload: when return_token_type_ids is false/null or not provided
public encode(text: string, options?: EncodeOptions): EncodingSingle;

// Implementation
public encode(
text: string,
{
text_pair,
add_special_tokens,
return_token_type_ids,
text_pair = null,
add_special_tokens = true,
return_token_type_ids = null,
}: EncodeOptions = {},
): Array<number> {
return this.encode_plus(text, {
): EncodingSingle {
const { tokens, token_type_ids } = this.tokenize_helper(text, {
text_pair,
add_special_tokens,
return_token_type_ids,
}).input_ids;
});

const input_ids = this.model.convert_tokens_to_ids(tokens);
const result: EncodingSingle = {
ids: input_ids,
tokens,
attention_mask: new Array(input_ids.length).fill(1),
};

if (return_token_type_ids && token_type_ids) {
result.token_type_ids = token_type_ids;
}
return result;
}

public decode(
Expand Down Expand Up @@ -198,40 +227,6 @@ class Tokenizer {
return this.tokenize_helper(text, { text_pair, add_special_tokens }).tokens;
}

/**
* Encodes a single text or a pair of texts using the model's tokenizer.
*
* @param text The text to encode.
* @param options An optional object containing the following properties:
* @returns An object containing the encoded text.
* @private
*/

private encode_plus(
text: string,
{
text_pair = null,
add_special_tokens = true,
return_token_type_ids = null,
}: EncodeOptions,
): EncodingSingle {
const { tokens, token_type_ids } = this.tokenize_helper(text, {
text_pair,
add_special_tokens,
});

const input_ids = this.model.convert_tokens_to_ids(tokens);
const result: EncodingSingle = {
input_ids,
attention_mask: new Array(input_ids.length).fill(1),
};

if (return_token_type_ids && token_type_ids) {
result.token_type_ids = token_type_ids;
}
return result;
}

private encode_text(text: string | null): string[] | null {
if (text === null) {
return null;
Expand Down
3 changes: 2 additions & 1 deletion src/static/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ export type DataType =
| "int4";

export interface EncodingSingle {
input_ids: number[];
ids: number[];
tokens: string[];
attention_mask: number[];
token_type_ids?: number[];
}
11 changes: 9 additions & 2 deletions tests/bundle.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,21 @@ const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
// Tokenize text
const tokens = tokenizer.tokenize('Hello World');
const encoded = tokenizer.encode('Hello World');
const decoded = tokenizer.decode(encoded);
const decoded = tokenizer.decode(encoded.ids);

console.log(tokens);
console.log(encoded);
console.log(decoded);
`;

const TARGET_OUTPUT = "[ '▁Hello', '▁World' ]\n[ 1, 15043, 2787 ]\n<s> Hello World\n";
const TARGET_OUTPUT = `[ '▁Hello', '▁World' ]
{
ids: [ 1, 15043, 2787 ],
tokens: [ '<s>', '▁Hello', '▁World' ],
attention_mask: [ 1, 1, 1 ]
}
<s> Hello World
`;

const wrap_async_iife = (code: string) => `(async function() { ${code} })();`;

Expand Down
16 changes: 8 additions & 8 deletions tests/edgeCases.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ describe("Edge cases", () => {
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);

let text = String.prototype.repeat.call("a", 50000);
let token_ids = tokenizer.encode(text);
expect(token_ids).toEqual([101, 100, 102]);
let { ids } = tokenizer.encode(text);
expect(ids).toEqual([101, 100, 102]);
}, 5000); // NOTE: 5 seconds

it("Special/added tokens with earlier partial matches", async () => {
const modelId = "Xenova/gemini-nano";
const { tokenizerJson, tokenizerConfig } = await fetchConfigById(modelId);
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
{
let token_ids = tokenizer.encode("\n", { add_special_tokens: false });
expect(token_ids).toEqual([108]);
let { ids } = tokenizer.encode("\n", { add_special_tokens: false });
expect(ids).toEqual([108]);
}
{
let token_ids = tokenizer.encode("\n\n", { add_special_tokens: false });
expect(token_ids).toEqual([109]); // Should not be [108, 108]
let { ids } = tokenizer.encode("\n\n", { add_special_tokens: false });
expect(ids).toEqual([109]); // Should not be [108, 108]
}
}, 60_000);

Expand All @@ -32,7 +32,7 @@ describe("Edge cases", () => {
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);

let text = "hello world!";
let token_ids = tokenizer.encode(text);
expect(token_ids).toEqual([128000, 15339, 1917, 0]);
let { ids } = tokenizer.encode(text);
expect(ids).toEqual([128000, 15339, 1917, 0]);
}, 5000); // NOTE: 5 seconds
});
6 changes: 3 additions & 3 deletions tests/models/llama/llama.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ describe("hard-coded", () => {
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);

for (const [text, expected] of Object.entries(data)) {
const token_ids = tokenizer.encode(text, {
const encoded = tokenizer.encode(text, {
add_special_tokens: false,
});
expect(token_ids).toEqual(expected);
expect(encoded.ids).toEqual(expected);

// If reversible, test that decoding produces the original text
if (reversible) {
const decoded = tokenizer.decode(token_ids);
const decoded = tokenizer.decode(encoded.ids);
expect(decoded).toEqual(text);
}
}
Expand Down
6 changes: 3 additions & 3 deletions tests/models/t5/t5.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ describe("hard-coded", () => {
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);

for (const [text, expected] of Object.entries(data)) {
const token_ids = tokenizer.encode(text, {
const encoded = tokenizer.encode(text, {
add_special_tokens: false,
});
expect(token_ids).toEqual(expected);
expect(encoded.ids).toEqual(expected);

// If reversible, test that decoding produces the original text
if (reversible) {
const decoded = tokenizer.decode(token_ids);
const decoded = tokenizer.decode(encoded.ids);
expect(decoded).toEqual(text);
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/tokenizers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ describe("Tokenizers (model-specific)", () => {
for (const [testName, testCase] of Object.entries(config.default[modelId])) {
test(testName, () => {
if (testCase.ids) {
const ids = tokenizer.encode(testCase.text, {
const encoded = tokenizer.encode(testCase.text, {
text_pair: testCase.text_pair,
});
expect(ids).toEqual(testCase.ids);
expect(encoded.ids).toEqual(testCase.ids);

if (testCase.decoded) {
const decoded = tokenizer.decode(testCase.ids);
Expand Down