TensorFlow Lite on Android

Diehard04
Analytics Vidhya
Published in
6 min readJan 14, 2020

--

TensorFlow Lite?

TensorFlow Lite is a set of tools to help developers run TensorFlow models on mobile, embedded, and IoT devices. It enables on-device machine learning inference with low latency and a small binary size.

It’s currently supported on Android and iOS via a C++ API, as well as having a Java Wrapper for Android Developers. Additionally, on Android Devices that support it, the interpreter can also use the Android Neural Networks API for hardware acceleration, otherwise it will default to the CPU for execution. In this article I’ll focus on how you use it in an Android app.

TensorFlow Lite consists of two main components:

  • The TensorFlow Lite interpreter, which runs specially optimized models on many different hardware types, including mobile phones, embedded Linux devices, and microcontrollers.
  • The TensorFlow Lite converter, which converts TensorFlow models into an efficient form for use by the interpreter, and can introduce optimizations to improve binary size and performance.

These are the details about the abbreviation

  • Latency: there’s no round-trip to a server
  • Privacy: no data needs to leave the device
  • Connectivity: an Internet connection isn’t required
  • Power consumption: network connections are power hungry

In the below I am going to share how to implement and use of TensorFlow’ slibrary lite in Android Project.

NOTE → Before going below codes you can check github link https://github.com/amitshekhariitbhu/Android-TensorFlow-Lite-Example for implementation of TensorFlow’s lite I also follow his link. It’s a best and maintain project which is available in github.

Step 1. Add dependency

dependencies {
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
}

Step 2. (OPTIONAL)for achieving the x86, x86_64, and arm32 ABIs , add the below dependency in your gradle.

android {
defaultConfig {
ndk {
abiFilters 'armeabi-v7a', 'arm64-v8a'
}
}
}

Step 3 -> Create layout for Camera view activity_ tensorflow.xml

<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/activity_main"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:paddingBottom="@dimen/activity_vertical_margin"
android:paddingLeft="@dimen/activity_horizontal_margin"
android:paddingRight="@dimen/activity_horizontal_margin"
android:paddingTop="@dimen/activity_vertical_margin"
tools:context="com.example.myapp.tancerflow.TenserFlowActivity">

<com.wonderkiln.camerakit.CameraView
android:id="@+id/cameraView"
android:layout_width="300dp"
android:layout_height="300dp"
android:layout_gravity="center|top" />


<LinearLayout
android:layout_width="match_parent"
android:layout_height="80dp"
android:layout_gravity="center|top"
android:layout_marginTop="300dp"
android:gravity="center"
android:orientation="horizontal">

<ImageView
android:id="@+id/imageViewResult"
android:layout_width="75dp"
android:layout_height="75dp"
android:padding="2dp" />

<TextView
android:id="@+id/textViewResult"
android:layout_width="match_parent"
android:layout_height="80dp"
android:fadeScrollbars="false"
android:gravity="center"
android:maxLines="15"
android:scrollbars="vertical"
android:textColor="@android:color/black" />

</LinearLayout>


<Button
android:id="@+id/btnToggleCamera"
android:layout_width="match_parent"
android:layout_height="48dp"
android:layout_gravity="bottom|center"
android:layout_marginBottom="50dp"
android:text="toggle_camera"
android:textAllCaps="false"
android:textColor="@android:color/black" />

<Button
android:id="@+id/btnDetectObject"
android:layout_width="match_parent"
android:layout_height="48dp"
android:layout_gravity="bottom|center"
android:text="detect_object"
android:textAllCaps="false"
android:textColor="@android:color/black"
android:visibility="gone" />

</FrameLayout>

Note → For dependency Camera view implement gradle in your project

implementation 'com.wonderkiln:camerakit:0.13.1'

Step 4 → Add mobilenet_quant_v1_224.tflite file in your assets folder.

for downloading file check in github https://github.com/amitshekhariitbhu/Android-TensorFlow-Lite-Example/tree/master/app/src/main/assets

Step 5 → Create a TensorFlowActivity.class and give the correct path of tflite file in the place of modelPath.

below the whole class should be like this

public class TensorFlowActivity extends AppCompatActivity {

private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite";
private static final boolean QUANT = true;
private static final String LABEL_PATH = "labels.txt";
private static final int INPUT_SIZE = 224;

private Classifier classifier;

private Executor executor = Executors.newSingleThreadExecutor();
private TextView textViewResult;
private Button btnDetectObject, btnToggleCamera;
private ImageView imageViewResult;
private CameraView cameraView;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_tesnsorflow);
cameraView = findViewById(R.id.cameraView);
imageViewResult = findViewById(R.id.imageViewResult);
textViewResult = findViewById(R.id.textViewResult);
textViewResult.setMovementMethod(new ScrollingMovementMethod());

btnToggleCamera = findViewById(R.id.btnToggleCamera);
btnDetectObject = findViewById(R.id.btnDetectObject);

cameraView.addCameraKitListener(new CameraKitEventListener() {
@Override
public void onEvent(CameraKitEvent cameraKitEvent) {

}

@Override
public void onError(CameraKitError cameraKitError) {

}

@Override
public void onImage(CameraKitImage cameraKitImage) {

Bitmap bitmap = cameraKitImage.getBitmap();

bitmap = Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, false);

imageViewResult.setImageBitmap(bitmap);

final List<Classifier.Recognition> results = classifier.recognizeImage(bitmap);

textViewResult.setText(results.toString());

}

@Override
public void onVideo(CameraKitVideo cameraKitVideo) {

}
});

btnToggleCamera.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
cameraView.toggleFacing();
}
});

btnDetectObject.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
cameraView.captureImage();
}
});

initTensorFlowAndLoadModel();
}

@Override
protected void onResume() {
super.onResume();
cameraView.start();
}

@Override
protected void onPause() {
cameraView.stop();
super.onPause();
}

@Override
protected void onDestroy() {
super.onDestroy();
executor.execute(new Runnable() {
@Override
public void run() {
classifier.close();
}
});
}

private void initTensorFlowAndLoadModel() {
executor.execute(new Runnable() {
@Override
public void run() {
try {
classifier = TensorFlowImageClassifier.create(
getAssets(),
MODEL_PATH,
LABEL_PATH,
INPUT_SIZE,
QUANT);
makeButtonVisible();
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
}
}
});
}

private void makeButtonVisible() {
runOnUiThread(new Runnable() {
@Override
public void run() {
btnDetectObject.setVisibility(View.VISIBLE);
}
});
}
}

Step 6 → Create a class TensorFlowImageClassifier.class for image processing and and get the information of images.

Note- To get the info about the images you just need to call run method of Interpeter and pass image data and the labels array .

interpreter.run(byteBuffer, result);

below the whole code is there.

public class TensorFlowImageClassifier implements Classifier {

private static final int MAX_RESULTS = 3;
private static final int BATCH_SIZE = 1;
private static final int PIXEL_SIZE = 3;
private static final float THRESHOLD = 0.1f;

private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;

private Interpreter interpreter;
private int inputSize;
private List<String> labelList;
private boolean quant;

private TensorFlowImageClassifier() {

}

static Classifier create(AssetManager assetManager,
String modelPath,
String labelPath,
int inputSize,
boolean quant) throws IOException {

TensorFlowImageClassifier classifier = new TensorFlowImageClassifier();
classifier.interpreter = new Interpreter(classifier.loadModelFile(assetManager, modelPath), new Interpreter.Options());
classifier.labelList = classifier.loadLabelList(assetManager, labelPath);
classifier.inputSize = inputSize;
classifier.quant = quant;

return classifier;
}

@Override
public List<Recognition> recognizeImage(Bitmap bitmap) {
ByteBuffer byteBuffer = convertBitmapToByteBuffer(bitmap);
if (quant) {
byte[][] result = new byte[1][labelList.size()];
interpreter.run(byteBuffer, result);
return getSortedResultByte(result);
} else {
float[][] result = new float[1][labelList.size()];
interpreter.run(byteBuffer, result);
return getSortedResultFloat(result);
}

}

@Override
public void close() {
interpreter.close();
interpreter = null;
}

private MappedByteBuffer loadModelFile(AssetManager assetManager, String modelPath) throws IOException {
AssetFileDescriptor fileDescriptor = assetManager.openFd(modelPath);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

private List<String> loadLabelList(AssetManager assetManager, String labelPath) throws IOException {
List<String> labelList = new ArrayList<>();
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open(labelPath)));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
reader.close();
return labelList;
}

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
ByteBuffer byteBuffer;

if (quant) {
byteBuffer = ByteBuffer.allocateDirect(BATCH_SIZE * inputSize * inputSize * PIXEL_SIZE);
} else {
byteBuffer = ByteBuffer.allocateDirect(4 * BATCH_SIZE * inputSize * inputSize * PIXEL_SIZE);
}

byteBuffer.order(ByteOrder.nativeOrder());
int[] intValues = new int[inputSize * inputSize];
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
int pixel = 0;
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
final int val = intValues[pixel++];
if (quant) {
byteBuffer.put((byte) ((val >> 16) & 0xFF));
byteBuffer.put((byte) ((val >> 8) & 0xFF));
byteBuffer.put((byte) (val & 0xFF));
} else {
byteBuffer.putFloat((((val >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
byteBuffer.putFloat((((val >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
byteBuffer.putFloat((((val) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}

}
}
return byteBuffer;
}

@SuppressLint("DefaultLocale")
private List<Recognition> getSortedResultByte(byte[][] labelProbArray) {

PriorityQueue<Recognition> pq =
new PriorityQueue<>(
MAX_RESULTS,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});

for (int i = 0; i < labelList.size(); ++i) {
float confidence = (labelProbArray[0][i] & 0xff) / 255.0f;
if (confidence > THRESHOLD) {
pq.add(new Recognition("" + i,
labelList.size() > i ? labelList.get(i) : "unknown",
confidence, quant));
}
}

final ArrayList<Recognition> recognitions = new ArrayList<>();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
recognitions.add(pq.poll());
}

return recognitions;
}

@SuppressLint("DefaultLocale")
private List<Recognition> getSortedResultFloat(float[][] labelProbArray) {

PriorityQueue<Recognition> pq =
new PriorityQueue<>(
MAX_RESULTS,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});

for (int i = 0; i < labelList.size(); ++i) {
float confidence = labelProbArray[0][i];
if (confidence > THRESHOLD) {
pq.add(new Recognition("" + i,
labelList.size() > i ? labelList.get(i) : "unknown",
confidence, quant));
}
}

final ArrayList<Recognition> recognitions = new ArrayList<>();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
recognitions.add(pq.poll());
}

return recognitions;
}

}

Step 7 → Create interface Classifier.class for get the info about the images in your activity.

public interface Classifier {

class Recognition {
/**
* A unique identifier for what has been recognized. Specific to the class, not the instance of
* the object.
*/
private final String id;

/**
* Display name for the recognition.
*/
private final String title;

/**
* Whether or not the model features quantized or float weights.
*/
private final boolean quant;

/**
* A sortable score for how good the recognition is relative to others. Higher should be better.
*/
private final Float confidence;

public Recognition(
final String id, final String title, final Float confidence, final boolean quant) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.quant = quant;
}

public String getId() {
return id;
}

public String getTitle() {
return title;
}

public Float getConfidence() {
return confidence;
}

@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}

if (title != null) {
resultString += title + " ";
}

if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}

return resultString.trim();
}
}


List<Recognition> recognizeImage(Bitmap bitmap);

void close();
}

Conclusion

So these are the steps to implement TensorFlow’s lite in your android project.

--

--

Diehard04
Analytics Vidhya

Full time Android Developer, FreeLancern, SpringBoot, Flutter Developer