Skip to content

Commit 605d8ee

Browse files
committed
Normalize input and output
1 parent b6287bc commit 605d8ee

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

examples/imageClassifier-transformer-single-image/sketch.js

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@ let img;
1515
// Variables for displaying the results on the canvas
1616
let label = "";
1717
let confidence = "";
18-
let url = "images/bird.jpg";
1918

2019
function preload() {
2120
classifier = ml5.imageClassifier("transformer");
22-
img = loadImage(url);
21+
img = loadImage("images/bird.jpg");
2322
}
2423

2524
function setup() {
2625
createCanvas(400, 400);
27-
classifier.classify(url, gotResult);
26+
classifier.classify(img, gotResult);
2827
image(img, 0, 0, width, height);
2928
}
3029

@@ -39,7 +38,7 @@ function gotResult(results) {
3938
stroke(0);
4039
textSize(18);
4140
label = "Label: " + results[0].label;
42-
confidence = "Confidence: " + nf(results[0].score, 0, 2);
41+
confidence = "Confidence: " + nf(results[0].confidence, 0, 2);
4342
text(label, 10, 360);
4443
text(confidence, 10, 380);
4544
}

src/ImageClassifier/transformer.js

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import { pipeline } from "@huggingface/transformers";
2+
import handleArguments from "../utils/handleArguments";
23

4+
/**
5+
* @reference https://huggingface.co/docs/transformers.js/en/api/pipelines#module_pipelines.ImageClassificationPipeline
6+
*/
37
export class ImageClassifierTransformer {
48
constructor(options, callback) {
59
this.classifier = null;
@@ -16,18 +20,30 @@ export class ImageClassifierTransformer {
1620
});
1721
}
1822

19-
async classify(input, callback) {
20-
if (this.isClassifying) return;
21-
if (!this.classifier) return;
23+
async classify(inputNumOrCallback, numOrCallback, cb) {
24+
if (this.isClassifying || !this.classifier) return;
2225
this.isClassifying = true;
23-
const results = await this.classifier(input);
24-
callback(results);
26+
const { image, number, callback } = handleArguments(
27+
inputNumOrCallback,
28+
numOrCallback,
29+
cb
30+
).require(
31+
"image",
32+
"No input image provided. If you want to classify a video, use classifyStart."
33+
);
34+
const options = number !== undefined ? { top_k: number } : {};
35+
const results = await this.classifier(image, options);
36+
const normalized = results.map((result) => ({
37+
label: result.label,
38+
confidence: result.score,
39+
}));
40+
callback(normalized);
2541
this.isClassifying = false;
42+
return normalized;
2643
}
2744

2845
async classifyStart(input, callback) {
29-
if (this.isClassifying) return;
30-
if (!this.classifier) return;
46+
if (this.isClassifying || !this.classifier) return;
3147
this.needToStop = false;
3248
const next = (...args) => {
3349
if (this.needToStop) return;

0 commit comments

Comments
 (0)