we implement machine learning models in the mobile application

Hello! My name is Nikita Gribkov, I am a Flutter developer in AGIMA. In this article I will tell you about the TensorFlow Lite framework, which allows you to integrate machine learning models into a mobile application. This is a useful thing if you need to implement features related to speech recognition or image classification. I'll show you how to train models and then work with them.

The technology makes it possible to create personalized and intelligent solutions for users, which is why it is in high demand. If our goal is to make the application more usable and inclusive, then most likely we will have to use ML.

Here are some examples of tasks for which the technology is 100% suitable:

  • image classification: so that the application can recognize objects in photos or videos (for example, Google Lens);

  • Natural Language Processing (NLP): in applications with voice assistants or chatbots, ML processes speech and texts (for example, Siri or Google Assistant);

  • personalization: ML algorithms analyze user behavior and offer personalized content or recommendations;

  • voice recognition: used in speech-to-text and command applications.

There are several ways to integrate machine learning models into an application. You can use the ML Kit from Firebase or libraries on Dart. But personally, I tried working with TensorFlow Lite (TFLite). This framework can be considered the most common solution in this case.

Its main (but not only) advantage is that it can work offline when the device is not connected to the Internet. I also like that TFLite is optimized to work on devices with limited resources, it’s convenient. Let's look at how the framework works.

Preparing the model for use with TFLite

Before integrating TFLite into a Flutter application, you need to prepare a model. This involves training it in TensorFlow and converting it to the .tflite format.

Step 1: Create and train a model in TensorFlow

To work with machine learning, you can train a model using TensorFlow. Here's a simple example of creating and training a model in Python:

import tensorflow as tf
from tensorflow.keras import layers

# Создание простой модели для классификации изображений
model = tf.keras.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

Model compilation:

model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])

Training a model on MNIST data:

model.fit(train_images, train_labels, epochs=5)

Saving the model:

model.save("model.h5")

The network consists of one layer to transform 28×28 pixels into a one-dimensional vector, a hidden layer with 128 neurons, and an output layer with 10 neurons for 10 classes.

model = tf.keras.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

The model is compiled using the Adam optimizer and the Sparse Categorical Crossentropy loss function.

model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])

It is then trained on MNIST data for 5 epochs and saved to the file “model.h5”.

model.fit(train_images, train_labels, epochs=5)

Step 2: Convert the model to TFLite format

After training the model, it needs to be converted to .tflite format using the TFLite converter.

Sample code for model conversion:

# Загрузка модели
model = tf.keras.models.load_model('model.h5')

# Конвертация модели в формат TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Сохранение модели в формате .tflite
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

You now have a .tflite model that you can integrate into your Flutter app.

Integrating TFLite into a Flutter application

To work with TFLite in Flutter you need to use a plugin tflite_flutter. This repository is a TensorFlow managed fork of the project – a TensorFlow managed fork of the project [tflite_flutter_plugin]

Step 1: Install the required dependencies

Open the file pubspec.yaml your Flutter project and add dependencies:

dependencies:
  flutter:
    sdk: flutter
  tflite_flutter: ^0.11.0
  tflite_flutter_helper_plus: ^0.0.2

Step 2. Preparing the model

Copy your model file model.tflite to the project folder assets. Then in the file pubspec.yaml indicate the path to the model in the section assets:

flutter:
  assets:
    - assets/model.tflite
    - assets/labels.txt # если у вас есть файл с метками

Step 3: Load and Use the Model in Flutter Code

Now let's create code to load the model and make predictions based on it on the Flutter side.

Import packages:

import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper_plus/tflite_flutter_helper_plus.dart';

Model loading:

late Interpreter interpreter;
Future<void> loadModel() async {
  try {
    // Загружаем модель из assets
    interpreter = await Interpreter.fromAsset('model.tflite');
    print('Модель загружена успешно');
  } catch (e) {
    print('Ошибка загрузки модели: $e');
  }
}

This code will convert the image to an array Float32List. It takes each pixel of the image, extracts the red, green and blue channel values, normalizes them using the given mean And stdand then fills the array.

Float32List imageToByteListFloat32(
      img.Image image, int inputSize, double mean, double std) {
    var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
    var buffer = Float32List.view(convertedBytes.buffer);
    int pixelIndex = 0;
    for (var i = 0; i < inputSize; i++) {
      for (var j = 0; j < inputSize; j++) {
        var pixel = image.getPixel(j, i);
        buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std);
        buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std);
        buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std);
      }
    }
    return convertedBytes;
  }

Making predictions

To make predictions, you need to transform the input data into a suitable format, for example, an image into a tensor (data array).

 Future<void> classifyImage(File image) async {
    // Преобразуем изображение в тензор
    final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!;
    var inputImage = img.copyResize(imageInput, width: 28, height: 28);
    var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5);

    // Подготовка выходного тензора
    var output = List.filled(10, 0).reshape([1, 10]);

    // Выполнение предсказания
    _interpreter.run(input, output);

    setState(() {
      _result="Предсказание: ${output.toString()}";
    });
  }

Flutter using TFLite

import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:image_picker/image_picker.dart';
import 'dart:io';
import 'package:image/image.dart' as img;

class MyHomePage extends StatefulWidget {
  @override
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  late Interpreter _interpreter;
  File? _image;
  final picker = ImagePicker();
  String _result="Нет предсказаний";

  @override
  void initState() {
    super.initState();
    loadModel();
  }

  Future<void> loadModel() async {
    try {
      _interpreter = await Interpreter.fromAsset('model.tflite');
      print('Модель загружена');
    } catch (e) {
      print('Ошибка загрузки модели: $e');
    }
  }

  Future<void> pickImage() async {
    final pickedFile = await picker.pickImage(source: ImageSource.gallery);

    setState(() {
      _image = File(pickedFile!.path);
    });

    if (_image != null) {
      classifyImage(_image!);
    }
  }

  Future<void> classifyImage(File image) async {
    // Преобразуем изображение в тензор
    final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!;
    var inputImage = img.copyResize(imageInput, width: 28, height: 28);
    var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5);

    // Подготовка выходного тензора
    var output = List.filled(10, 0).reshape([1, 10]);

    // Выполнение предсказания
    _interpreter.run(input, output);

    setState(() {
      _result="Предсказание: ${output.toString()}";
    });
  }

  Float32List imageToByteListFloat32(
      img.Image image, int inputSize, double mean, double std) {
    var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
    var buffer = Float32List.view(convertedBytes.buffer);
    int pixelIndex = 0;
    for (var i = 0; i < inputSize; i++) {
      for (var j = 0; j < inputSize; j++) {
        var pixel = image.getPixel(j, i);
        buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std);
        buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std);
        buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std);
      }
    }
    return convertedBytes;
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(title: Text('TFLite Classifier')),
      body: Column(
        children: [
          _image == null ? Text('Выберите изображение') : Image.file(_image!),
          ElevatedButton(
            onPressed: pickImage,
            child: Text('Загрузить изображение'),
          ),
          Text(_result),
        ],
      ),
    );
  }
}

Optimizing the model for mobile devices

To improve performance on mobile devices, you can use the following approaches:

  • Model quantization. It reduces the size of the model and speeds up work by reducing the precision of numerical representations.

  • Parallel execution. Using multi-core processors to speed up predictions.

Example code for quantizing a model:

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('quantized_model.tflite', 'wb') as f:
    f.write(tflite_model)

What's the result?

As a result, we get an application with already working ML models. The longest stage is associated with training the models, everything else is a matter of technique. I think the examples above will help you carry out integration quickly.

If you have any questions, ask in the comments, I will answer. In general, subscribe to the channel of our colleague Sasha Vorozhishchev — he writes a lot about Flutter and about mobile development in general.

What else to read

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *