在Flutter中使用TensorFlow Lite实现人脸检测

起因

项目需要部署人脸签到模块,但是目前市面上类似项目的架构更多在于将视频流传输到后端,由后端完成人脸的检测与识别.但是考虑到服务器性能与带宽限制,我们选择了设备端检测人脸并上传到服务器端进行对比的方案.搜索了下,发现在Flutter中实现模型检测的资料较少,因此在这里做个记录.

tflite_flutter

img

在Flutter上,TensorFlow Lite的实现为tflite_flutter这个插件,它在Android与iOS端分别调用原生API,完成了统一化的安装.要使用这个插件,首先应当安装相应依赖库.在插件的主页上有详细步骤:https://pub.flutter-io.cn/packages/tflite_flutter.实际上只需要一个install批处理文件,即可自动化完成下载依赖.同时我们需要TFLite格式的模型文件,并存放在assets文件夹.为了保证assets文件夹可以被识别,需要在pubspec.yaml中加入以下内容:

  assets:
    -assets/

对于我们的人脸检测模型,我们采用了谷歌官方推荐的BlazeFace模型,在https://github.com/ibaiGorordo/BlazeFace-TFLite-Inference/tree/main/models可以找到相关模型文件,\_back文件指适用于后置摄像头,_front为前置摄像头.这里我们选择的是前置模型.

Dart端调用

利用TFLiteInterpreter接口可以实现模型的加载与预测.通过getInputTensor()可获得模型中定义的输入向量规模.

    _interpreter =
    await Interpreter.fromAsset("models/face_detection_front.tflite");
   var _inputShape = _interpreter.getInputTensor(0).shape;
    _imageProcessor = ImageProcessorBuilder()
        .add(ResizeOp(
        _inputShape[1], _inputShape[2], ResizeMethod.NEAREST_NEIGHBOUR))
        .add(_normalizeInput)
        .build();

TFLite的Java代码翻译为Dart,可得如下辅助函数:

List<Anchor> getAnchors(AnchorOption options) {
    List<Anchor> _anchors = new List();
    if (options.stridesSize != options.numLayers) {
      print('strides_size and num_layers must be equal.');
      return [];
    }
    int layerID = 0;
    while (layerID < options.stridesSize) {
      List<double> anchorHeight = new List();
      List<double> anchorWidth = new List();
      List<double> aspectRatios = new List();
      List<double> scales = new List();

      int lastSameStrideLayer = layerID;
      while (lastSameStrideLayer < options.stridesSize &&
          options.strides[lastSameStrideLayer] == options.strides[layerID]) {
        double scale = options.minScale +
            (options.maxScale - options.minScale) *
                1.0 *
                lastSameStrideLayer /
                (options.stridesSize - 1.0);
        if (lastSameStrideLayer == 0 && options.reduceBoxesInLowestLayer) {
          aspectRatios.add(1.0);
          aspectRatios.add(2.0);
          aspectRatios.add(0.5);
          scales.add(0.1);
          scales.add(scale);
          scales.add(scale);
        } else {
          for (int i = 0; i < options.aspectRatios.length; i++) {
            aspectRatios.add(options.aspectRatios[i]);
            scales.add(scale);
          }

          if (options.interpolatedScaleAspectRatio > 0.0) {
            double scaleNext = 0.0;
            if (lastSameStrideLayer == options.stridesSize - 1) {
              scaleNext = 1.0;
            } else {
              scaleNext = options.minScale +
                  (options.maxScale - options.minScale) *
                      1.0 *
                      (lastSameStrideLayer + 1) /
                      (options.stridesSize - 1.0);
            }
            scales.add(sqrt(scale * scaleNext));
            aspectRatios.add(options.interpolatedScaleAspectRatio);
          }
        }
        lastSameStrideLayer++;
      }
      for (int i = 0; i < aspectRatios.length; i++) {
        double ratioSQRT = sqrt(aspectRatios[i]);
        anchorHeight.add(scales[i] / ratioSQRT);
        anchorWidth.add(scales[i] * ratioSQRT);
      }
      int featureMapHeight = 0;
      int featureMapWidth = 0;
      if (options.featureMapHeightSize > 0) {
        featureMapHeight = options.featureMapHeight[layerID];
        featureMapWidth = options.featureMapWidth[layerID];
      } else {
        int stride = options.strides[layerID];
        featureMapHeight = (1.0 * options.inputSizeHeight / stride).ceil();
        featureMapWidth = (1.0 * options.inputSizeWidth / stride).ceil();
      }

      for (int y = 0; y < featureMapHeight; y++) {
        for (int x = 0; x < featureMapWidth; x++) {
          for (int anchorID = 0; anchorID < anchorHeight.length; anchorID++) {
            double xCenter =
                (x + options.anchorOffsetX) * 1.0 / featureMapWidth;
            double yCenter =
                (y + options.anchorOffsetY) * 1.0 / featureMapHeight;
            double w = 0;
            double h = 0;
            if (options.fixedAnchorSize) {
              w = 1.0;
              h = 1.0;
            } else {
              w = anchorWidth[anchorID];
              h = anchorHeight[anchorID];
            }
            _anchors.add(Anchor(xCenter, yCenter, h, w));
          }
        }
      }
      layerID = lastSameStrideLayer;
    }
    return _anchors;
  }
  List<Detection> convertToDetections(
      List<double> rawBoxes,
      List<Anchor> anchors,
      List<double> detectionScores,
      List<int> detectionClasses,
      OptionsFace options) {
    List<Detection> _outputDetections = new List();
    for (int i = 0; i < options.numBoxes; i++) {
      if (detectionScores[i] < options.minScoreThresh) continue;
      int boxOffset = 0;
      Array boxData = decodeBox(rawBoxes, i, anchors, options);
      Detection detection = convertToDetection(
          boxData[boxOffset + 0],
          boxData[boxOffset + 1],
          boxData[boxOffset + 2],
          boxData[boxOffset + 3],
          detectionScores[i],
          detectionClasses[i],
          options.flipVertically);
      _outputDetections.add(detection);
    }
    return _outputDetections;
  }
    List<Detection> detections = convertToDetections(
        rawBoxes, anchors, detectionScores, detectionClasses, options);
    // print(detections);
    return detections;
  }
  List<Detection> process(
      {OptionsFace options,
        List<double> rawScores,
        List<double> rawBoxes,
        List<Anchor> anchors}) {
    List<double> detectionScores = new List();
    List<int> detectionClasses = new List();

    int boxes = options.numBoxes;
    for (int i = 0; i < boxes; i++) {
      int classId = -1;
      double maxScore = double.minPositive;
      for (int scoreIdx = 0; scoreIdx < options.numClasses; scoreIdx++) {
        double score = rawScores[i * options.numClasses + scoreIdx];
        if (options.sigmoidScore) {
          if (options.scoreClippingThresh > 0) {
            if (score < -options.scoreClippingThresh)
              score = -options.scoreClippingThresh;
            if (score > options.scoreClippingThresh)
              score = options.scoreClippingThresh;
            score = 1.0 / (1.0 + exp(-score));
            if (maxScore < score) {
              maxScore = score;
              classId = scoreIdx;
            }
          }
        }
      }
      detectionClasses.add(classId);
      detectionScores.add(maxScore);
    }
    // print(detectionScores);
    List<Detection> detections = convertToDetections(
        rawBoxes, anchors, detectionScores, detectionClasses, options);
    // print(detections);
    return detections;
  }

为了实现对摄像流每帧进行检测,我们侦听CameraController的startImageStream方法,该方法接收一个Image为形参的回调函数,允许对每帧摄像做出处理.设定每秒截取一次图像进行运算并设置一个标识来标志正在调用模型预测,降低性能需求.

  _onStream() async {
    await cameraController.startImageStream((CameraImage image) async {
      if (_isDetecting) return;
      _isDetecting = true;
      Future.delayed(const Duration(seconds: 1), () {
        _tfLite(image);
        _isDetecting = false;
      });
    });
  }

然而,在Android端,视频流的帧格式为YUV420 Image,需要将其转换为RGB格式:

import 'package:image/image.dart' as img;
Future<img.Image> convertYUV420toImageColor(CameraImage image) async {
  final int width = image.planes[0].bytesPerRow;
  final int height = image.height;
  final int uvRowStride = image.planes[1].bytesPerRow;
  final int uvPixelStride = image.planes[1].bytesPerPixel;
  var buffer = img.Image(width, height);
  for (int x = 0; x < width; x++) {
    for (int y = 0; y < height; y++) {
      final int uvIndex =
          uvPixelStride * (x / 2).floor() + uvRowStride * (y / 2).floor();
      final int index = y * width + x;
      if (uvIndex > image.planes[1].bytes.length) {
        continue;
      }
      final yp = image.planes[0].bytes[index];
      final up = image.planes[1].bytes[uvIndex];
      final vp = image.planes[2].bytes[uvIndex];
      int r = (yp + vp * 1436 / 1024 - 179).round().clamp(0, 255);
      int g = (yp - up * 46549 / 131072 + 44 - vp * 93604 / 131072 + 91)
          .round()
          .clamp(0, 255);
      int b = (yp + up * 1814 / 1024 - 227).round().clamp(0, 255);
      buffer.data[index] = (0xFF << 24) | (b << 16) | (g << 8) | r;
    }
  }
  return img.copyRotate(
      img.copyCrop(buffer, 0, 0, image.width, image.height), -90);
}

次后我们将img转换为 TensorFlow Image,由tflite-helper库中的imageProcessor完成.最后将TensorImage输入interpreter进行预测,并获取结果.结果为一个二维数组,分别包含对每个可能识别为面孔的区域的置信度及每个区域的x,y,w,h.

    _imageProcessor = ImageProcessorBuilder()
        .add(ResizeOp(
        _inputShape[1], _inputShape[2], ResizeMethod.NEAREST_NEIGHBOUR))
        .add(_normalizeInput)
        .build();

 _img = await convertYUV420toImageColor(image);
      TensorImage tensorImage = TensorImage.fromImage(_img);
      _image=tensorImage.image;
      tensorImage = _imageProcessor.process(tensorImage);
     // print( _image.height);
     //  print( _image.width);

      TensorBuffer output0 = TensorBuffer.createFixedSize(
          _interpreter.getOutputTensor(0).shape,
          _interpreter.getOutputTensor(0).type);
      TensorBuffer output1 = TensorBuffer.createFixedSize(
          _interpreter.getOutputTensor(1).shape,
          _interpreter.getOutputTensor(1).type);

      Map<int, ByteBuffer> outputs = {0: output0.buffer, 1: output1.buffer};

      _interpreter.runForMultipleInputs([tensorImage.buffer], outputs);


      List<double> regression = output0.getDoubleList();
      List<double> classificators = output1.getDoubleList();
      List<Detection> detections = process(
          options: options,
          rawScores: classificators,
          rawBoxes: regression,
          anchors: _anchors);
      List<Detection> _detections = origNms(detections, 0.75);

我们设置置信区间为大于0.75,并对检测结果进行过滤,当存在符合要求的区域,即_detections的长度大于1时,我们将此区域截取,发送给后端进行校验.

tag(s): none
show comments · back · home
Edit with markdown