1- import 'dart:io' ;
21import 'dart:math' ;
32
3+ import 'package:image/image.dart' ;
44import 'package:collection/collection.dart' ;
55import 'package:logger/logger.dart' ;
66import 'package:tflite_flutter/tflite_flutter.dart' ;
77import 'package:tflite_flutter_helper/tflite_flutter_helper.dart' ;
88
99abstract class Classifier {
10- Interpreter _interpreter;
10+ Interpreter interpreter;
11+ InterpreterOptions _interpreterOptions;
1112
1213 var logger = Logger ();
1314
@@ -19,38 +20,51 @@ abstract class Classifier {
1920
2021 TfLiteType _outputType = TfLiteType .uint8;
2122
22- final String _labelsFileName = 'assets/labels_mobilenet_quant_v1_224 .txt' ;
23+ final String _labelsFileName = 'assets/labels .txt' ;
2324
2425 final int _labelsLength = 1001 ;
2526
26- List <String > _labels;
27+ var _probabilityProcessor;
28+
29+ List <String > labels;
2730
2831 String get modelName;
2932
3033 NormalizeOp get preProcessNormalizeOp;
3134 NormalizeOp get postProcessNormalizeOp;
3235
33- Classifier () {
34- _loadModel ();
35- _loadLabels ();
36+ Classifier ({int numThreads}) {
37+ _interpreterOptions = InterpreterOptions ();
38+
39+ if (numThreads != null ) {
40+ _interpreterOptions.threads = numThreads;
41+ }
42+
43+ loadModel ();
44+ loadLabels ();
3645 }
3746
38- Future <void > _loadModel () async {
39- _interpreter = await Interpreter .fromAsset (modelName);
40- if (_interpreter != null ) {
47+ Future <void > loadModel () async {
48+ try {
49+ interpreter =
50+ await Interpreter .fromAsset (modelName, options: _interpreterOptions);
4151 print ('Interpreter Created Successfully' );
42- _inputShape = _interpreter.getInputTensor (0 ).shape;
43- _outputShape = _interpreter.getOutputTensor (0 ).shape;
44- _outputType = _interpreter.getOutputTensor (0 ).type;
52+
53+ _inputShape = interpreter.getInputTensor (0 ).shape;
54+ _outputShape = interpreter.getOutputTensor (0 ).shape;
55+ _outputType = interpreter.getOutputTensor (0 ).type;
56+
4557 _outputBuffer = TensorBuffer .createFixedSize (_outputShape, _outputType);
46- } else {
47- print ('Unable to create interpreter' );
58+ _probabilityProcessor =
59+ TensorProcessorBuilder ().add (postProcessNormalizeOp).build ();
60+ } catch (e) {
61+ print ('Unable to create interpreter, Caught Exception: ${e .toString ()}' );
4862 }
4963 }
5064
51- Future <void > _loadLabels () async {
52- _labels = await FileUtil .loadLabels (_labelsFileName);
53- if (_labels .length == _labelsLength) {
65+ Future <void > loadLabels () async {
66+ labels = await FileUtil .loadLabels (_labelsFileName);
67+ if (labels .length == _labelsLength) {
5468 print ('Labels loaded successfully' );
5569 } else {
5670 print ('Unable to load labels' );
@@ -68,30 +82,39 @@ abstract class Classifier {
6882 .process (_inputImage);
6983 }
7084
71- Future <Category > predict (File imageFile) async {
72- _inputImage = TensorImage .fromFile (imageFile);
85+ Category predict (Image image) {
86+ if (interpreter == null ) {
87+ throw StateError ('Cannot run inference, Intrepreter is null' );
88+ }
89+ final pres = DateTime .now ().millisecondsSinceEpoch;
90+ _inputImage = TensorImage .fromImage (image);
7391 _inputImage = _preProcess ();
92+ final pre = DateTime .now ().millisecondsSinceEpoch - pres;
7493
75- logger. d (_inputImage.tensorBuffer. getDoubleList (). sublist ( 0 , 30 ) );
94+ print ( 'Time to load image: $ pre ms' );
7695
77- final st = DateTime .now ().millisecondsSinceEpoch;
78- _interpreter.run (_inputImage.buffer, _outputBuffer.getBuffer ());
79- logger.d (
80- 'Run Time :' + (DateTime .now ().millisecondsSinceEpoch - st).toString ());
96+ final runs = DateTime .now ().millisecondsSinceEpoch;
97+ interpreter.run (_inputImage.buffer, _outputBuffer.getBuffer ());
98+ final run = DateTime .now ().millisecondsSinceEpoch - runs;
8199
82- final _probabilityProcessor =
83- TensorProcessorBuilder ().add (postProcessNormalizeOp).build ();
100+ print ('Time to run inference: $run ms' );
84101
85102 Map <String , double > labeledProb = TensorLabel .fromList (
86- _labels , _probabilityProcessor.process (_outputBuffer))
103+ labels , _probabilityProcessor.process (_outputBuffer))
87104 .getMapWithFloatValue ();
88-
89105 final pred = getTopProbability (labeledProb);
106+
90107 return Category (pred.key, pred.value);
91108 }
109+
110+ void close () {
111+ if (interpreter != null ) {
112+ interpreter.close ();
113+ }
114+ }
92115}
93116
94- getTopProbability (Map <String , double > labeledProb) {
117+ MapEntry < String , double > getTopProbability (Map <String , double > labeledProb) {
95118 var pq = PriorityQueue <MapEntry <String , double >>(compare);
96119 pq.addAll (labeledProb.entries);
97120
0 commit comments