Commit 18f23a1c by vincent

fixed memory leaks + accept Tensors and HTMLCanvasElement as inputs

parent 3ca0f4af
import { FaceDetectionNet } from './types';
export class FaceDetectionResult {
private score: number
private top: number
private left: number
private bottom: number
private right: number
constructor(
score: number,
top: number,
left: number,
bottom: number,
right: number
) {
this.score = score
this.top = Math.max(0, top),
this.left = Math.max(0, left),
this.bottom = Math.min(1.0, bottom),
this.right = Math.min(1.0, right)
}
public forSize(width: number, height: number): FaceDetectionNet.Detection {
return {
score: this.score,
box: {
top: this.top * height,
left: this.left * width,
bottom: this.bottom * height,
right: this.right * width
}
}
}
}
\ No newline at end of file
......@@ -2,12 +2,12 @@ import * as tf from '@tensorflow/tfjs-core';
import { isFloat } from '../utils';
import { extractParams } from './extractParams';
import { FaceDetectionResult } from './FaceDetectionResult';
import { mobileNetV1 } from './mobileNetV1';
import { resizeLayer } from './resizeLayer';
import { predictionLayer } from './predictionLayer';
import { outputLayer } from './outputLayer';
import { nonMaxSuppression } from './nonMaxSuppression';
import { FaceDetectionNet } from './types';
import { outputLayer } from './outputLayer';
import { predictionLayer } from './predictionLayer';
import { resizeLayer } from './resizeLayer';
function fromData(input: number[]): tf.Tensor4D {
const pxPerChannel = input.length / 3
......@@ -21,6 +21,7 @@ function fromData(input: number[]): tf.Tensor4D {
}
function fromImageData(input: ImageData[]) {
return tf.tidy(() => {
const idx = input.findIndex(data => !(data instanceof ImageData))
if (idx !== -1) {
throw new Error(`expected input at index ${idx} to be instanceof ImageData`)
......@@ -31,9 +32,12 @@ function fromImageData(input: ImageData[]) {
.map(data => tf.expandDims(data, 0)) as tf.Tensor4D[]
return tf.cast(tf.concat(imgTensors, 0), 'float32')
})
}
function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
return tf.tidy(() => {
const [_, height, width] = imgTensor.shape
if (height === width) {
return imgTensor
......@@ -45,10 +49,25 @@ function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
}
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 1)
})
}
function getImgTensor(input: ImageData|ImageData[]|number[]) {
function getImgTensor(input: tf.Tensor|HTMLCanvasElement|ImageData|ImageData[]|number[]) {
return tf.tidy(() => {
if (input instanceof HTMLCanvasElement) {
return tf.cast(
tf.expandDims(tf.fromPixels(input), 0), 'float32'
) as tf.Tensor4D
}
if (input instanceof tf.Tensor) {
const rank = input.shape.length
if (rank !== 3 && rank !== 4) {
throw new Error('input tensor must be of rank 3 or 4')
}
return tf.cast(
rank === 3 ? tf.expandDims(input, 0) : input, 'float32'
) as tf.Tensor4D
}
const imgDataArray = input instanceof ImageData
? [input]
......@@ -58,11 +77,9 @@ function getImgTensor(input: ImageData|ImageData[]|number[]) {
: null
)
return padToSquare(
imgDataArray !== null
return imgDataArray !== null
? fromImageData(imgDataArray)
: fromData(input as number[])
)
})
}
......@@ -85,31 +102,47 @@ export function faceDetectionNet(weights: Float32Array) {
})
}
function forward(input: ImageData|ImageData[]|number[]) {
function forward(input: tf.Tensor|ImageData|ImageData[]|number[]) {
return tf.tidy(
() => forwardTensor(padToSquare(getImgTensor(input)))
)
}
async function locateFaces(
input: ImageData|ImageData[]|number[],
input: tf.Tensor|HTMLCanvasElement|ImageData|ImageData[]|number[],
minConfidence: number = 0.8,
maxResults: number = 100,
): Promise<FaceDetectionNet.Detection[]> {
const imgTensor = getImgTensor(input)
const [_, height, width] = imgTensor.shape
): Promise<FaceDetectionResult[]> {
let paddedHeightRelative = 1, paddedWidthRelative = 1
const {
boxes: _boxes,
scores: _scores
} = forwardTensor(imgTensor)
} = tf.tidy(() => {
let imgTensor = getImgTensor(input)
const [_, height, width] = imgTensor.shape
imgTensor = padToSquare(imgTensor)
paddedHeightRelative = imgTensor.shape[1] / height
paddedWidthRelative = imgTensor.shape[2] / width
return forwardTensor(imgTensor)
})
// TODO batches
const boxes = _boxes[0]
const scores = _scores[0]
for (let i = 1; i < _boxes.length; i++) {
_boxes[i].dispose()
_scores[i].dispose()
}
// TODO find a better way to filter by minConfidence
//const ts = Date.now()
const scoresData = Array.from(await scores.data())
//console.log('await data:', (Date.now() - ts))
const iouThreshold = 0.5
const indices = nonMaxSuppression(
......@@ -120,17 +153,19 @@ export function faceDetectionNet(weights: Float32Array) {
minConfidence
)
return indices
.map(idx => ({
score: scoresData[idx],
box: {
top: Math.max(0, height * boxes.get(idx, 0)),
left: Math.max(0, width * boxes.get(idx, 1)),
bottom: Math.min(height, height * boxes.get(idx, 2)),
right: Math.min(width, width * boxes.get(idx, 3))
}
}))
const results = indices
.map(idx => new FaceDetectionResult(
scoresData[idx],
boxes.get(idx, 0) * paddedHeightRelative,
boxes.get(idx, 1) * paddedWidthRelative,
boxes.get(idx, 2) * paddedHeightRelative,
boxes.get(idx, 3) * paddedWidthRelative
))
boxes.dispose()
scores.dispose()
return results
}
return {
......
......@@ -15,6 +15,13 @@ function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingContext2
return ctx
}
function getMediaDimensions(media: HTMLImageElement | HTMLVideoElement) {
if (media instanceof HTMLVideoElement) {
return { width: media.videoWidth, height: media.videoHeight }
}
return media
}
export function isFloat(num: number) {
return num % 1 !== 0
}
......@@ -43,7 +50,7 @@ export function drawMediaToCanvas(
throw new Error('drawMediaToCanvas - expected media to be of type: HTMLImageElement | HTMLVideoElement')
}
const { width, height } = dims || media
const { width, height } = dims || getMediaDimensions(media)
canvas.width = width
canvas.height = height
......@@ -59,7 +66,7 @@ export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement, dim
const ctx = drawMediaToCanvas(document.createElement('canvas'), media)
const { width, height } = dims || media
const { width, height } = dims || getMediaDimensions(media)
return ctx.getImageData(0, 0, width, height)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment