Skip to content

Commit a09f6b1

Browse files
authored
Merge pull request #9 from am15h/ci
CI and example code update
2 parents bce823e + 2c78169 commit a09f6b1

File tree

18 files changed

+239
-156
lines changed

18 files changed

+239
-156
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Flutter CI
2+
3+
on:
4+
push:
5+
branches: [ master, actions]
6+
pull_request:
7+
branches: [ master ]
8+
9+
jobs:
10+
android-integration-test:
11+
name: Android
12+
runs-on: macOS-latest
13+
strategy:
14+
matrix:
15+
device:
16+
- "pixel_xl"
17+
fail-fast: false
18+
19+
steps:
20+
- uses: actions/checkout@v1
21+
- uses: actions/setup-java@v1
22+
with:
23+
java-version: '8.x'
24+
- uses: subosito/flutter-action@v1
25+
with:
26+
flutter-version: '1.12.13+hotfix.5'
27+
channel: 'stable'
28+
- name: run tests
29+
timeout-minutes: 30
30+
uses: reactivecircus/android-emulator-runner@v2
31+
env:
32+
ANDROID_SIGN_PWD: ${{ secrets.ANDROID_SIGN_PWD }}
33+
SECRET_REPO: ${{ secrets.SECRET_REPO }}
34+
GITHUB_TOKEN: ${{ secrets.MY_GITHUB_TOKEN }}
35+
with:
36+
api-level: 29
37+
profile: ${{ matrix.device }}
38+
script: |
39+
cd ./example/image_classification && flutter pub get
40+
cd ./example/image_classification && flutter driver --driver='test_driver/image_classification_e2e_test.dart' test/image_classification_e2e.dart
Binary file not shown.
Binary file not shown.
Binary file not shown.
194 KB
Loading

example/image_classification/lib/classifier.dart

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import 'dart:io';
21
import 'dart:math';
32

3+
import 'package:image/image.dart';
44
import 'package:collection/collection.dart';
55
import 'package:logger/logger.dart';
66
import 'package:tflite_flutter/tflite_flutter.dart';
77
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
88

99
abstract class Classifier {
10-
Interpreter _interpreter;
10+
Interpreter interpreter;
11+
InterpreterOptions _interpreterOptions;
1112

1213
var logger = Logger();
1314

@@ -19,38 +20,51 @@ abstract class Classifier {
1920

2021
TfLiteType _outputType = TfLiteType.uint8;
2122

22-
final String _labelsFileName = 'assets/labels_mobilenet_quant_v1_224.txt';
23+
final String _labelsFileName = 'assets/labels.txt';
2324

2425
final int _labelsLength = 1001;
2526

26-
List<String> _labels;
27+
var _probabilityProcessor;
28+
29+
List<String> labels;
2730

2831
String get modelName;
2932

3033
NormalizeOp get preProcessNormalizeOp;
3134
NormalizeOp get postProcessNormalizeOp;
3235

33-
Classifier() {
34-
_loadModel();
35-
_loadLabels();
36+
Classifier({int numThreads}) {
37+
_interpreterOptions = InterpreterOptions();
38+
39+
if (numThreads != null) {
40+
_interpreterOptions.threads = numThreads;
41+
}
42+
43+
loadModel();
44+
loadLabels();
3645
}
3746

38-
Future<void> _loadModel() async {
39-
_interpreter = await Interpreter.fromAsset(modelName);
40-
if (_interpreter != null) {
47+
Future<void> loadModel() async {
48+
try {
49+
interpreter =
50+
await Interpreter.fromAsset(modelName, options: _interpreterOptions);
4151
print('Interpreter Created Successfully');
42-
_inputShape = _interpreter.getInputTensor(0).shape;
43-
_outputShape = _interpreter.getOutputTensor(0).shape;
44-
_outputType = _interpreter.getOutputTensor(0).type;
52+
53+
_inputShape = interpreter.getInputTensor(0).shape;
54+
_outputShape = interpreter.getOutputTensor(0).shape;
55+
_outputType = interpreter.getOutputTensor(0).type;
56+
4557
_outputBuffer = TensorBuffer.createFixedSize(_outputShape, _outputType);
46-
} else {
47-
print('Unable to create interpreter');
58+
_probabilityProcessor =
59+
TensorProcessorBuilder().add(postProcessNormalizeOp).build();
60+
} catch (e) {
61+
print('Unable to create interpreter, Caught Exception: ${e.toString()}');
4862
}
4963
}
5064

51-
Future<void> _loadLabels() async {
52-
_labels = await FileUtil.loadLabels(_labelsFileName);
53-
if (_labels.length == _labelsLength) {
65+
Future<void> loadLabels() async {
66+
labels = await FileUtil.loadLabels(_labelsFileName);
67+
if (labels.length == _labelsLength) {
5468
print('Labels loaded successfully');
5569
} else {
5670
print('Unable to load labels');
@@ -68,30 +82,39 @@ abstract class Classifier {
6882
.process(_inputImage);
6983
}
7084

71-
Future<Category> predict(File imageFile) async {
72-
_inputImage = TensorImage.fromFile(imageFile);
85+
Category predict(Image image) {
86+
if (interpreter == null) {
87+
throw StateError('Cannot run inference, Intrepreter is null');
88+
}
89+
final pres = DateTime.now().millisecondsSinceEpoch;
90+
_inputImage = TensorImage.fromImage(image);
7391
_inputImage = _preProcess();
92+
final pre = DateTime.now().millisecondsSinceEpoch - pres;
7493

75-
logger.d(_inputImage.tensorBuffer.getDoubleList().sublist(0, 30));
94+
print('Time to load image: $pre ms');
7695

77-
final st = DateTime.now().millisecondsSinceEpoch;
78-
_interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer());
79-
logger.d(
80-
'Run Time :' + (DateTime.now().millisecondsSinceEpoch - st).toString());
96+
final runs = DateTime.now().millisecondsSinceEpoch;
97+
interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer());
98+
final run = DateTime.now().millisecondsSinceEpoch - runs;
8199

82-
final _probabilityProcessor =
83-
TensorProcessorBuilder().add(postProcessNormalizeOp).build();
100+
print('Time to run inference: $run ms');
84101

85102
Map<String, double> labeledProb = TensorLabel.fromList(
86-
_labels, _probabilityProcessor.process(_outputBuffer))
103+
labels, _probabilityProcessor.process(_outputBuffer))
87104
.getMapWithFloatValue();
88-
89105
final pred = getTopProbability(labeledProb);
106+
90107
return Category(pred.key, pred.value);
91108
}
109+
110+
void close() {
111+
if (interpreter != null) {
112+
interpreter.close();
113+
}
114+
}
92115
}
93116

94-
getTopProbability(Map<String, double> labeledProb) {
117+
MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) {
95118
var pq = PriorityQueue<MapEntry<String, double>>(compare);
96119
pq.addAll(labeledProb.entries);
97120

example/image_classification/lib/classifier_float.dart

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ import 'package:imageclassification/classifier.dart';
22
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
33

44
class ClassifierFloat extends Classifier {
5+
ClassifierFloat({int numThreads}) : super(numThreads: numThreads);
6+
57
@override
68
String get modelName => 'mobilenet_v1_1.0_224.tflite';
79

example/image_classification/lib/classifier_quant.dart

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ import 'package:imageclassification/classifier.dart';
22
import 'package:tflite_flutter_helper/src/common/ops/normailze_op.dart';
33

44
class ClassifierQuant extends Classifier {
5+
ClassifierQuant({int numThreads: 1}) : super(numThreads: numThreads);
6+
57
@override
68
String get modelName => 'mobilenet_v1_1.0_224_quant.tflite';
79

example/image_classification/lib/main.dart

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
import 'dart:io';
2-
import 'package:flutter/foundation.dart' as f;
2+
import 'package:image/image.dart' as img;
33
import 'package:flutter/material.dart';
44
import 'package:image_picker/image_picker.dart';
55
import 'package:imageclassification/classifier.dart';
66
import 'package:imageclassification/classifier_quant.dart';
77
import 'package:logger/logger.dart';
88
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
99

10-
import 'classifier_float.dart';
11-
1210
void main() => runApp(MyApp());
1311

1412
class MyApp extends StatelessWidget {
1513
@override
1614
Widget build(BuildContext context) {
1715
return MaterialApp(
18-
title: 'Flutter Demo',
16+
title: 'Image Classification',
1917
theme: ThemeData(
2018
primarySwatch: Colors.orange,
2119
),
@@ -43,8 +41,16 @@ class _MyHomePageState extends State<MyHomePage> {
4341

4442
Image _imageWidget;
4543

44+
img.Image fox;
45+
4646
Category category;
4747

48+
@override
49+
void initState() {
50+
super.initState();
51+
_classifier = ClassifierQuant();
52+
}
53+
4854
Future getImage() async {
4955
final pickedFile = await picker.getImage(source: ImageSource.gallery);
5056

@@ -58,21 +64,12 @@ class _MyHomePageState extends State<MyHomePage> {
5864
});
5965
}
6066

61-
@override
62-
void initState() {
63-
super.initState();
64-
_classifier = ClassifierQuant();
65-
}
67+
void _predict() async {
68+
img.Image imageInput = img.decodeImage(_image.readAsBytesSync());
69+
var pred = _classifier.predict(imageInput);
6670

67-
void _predict() {
68-
int st = DateTime.now().millisecondsSinceEpoch;
69-
final pred = _classifier.predict(_image);
70-
pred.then((category) {
71-
int en = DateTime.now().millisecondsSinceEpoch;
72-
logger.d('Total Time: ${en - st}');
73-
setState(() {
74-
this.category = category;
75-
});
71+
setState(() {
72+
this.category = pred;
7673
});
7774
}
7875

0 commit comments

Comments
 (0)