16
16
17
17
package com .mindorks .tensorflowexample ;
18
18
19
- import android .content .res .AssetManager ;
20
- import android .os .Trace ;
21
- import android .util .Log ;
22
-
23
- import org .tensorflow .contrib .android .TensorFlowInferenceInterface ;
24
-
25
19
import java .io .BufferedReader ;
26
20
import java .io .IOException ;
27
21
import java .io .InputStreamReader ;
31
25
import java .util .PriorityQueue ;
32
26
import java .util .Vector ;
33
27
28
+ import org .tensorflow .contrib .android .TensorFlowInferenceInterface ;
29
+
30
+ import android .content .res .AssetManager ;
31
+ import android .support .v4 .os .TraceCompat ;
32
+ import android .util .Log ;
33
+
34
34
/**
35
35
* Created by amitshekhar on 16/03/17.
36
36
*/
40
40
*/
41
41
public class TensorFlowImageClassifier implements Classifier {
42
42
43
- private static final String TAG = "TensorFlowImageClassifier " ;
43
+ private static final String TAG = "TFImageClassifier " ;
44
44
45
45
// Only return this many results with at least this confidence.
46
46
private static final int MAX_RESULTS = 3 ;
@@ -58,6 +58,8 @@ public class TensorFlowImageClassifier implements Classifier {
58
58
59
59
private TensorFlowInferenceInterface inferenceInterface ;
60
60
61
+ private boolean runStats = false ;
62
+
61
63
private TensorFlowImageClassifier () {
62
64
}
63
65
@@ -96,10 +98,8 @@ public static Classifier create(
96
98
}
97
99
br .close ();
98
100
99
- c .inferenceInterface = new TensorFlowInferenceInterface ();
100
- if (c .inferenceInterface .initializeTensorFlow (assetManager , modelFilename ) != 0 ) {
101
- throw new RuntimeException ("TF initialization failed" );
102
- }
101
+ c .inferenceInterface = new TensorFlowInferenceInterface (assetManager , modelFilename );
102
+
103
103
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
104
104
int numClasses =
105
105
(int ) c .inferenceInterface .graph ().operation (outputName ).output (0 ).shape ().size (1 );
@@ -120,23 +120,22 @@ public static Classifier create(
120
120
@ Override
121
121
public List <Recognition > recognizeImage (final float [] pixels ) {
122
122
// Log this method so that it can be analyzed with systrace.
123
- Trace .beginSection ("recognizeImage" );
123
+ TraceCompat .beginSection ("recognizeImage" );
124
124
125
125
// Copy the input data into TensorFlow.
126
- Trace .beginSection ("fillNodeFloat" );
127
- inferenceInterface .fillNodeFloat (
128
- inputName , new int []{inputSize * inputSize }, pixels );
129
- Trace .endSection ();
126
+ TraceCompat .beginSection ("feed" );
127
+ inferenceInterface .feed (inputName , pixels , new long []{inputSize * inputSize });
128
+ TraceCompat .endSection ();
130
129
131
130
// Run the inference call.
132
- Trace .beginSection ("runInference " );
133
- inferenceInterface .runInference (outputNames );
134
- Trace .endSection ();
131
+ TraceCompat .beginSection ("run " );
132
+ inferenceInterface .run (outputNames , runStats );
133
+ TraceCompat .endSection ();
135
134
136
135
// Copy the output Tensor back into the output array.
137
- Trace .beginSection ("readNodeFloat " );
138
- inferenceInterface .readNodeFloat (outputName , outputs );
139
- Trace .endSection ();
136
+ TraceCompat .beginSection ("fetch " );
137
+ inferenceInterface .fetch (outputName , outputs );
138
+ TraceCompat .endSection ();
140
139
141
140
// Find the best classifications.
142
141
PriorityQueue <Recognition > pq =
@@ -161,13 +160,13 @@ public int compare(Recognition lhs, Recognition rhs) {
161
160
for (int i = 0 ; i < recognitionsSize ; ++i ) {
162
161
recognitions .add (pq .poll ());
163
162
}
164
- Trace .endSection (); // "recognizeImage"
163
+ TraceCompat .endSection (); // "recognizeImage"
165
164
return recognitions ;
166
165
}
167
166
168
167
@ Override
169
168
public void enableStatLogging (boolean debug ) {
170
- inferenceInterface . enableStatLogging ( debug ) ;
169
+ runStats = debug ;
171
170
}
172
171
173
172
@ Override
0 commit comments