Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 97 additions & 31 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ export class LLMChatPipeline {
private fapplyPenalty: tvmjs.PackedFunc;
private fapplyLogitBias: tvmjs.PackedFunc;
private fsoftmaxWithTemperature: tvmjs.PackedFunc;
private fsampleWithTopP: tvmjs.PackedFunc;
private fargsortProbs: tvmjs.PackedFunc;

// Functions related to PagedKVCache
private fclearKVCaches: tvmjs.PackedFunc;
Expand Down Expand Up @@ -142,6 +144,10 @@ export class LLMChatPipeline {
private curRoundGrammarInitTotalTime = 0;
// Total time of getting next bitmask and accepting token in seconds
private curRoundGrammarPerTokenTotalTime = 0;
// Instance variables for supporting sampling on WebGPU
private sampleIndices: Int32Array;
private sampleIndicesDevice: tvmjs.Tensor;
private topPDevice: tvmjs.Tensor;

constructor(
tvm: tvmjs.Instance,
Expand Down Expand Up @@ -213,6 +219,12 @@ export class LLMChatPipeline {
this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope(
this.vm.getFunction("softmax_with_temperature"),
);
this.fsampleWithTopP = this.tvm.detachFromCurrentScope(
this.vm.getFunction("sample_with_top_p"),
);
this.fargsortProbs = this.tvm.detachFromCurrentScope(
this.vm.getFunction("argsort_probs"),
);
try {
this.image_embed = this.tvm.detachFromCurrentScope(
this.vm.getFunction("image_embed"),
Expand Down Expand Up @@ -310,6 +322,25 @@ export class LLMChatPipeline {

this.filledKVCacheLength = 0;
this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence

// Initialize WebGPU sampling related device tensors
const numSamples = 1;
const numProbs = 1;

this.sampleIndices = new Int32Array(numSamples);
for (let i = 0; i < numSamples; i++) {
this.sampleIndices[i] = i;
}
this.sampleIndicesDevice = this.tvm.detachFromCurrentScope(
this.tvm
.empty([numSamples], "int32", this.device)
.copyFrom(this.sampleIndices),
);

this.topPDevice = this.tvm.detachFromCurrentScope(
this.tvm.empty([numProbs], "float32", this.device),
);

tvm.endScope();
}

Expand Down Expand Up @@ -942,7 +973,7 @@ export class LLMChatPipeline {
for (let i = 0; i < inputData.length; i++) {
const data = inputData[i];
if (Array.isArray(data)) {
embeddings.push(await this.getTokensEmbeddings(data));
embeddings.push(this.getTokensEmbeddings(data));
} else {
embeddings.push(await this.getImageEmbeddings(data));
}
Expand Down Expand Up @@ -1024,6 +1055,12 @@ export class LLMChatPipeline {
if (_hasValue(genConfig.top_p)) {
top_p = genConfig.top_p!;
}
// TODO: setting top_p to 1.0 by default might run into issues since
// top_p masking in relax uses < instead of <=
// Set default top_p to 1.0 if not set
if (!_hasValue(top_p)) {
top_p = 1.0;
}
if (_hasValue(genConfig.repetition_penalty)) {
repetition_penalty = genConfig.repetition_penalty!;
}
Expand All @@ -1033,13 +1070,19 @@ export class LLMChatPipeline {
if (_hasValue(genConfig.presence_penalty)) {
presence_penalty = genConfig.presence_penalty!;
}
// If only one of frequency or presence penatly is set, make the other one 0.0
// If only one of frequency or presence penalty is set, make the other one 0.0
if (_hasValue(frequency_penalty) && !_hasValue(presence_penalty)) {
presence_penalty = 0.0;
}
if (_hasValue(presence_penalty) && !_hasValue(frequency_penalty)) {
frequency_penalty = 0.0;
}
if (!_hasValue(frequency_penalty)) {
frequency_penalty = 0.0;
}
if (!_hasValue(presence_penalty)) {
presence_penalty = 0.0;
}
if (_hasValue(genConfig.logit_bias)) {
logit_bias = genConfig.logit_bias!;
}
Expand Down Expand Up @@ -1267,46 +1310,69 @@ export class LLMChatPipeline {
}
}

// TODO: Explore usage of multinomial sampling kernel (currently blocked due to usage
// of i8) for cases where top_p is not set
// 4. Sample token from logits
// If logprobs, need the actual distribution via softmax, otherwise directly sample from logits
const sampleBegin = performance.now();
let sampledToken: number;
if (logprobs) {
// Inplace transform logitsOnCPU to a distribution
temperature = Math.max(1e-6, temperature); // to prevent division by zero

const numSeqs = 1;
// Inplace transform logitsOnCPU to a distribution
temperature = Math.max(1e-6, temperature); // to prevent division by zero

const temperatures = new Float32Array([temperature]);
const numSeqs = 1;
const numProbs = 1;

this.tvm.beginScope();
const temperaturesDevice = this.tvm
.empty([numSeqs], "float32", this.device)
.copyFrom(temperatures);
const temperatures = new Float32Array([temperature]);

const probs = this.fsoftmaxWithTemperature(
logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]),
temperaturesDevice,
);
this.tvm.beginScope();
const temperaturesDevice = this.tvm
.empty([numSeqs], "float32", this.device)
.copyFrom(temperatures);

let probs = this.fsoftmaxWithTemperature(
logitsOnGPU.view([numSeqs, numProbs, this.fullVocabSize]),
temperaturesDevice,
);
probs = probs.view([numProbs, this.fullVocabSize]);

const argsortResults = this.fargsortProbs(probs);
const sortedProbsDevice = argsortResults.get(0);
const sortedIndicesDevice = argsortResults.get(1);

const uniformSamplesDevice = this.tvm.uniform([1], 0.0, 1.0, this.device);

const topPHost = new Float32Array(numProbs).fill(-1);
const topPValue = Math.max(top_p, 1e-5);
this.sampleIndices.forEach((row) => {
topPHost[row] = topPValue;
});
this.topPDevice.copyFrom(topPHost);

const sampledTokensDevice = this.tvm.detachFromCurrentScope(
this.fsampleWithTopP(
sortedProbsDevice,
sortedIndicesDevice,
uniformSamplesDevice,
this.sampleIndicesDevice,
this.topPDevice,
),
);
const sampledTokensHost = this.tvm.detachFromCurrentScope(
this.tvm
.empty([numSeqs], "int32", this.tvm.cpu())
.copyFrom(sampledTokensDevice),
);
if (logprobs && top_logprobs! > 0) {
this.updateLogitsOnCPU(probs);
this.tvm.endScope();
await this.device.sync();
}
this.tvm.endScope();
await this.device.sync();

const sampledToken = sampledTokensHost.toArray()[0];

sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU!, top_p);
if (logprobs && top_logprobs! > 0) {
this.tokenLogprobArray.push(
this.getTokenLogprob(sampledToken, top_logprobs!),
);
} else {
// temperature being 0 is allowed here, equivalent to argmax
this.tvm.beginScope();
this.updateLogitsOnCPU(logitsOnGPU);
this.tvm.endScope();
await this.device.sync();
sampledToken = this.tvm.sampleTopPFromLogits(
this.logitsOnCPU!,
temperature,
top_p,
);
}

if (genConfig?.enable_latency_breakdown) {
Expand Down