Skip to content

Commit 0fc8b4e

Browse files
Merge pull request #1 from daj/master
Support all architectures using 'org.tensorflow:tensorflow-android' dependency
2 parents 99eeff6 + be08d3d commit 0fc8b4e

File tree

4 files changed

+24
-25
lines changed

4 files changed

+24
-25
lines changed

app/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ dependencies {
4242
})
4343
compile 'com.android.support:appcompat-v7:25.2.0'
4444
testCompile 'junit:junit:4.12'
45-
compile files('libs/libandroid_tensorflow_inference_java.jar')
45+
compile 'org.tensorflow:tensorflow-android:1.2.0'
4646
}
-26.2 KB
Binary file not shown.

app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java

+23-24
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@
1616

1717
package com.mindorks.tensorflowexample;
1818

19-
import android.content.res.AssetManager;
20-
import android.os.Trace;
21-
import android.util.Log;
22-
23-
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
24-
2519
import java.io.BufferedReader;
2620
import java.io.IOException;
2721
import java.io.InputStreamReader;
@@ -31,6 +25,12 @@
3125
import java.util.PriorityQueue;
3226
import java.util.Vector;
3327

28+
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
29+
30+
import android.content.res.AssetManager;
31+
import android.support.v4.os.TraceCompat;
32+
import android.util.Log;
33+
3434
/**
3535
* Created by amitshekhar on 16/03/17.
3636
*/
@@ -40,7 +40,7 @@
4040
*/
4141
public class TensorFlowImageClassifier implements Classifier {
4242

43-
private static final String TAG = "TensorFlowImageClassifier";
43+
private static final String TAG = "TFImageClassifier";
4444

4545
// Only return this many results with at least this confidence.
4646
private static final int MAX_RESULTS = 3;
@@ -58,6 +58,8 @@ public class TensorFlowImageClassifier implements Classifier {
5858

5959
private TensorFlowInferenceInterface inferenceInterface;
6060

61+
private boolean runStats = false;
62+
6163
private TensorFlowImageClassifier() {
6264
}
6365

@@ -96,10 +98,8 @@ public static Classifier create(
9698
}
9799
br.close();
98100

99-
c.inferenceInterface = new TensorFlowInferenceInterface();
100-
if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
101-
throw new RuntimeException("TF initialization failed");
102-
}
101+
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
102+
103103
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
104104
int numClasses =
105105
(int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
@@ -120,23 +120,22 @@ public static Classifier create(
120120
@Override
121121
public List<Recognition> recognizeImage(final float[] pixels) {
122122
// Log this method so that it can be analyzed with systrace.
123-
Trace.beginSection("recognizeImage");
123+
TraceCompat.beginSection("recognizeImage");
124124

125125
// Copy the input data into TensorFlow.
126-
Trace.beginSection("fillNodeFloat");
127-
inferenceInterface.fillNodeFloat(
128-
inputName, new int[]{inputSize * inputSize}, pixels);
129-
Trace.endSection();
126+
TraceCompat.beginSection("feed");
127+
inferenceInterface.feed(inputName, pixels, new long[]{inputSize * inputSize});
128+
TraceCompat.endSection();
130129

131130
// Run the inference call.
132-
Trace.beginSection("runInference");
133-
inferenceInterface.runInference(outputNames);
134-
Trace.endSection();
131+
TraceCompat.beginSection("run");
132+
inferenceInterface.run(outputNames, runStats);
133+
TraceCompat.endSection();
135134

136135
// Copy the output Tensor back into the output array.
137-
Trace.beginSection("readNodeFloat");
138-
inferenceInterface.readNodeFloat(outputName, outputs);
139-
Trace.endSection();
136+
TraceCompat.beginSection("fetch");
137+
inferenceInterface.fetch(outputName, outputs);
138+
TraceCompat.endSection();
140139

141140
// Find the best classifications.
142141
PriorityQueue<Recognition> pq =
@@ -161,13 +160,13 @@ public int compare(Recognition lhs, Recognition rhs) {
161160
for (int i = 0; i < recognitionsSize; ++i) {
162161
recognitions.add(pq.poll());
163162
}
164-
Trace.endSection(); // "recognizeImage"
163+
TraceCompat.endSection(); // "recognizeImage"
165164
return recognitions;
166165
}
167166

168167
@Override
169168
public void enableStatLogging(boolean debug) {
170-
inferenceInterface.enableStatLogging(debug);
169+
runStats = debug;
171170
}
172171

173172
@Override
Binary file not shown.

0 commit comments

Comments
 (0)