Commit 0c75bc37 by vincent

refactor FaceDetection result

parent 1cba5ada
...@@ -108,7 +108,7 @@ ...@@ -108,7 +108,7 @@
descriptors.forEach((descriptor, i) => { descriptors.forEach((descriptor, i) => {
const bestMatch = getBestMatch(trainDescriptorsByClass, descriptor) const bestMatch = getBestMatch(trainDescriptorsByClass, descriptor)
const text = `${bestMatch.distance < maxDistance ? bestMatch.className : 'unkown'} (${bestMatch.distance})` const text = `${bestMatch.distance < maxDistance ? bestMatch.className : 'unkown'} (${bestMatch.distance})`
const { x, y, height: boxHeight } = detectionsForSize[i].box const { x, y, height: boxHeight } = detectionsForSize[i].getBox()
faceapi.drawText( faceapi.drawText(
canvas.getContext('2d'), canvas.getContext('2d'),
x, x,
......
export class Rect {
public x: number
public y: number
public width: number
public height: number
constructor(x: number, y: number, width: number, height: number) {
this.x = x
this.y = y
this.width = width
this.height = height
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { NetInput } from './NetInput';
import { getImageTensor } from './getImageTensor'; import { getImageTensor } from './getImageTensor';
import { NetInput } from './NetInput';
import { TNetInput } from './types'; import { TNetInput } from './types';
/** /**
...@@ -18,7 +18,7 @@ import { TNetInput } from './types'; ...@@ -18,7 +18,7 @@ import { TNetInput } from './types';
*/ */
export function extractFaceTensors( export function extractFaceTensors(
image: tf.Tensor | NetInput | TNetInput, image: tf.Tensor | NetInput | TNetInput,
detections: FaceDetectionResult[] detections: FaceDetection[]
): tf.Tensor4D[] { ): tf.Tensor4D[] {
return tf.tidy(() => { return tf.tidy(() => {
const imgTensor = getImageTensor(image) const imgTensor = getImageTensor(image)
...@@ -27,7 +27,7 @@ export function extractFaceTensors( ...@@ -27,7 +27,7 @@ export function extractFaceTensors(
const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape
const faceTensors = detections.map(det => { const faceTensors = detections.map(det => {
const { x, y, width, height } = det.forSize(imgWidth, imgHeight).box const { x, y, width, height } = det.forSize(imgWidth, imgHeight).getBox()
return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels]) return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels])
}) })
......
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { createCanvas, getContext2dOrThrow } from './utils'; import { createCanvas, getContext2dOrThrow } from './utils';
/** /**
...@@ -10,12 +10,12 @@ import { createCanvas, getContext2dOrThrow } from './utils'; ...@@ -10,12 +10,12 @@ import { createCanvas, getContext2dOrThrow } from './utils';
*/ */
export function extractFaces( export function extractFaces(
image: HTMLCanvasElement, image: HTMLCanvasElement,
detections: FaceDetectionResult[] detections: FaceDetection[]
): HTMLCanvasElement[] { ): HTMLCanvasElement[] {
const ctx = getContext2dOrThrow(image) const ctx = getContext2dOrThrow(image)
return detections.map(det => { return detections.map(det => {
const { x, y, width, height } = det.forSize(image.width, image.height).box const { x, y, width, height } = det.forSize(image.width, image.height).getBox()
const faceImg = createCanvas({ width, height }) const faceImg = createCanvas({ width, height })
getContext2dOrThrow(faceImg) getContext2dOrThrow(faceImg)
......
import { Rect } from '../Rect';
import { Dimensions } from '../types';
export class FaceDetection {
private _score: number
private _box: Rect
private _imageWidth: number
private _imageHeight: number
constructor(
score: number,
relativeBox: Rect,
imageDims: Dimensions
) {
const { width, height } = imageDims
this._imageWidth = width
this._imageHeight = height
this._score = score
this._box = new Rect(
Math.floor(relativeBox.x * width),
Math.floor(relativeBox.y * height),
Math.floor(relativeBox.width * width),
Math.floor(relativeBox.height * height)
)
}
public getScore() {
return this._score
}
public getBox() {
return this._box
}
public getRelativeBox() {
return new Rect(
this._box.x / this._imageWidth,
this._box.y / this._imageHeight,
this._box.width / this._imageWidth,
this._box.height / this._imageHeight
)
}
public forSize(width: number, height: number): FaceDetection {
return new FaceDetection(
this._score,
this.getRelativeBox(),
{ width, height}
)
}
}
\ No newline at end of file
import { FaceDetectionNet } from './types';
export class FaceDetectionResult {
private _score: number
private _topRelative: number
private _leftRelative: number
private _bottomRelative: number
private _rightRelative: number
constructor(
score: number,
topRelative: number,
leftRelative: number,
bottomRelative: number,
rightRelative: number
) {
this._score = score
this._topRelative = Math.max(0, topRelative),
this._leftRelative = Math.max(0, leftRelative),
this._bottomRelative = Math.min(1.0, bottomRelative),
this._rightRelative = Math.min(1.0, rightRelative)
}
public forSize(width: number, height: number): FaceDetectionNet.Detection {
const x = Math.floor(this._leftRelative * width)
const y = Math.floor(this._topRelative * height)
return {
score: this._score,
box: {
x,
y,
width: Math.floor(this._rightRelative * width) - x,
height: Math.floor(this._bottomRelative * height) - y
}
}
}
}
\ No newline at end of file
...@@ -3,14 +3,15 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -3,14 +3,15 @@ import * as tf from '@tensorflow/tfjs-core';
import { getImageTensor } from '../getImageTensor'; import { getImageTensor } from '../getImageTensor';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { padToSquare } from '../padToSquare'; import { padToSquare } from '../padToSquare';
import { TNetInput } from '../types'; import { TNetInput, Dimensions } from '../types';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { FaceDetectionResult } from './FaceDetectionResult'; import { FaceDetection } from './FaceDetection';
import { mobileNetV1 } from './mobileNetV1'; import { mobileNetV1 } from './mobileNetV1';
import { nonMaxSuppression } from './nonMaxSuppression'; import { nonMaxSuppression } from './nonMaxSuppression';
import { outputLayer } from './outputLayer'; import { outputLayer } from './outputLayer';
import { predictionLayer } from './predictionLayer'; import { predictionLayer } from './predictionLayer';
import { resizeLayer } from './resizeLayer'; import { resizeLayer } from './resizeLayer';
import { Rect } from '../Rect';
export function faceDetectionNet(weights: Float32Array) { export function faceDetectionNet(weights: Float32Array) {
const params = extractParams(weights) const params = extractParams(weights)
...@@ -40,9 +41,10 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -40,9 +41,10 @@ export function faceDetectionNet(weights: Float32Array) {
input: tf.Tensor | NetInput, input: tf.Tensor | NetInput,
minConfidence: number = 0.8, minConfidence: number = 0.8,
maxResults: number = 100, maxResults: number = 100,
): Promise<FaceDetectionResult[]> { ): Promise<FaceDetection[]> {
let paddedHeightRelative = 1, paddedWidthRelative = 1 let paddedHeightRelative = 1, paddedWidthRelative = 1
let imageDimensions: Dimensions | undefined
const { const {
boxes: _boxes, boxes: _boxes,
...@@ -51,6 +53,7 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -51,6 +53,7 @@ export function faceDetectionNet(weights: Float32Array) {
let imgTensor = getImageTensor(input) let imgTensor = getImageTensor(input)
const [height, width] = imgTensor.shape.slice(1) const [height, width] = imgTensor.shape.slice(1)
imageDimensions = { width, height }
imgTensor = padToSquare(imgTensor) imgTensor = padToSquare(imgTensor)
paddedHeightRelative = imgTensor.shape[1] / height paddedHeightRelative = imgTensor.shape[1] / height
...@@ -80,13 +83,26 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -80,13 +83,26 @@ export function faceDetectionNet(weights: Float32Array) {
) )
const results = indices const results = indices
.map(idx => new FaceDetectionResult( .map(idx => {
const [top, bottom] = [
Math.max(0, boxes.get(idx, 0)),
Math.min(1.0, boxes.get(idx, 2))
].map(val => val * paddedHeightRelative)
const [left, right] = [
Math.max(0, boxes.get(idx, 1)),
Math.min(1.0, boxes.get(idx, 3))
].map(val => val * paddedWidthRelative)
return new FaceDetection(
scoresData[idx], scoresData[idx],
boxes.get(idx, 0) * paddedHeightRelative, new Rect(
boxes.get(idx, 1) * paddedWidthRelative, left,
boxes.get(idx, 2) * paddedHeightRelative, top,
boxes.get(idx, 3) * paddedWidthRelative right - left,
)) bottom - top
),
imageDimensions as Dimensions
)
})
boxes.dispose() boxes.dispose()
scores.dispose() scores.dispose()
......
...@@ -62,15 +62,4 @@ export namespace FaceDetectionNet { ...@@ -62,15 +62,4 @@ export namespace FaceDetectionNet {
prediction_layer_params: PredictionLayerParams, prediction_layer_params: PredictionLayerParams,
output_layer_params: OutputLayerParams output_layer_params: OutputLayerParams
} }
export type Detection = {
score: number
box: {
x: number,
y: number,
width: number,
height: number
}
}
} }
import { FaceDetectionNet } from './faceDetectionNet/types'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks'; import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
import { Dimensions, DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types'; import { Dimensions, DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
...@@ -115,7 +115,7 @@ export function drawText( ...@@ -115,7 +115,7 @@ export function drawText(
export function drawDetection( export function drawDetection(
canvasArg: string | HTMLCanvasElement, canvasArg: string | HTMLCanvasElement,
detection: FaceDetectionNet.Detection | FaceDetectionNet.Detection[], detection: FaceDetection | FaceDetection[],
options?: DrawBoxOptions & DrawTextOptions & { withScore: boolean } options?: DrawBoxOptions & DrawTextOptions & { withScore: boolean }
) { ) {
const canvas = getElement(canvasArg) const canvas = getElement(canvasArg)
...@@ -129,16 +129,11 @@ export function drawDetection( ...@@ -129,16 +129,11 @@ export function drawDetection(
detectionArray.forEach((det) => { detectionArray.forEach((det) => {
const { const {
score,
box
} = det
const {
x, x,
y, y,
width, width,
height height
} = box } = det.getBox()
const drawOptions = Object.assign( const drawOptions = Object.assign(
getDefaultDrawOptions(), getDefaultDrawOptions(),
...@@ -161,7 +156,7 @@ export function drawDetection( ...@@ -161,7 +156,7 @@ export function drawDetection(
ctx, ctx,
x, x,
y, y,
`${round(score)}`, `${round(det.getScore())}`,
drawOptions drawOptions
) )
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment