DEV Community

Cover image for Flutter Web and Machine Learning
aseem wangoo
aseem wangoo

Posted on • Updated on

Flutter Web and Machine Learning

In case it helped :)
Pass Me A Coffee!!

We will cover how to implement

  1. Machine Learning using TensorFlow..
  2. Feature Extraction from image…

Pre-Requisite :

This article uses the concept of calling JavaScript functions from Flutter Web, which is explained in detail here.


Machine Learning using TensorFlow in Flutter Web..

Article here: https://flatteredwithflutter.com/machine-learning-in-flutter-web/

We will use TensorFlow.js, which is a JavaScript Library for training and deploying machine learning models in the browser and in Node.js

Setup :

Using Script Tags

script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>

Add the above script tag inside the head section of your index.html file

Alt Text

That’s it…..


Implementing a Model in Flutter Web…

What will we do :

  1. Create a linear Model
  2. Train the model
  3. Enter a sample value to get the output…

Alt Text

Explanation :

We will create the above linear model. This model follows the formula

(2x — 1). For instance,

  1. when x = -1, then y = -3
  2. x = 0, y = -1 and so on…..

We will give a sample input as 12, and predict the value from this model..

Create the model…

  1. Create a js file (in our case ml.js)
  2. Define a function (in our case learnLinear)
async function learnLinear(input) {}

Initialize a sequential model, using tf.sequential.

const model = tf.sequential();

A sequential model is any model where the outputs of one layer are the inputs to the next layer.

Lets add our input layer to this model, using tf.layers.dense.


model.add(tf.layers.dense({ units: 1, inputShape: [1] }));

Parameters :

  • units (number) : Size of the output space. We will output just a single number
  • inputShape : Defines the shape of input. We will provide the input as an array of length 1.

Finally, we add this layer to our sequential model, using model.add


Next, we need to compile the model,

model.compile({
loss: 'meanSquaredError',
optimizer: 'sgd'
});

We use model.compile for compiling the model..

Parameters :

loss : we seek to minimize the error. Cross-entropy and mean squared error are the two main types of loss functions to use when training neural network models.

optimizer : string name for an Optimizer. In our case Stochastic Gradient Descent or sgd


Next, we need to train the model,

// INPUT -> [6, 1] 6rows 1 columns
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);

We define the input for x-axis using tf.tensor2d, called as xs

Parameters :

values : The values of the tensor. Can be nested array of numbers, or a flat array. In our case [-1, 0, 1, 2, 3, 4]

shape : The shape of the tensor. If not provided, it is inferred from values. In our case, its an array of 6 rows and 1 column, hence [6, 1]

Similarly, we define the output for y-axis using tf.tensor2d, called as ys

// TRAIN MODEL -> EPOCHS (ITERATIONS)
await model.fit(xs, ys, { epochs: 250 });

Now, we train the model using model.fit

Parameters :

  • x : an input array of tf.Tensors, in our case xs
  • y: an output array of tf.Tensors, in our case ys
  • epochs: Times to iterate over the training data arrays.

As we trained our model now, lets test it…..Time to predict values using model.predict

// PREDICT THE VALUE NOW...
var predictions = model.predict(tf.tensor2d([input], [1, 1]));

let result = predictions.dataSync();
console.log('Res', result[0]); //number

Parameters :

x: Input data, as an Array of tf.Tensors, in our case this value is an array of 1 element, passed from dart.

The result is stored in a predictions variable. In order to retrieve the data, we call

dataSync : Synchronously downloads the values from the tf.Tensor as an array..

Get Predicted Value in Flutter Web…

In the above step, we created the TensorFlow model as a JS function which accepts a parameter..

async function learnLinear(input) {}
  1. Import the package
import 'package:js/js_util.dart' as jsutil;

2. Create a dart file calling the JS function…

@js
()
library main;

import 'package:js/js.dart';

@js
('learnLinear')
external num linearModel(int number);

Note : learnLinear is the same JS function which we defined in the above section

3. As our function is an async function, we need to await the result from it..

await jsutil.promiseToFuture<num>(linearModel(12))

We will make use of promiseToFuture. What this does is

Converts a JavaScript Promise to a Dart Future.


Lets call this function from a button now,

OutlineButton(
onPressed: () async {
await jsutil.promiseToFuture<num>(linearModel(12));
},
child: const Text('Linear Model x=12'),
)

We have provided input value as 12, and the output we get is :

Alt Text

Feature Extraction From Image…

For the feature extraction, we use an existing model called MobileNet.

MobileNets are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection similar to how other popular large scale models, are used.

It takes any browser-based image elements (<img>, <video>, <canvas>) as inputs, and returns an array of most likely predictions and their confidences.

  1. Setup :

Using Script Tags

Add the above script tag inside the head section of your index.html file

2. Function in JS : 

We will define an image tag inside our body html as

<img id="img" src="" hidden></img>

Define a function in JS as :

async function classifyImage() {}

Get the source of the image tag as

const img = document.getElementById('img');

Load the mobilenet model and extract the features from the image selected as

// LOAD MOBILENET MODEL
const model = await mobilenet.load();

// CLASSIFY THE IMAGE
let predictions = await model.classify(img);
console.log('Pred >>>', predictions);

return predictions

Predictions is an array which looks like this :

[{
className: "Egyptian cat",
probability: 0.8380282521247864
}, {
className: "tabby, tabby cat",
probability: 0.04644153267145157
}, {
className: "Siamese cat, Siamese",
probability: 0.024488523602485657
}]

Finally, return these predictions.

3. Function in dart :

@js
()
library main;

import 'package:js/js.dart';

@js
('learnLinear')
external num linearModel(int number);

@js
('classifyImage')
external List imageClassifier();

Note : This file was already created in the above section, we just added the last 2 lines…The name classifyImage is same as the JS function created in step 1

4. Call the function from button

OutlineButton(
onPressed: () async {
await jsutil.promiseToFuture<List<Object>>(imageClassifier());
},
child: const Text('Feature Extraction'),
)

The return type of the imageClassifier() is a List<Object> . In order to extract the results, we need to convert this list into a custom Model class

5. Convert into Custom Model

We create a custom Class called ImageResults as

@js
()
@anonymous
class ImageResults {
  external factory ImageResults({
    String className,
    num probability,
  });

external String get className;
external num get probability;

Map toMap() {
    final _map = {
      'className': className,
      'probability': probability,
    };
    return _map;
  }
}

First, we will convert each Object into a String, and then the string into ImageResults model…

List<ImageResults> listOfImageResults(List<Object> _val) {
  final _listOfMap = <ImageResults>[];
  
  for (final item in _val) {
    final _jsString = stringify(item);
    _listOfMap.add(jsonObject(_jsString));
  }
  return _listOfMap;
}

stringify is a function, defined as 

@js
('JSON.stringify')
external String stringify(Object obj);

this string is converted to ImageResults model using jsonObject..

@js
('JSON.parse')
external ImageResults jsonObject(String str);

Now, you can easily access the values in dart as :

for (final ImageResults _item in _listOfMap) ...[
Text('ClassName : ${_item.className}'),
Text('Probability : ${_item.probability}\n'),
]

In case it helped :)
Pass Me A Coffee!!

Hosted URL : https://fir-signin-4477d.firebaseapp.com/#/

Source code for Flutter Web App..

Top comments (0)