DEV Community

loading...
Cover image for Implementation of Tensorflow Lite model on Android

Implementation of Tensorflow Lite model on Android

jemaloqiu profile image jemaloQiu ・2 min read

Recently in some interview I have been asked about experience of implementing trained tensorflow models in android platform. I have tried one android project cloned from github which embedded a tflite model in it. However, I have not yet tried implementing my own model in an Android application. Thus I did such an exercise today and I successfully made my CNN model work on my Redmi Note 8 pro.

Alt Text

CNN model

Here is the code for training a cnn model with mnist data set. This model then is converted as tflite model and shall be implemented in Android application for recognizing hand-write digits.


import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics,models



(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(x_train[0,:,:])
## x_train.shape => (60000, 28, 28), y_train.shape => (60000,)
## x_test.shape => (10000, 28, 28), y_test.shape => (60000,)
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)

yt = tf.squeeze(y_train)  

y_train = tf.squeeze(y_train)    
y_test = tf.squeeze(y_test)

print("Dataset info: ", x_train.shape, y_train.shape, x_test.shape, y_test.shape)

batch_size = 128

train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(10000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.batch(batch_size)

train_iter = iter(train_db)
sample = next(train_iter)
print(sample[0].shape, sample[1].shape)  

##  build a standard cnn model
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))

model.summary()


model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])

train_history = model.fit(train_db, epochs=10,          validation_data=test_db)

## once the model has been trained, convert it to tflite model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('qiu_mnist_model.tflite', 'wb') as f:
    f.write(tflite_model)

Enter fullscreen mode Exit fullscreen mode

Implementation in Android app

I refereed to this post for obtaining the original android project. I imported the kotlin version into my Android Studio. However, there were some bugs initially when I loaded my model into it.

My own model is located to asset repository:
Alt Text

The most important thing for this work is the following Gradle setting:
Alt Text

After about 15min of debugging and code modifications, I successfully made my model work.

Check out the video (there is still accuracy issue):

I will upload the android project src code to my github repo once I finish cleaning the code and improve the performance.

reference
  1. https://www.tensorflow.org/lite/performance/post_training_quantization

  2. https://margaretmz.medium.com/e2e-tfkeras-tflite-android-273acde6588

Discussion (0)

pic
Editor guide