Commit 1c89e90a by vincent

face landmark net now accepts batch inputs

parent b8d62591
...@@ -18,25 +18,10 @@ export class NetInput { ...@@ -18,25 +18,10 @@ export class NetInput {
return return
} }
// if input is batch type, make sure every canvas has the same dimensions this._canvases.push(createCanvasFromMedia(media, dims))
const canvasDims = this.dims || dims
this._canvases.push(createCanvasFromMedia(media, canvasDims))
} }
public get canvases() : HTMLCanvasElement[] { public get canvases() : HTMLCanvasElement[] {
return this._canvases return this._canvases
} }
public get width() : number {
return (this._canvases[0] || {}).width
}
public get height() : number {
return (this._canvases[0] || {}).height
}
public get dims() : Dimensions | null {
const { width, height } = this
return (width > 0 && height > 0) ? { width, height } : null
}
} }
\ No newline at end of file
...@@ -3,6 +3,7 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -3,6 +3,7 @@ import * as tf from '@tensorflow/tfjs-core';
import { extractFaceTensors } from './extractFaceTensors'; import { extractFaceTensors } from './extractFaceTensors';
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet'; import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet'; import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet'; import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
import { FullFaceDescription } from './FullFaceDescription'; import { FullFaceDescription } from './FullFaceDescription';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
...@@ -23,7 +24,8 @@ export function allFacesFactory( ...@@ -23,7 +24,8 @@ export function allFacesFactory(
const faceTensors = await extractFaceTensors(input, detections) const faceTensors = await extractFaceTensors(input, detections)
const faceLandmarksByFace = await Promise.all(faceTensors.map( const faceLandmarksByFace = await Promise.all(faceTensors.map(
faceTensor => landmarkNet.detectLandmarks(faceTensor) faceTensor => landmarkNet.detectLandmarks(faceTensor)
)) )) as FaceLandmarks[]
faceTensors.forEach(t => t.dispose()) faceTensors.forEach(t => t.dispose())
const alignedFaceBoxes = await Promise.all(faceLandmarksByFace.map( const alignedFaceBoxes = await Promise.all(faceLandmarksByFace.map(
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { tensorTo4D } from './tensorTo4D';
export function getImageTensor(input: tf.Tensor | NetInput): tf.Tensor4D { export function getImageTensor(input: tf.Tensor | NetInput): tf.Tensor4D {
return tf.tidy(() => { return tf.tidy(() => {
if (input instanceof tf.Tensor) { if (input instanceof tf.Tensor) {
const rank = input.shape.length return tensorTo4D(input)
if (rank !== 3 && rank !== 4) {
throw new Error('input tensor must be of rank 3 or 4')
}
return (rank === 3 ? input.expandDims(0) : input).toFloat() as tf.Tensor4D
} }
if (!(input instanceof NetInput)) { if (!(input instanceof NetInput)) {
throw new Error('getImageTensor - expected input to be a tensor or an instance of NetInput') throw new Error('getImageTensor - expected input to be a tensor or an instance of NetInput')
} }
return tf.concat( if (input.canvases.length > 1) {
input.canvases.map(canvas => throw new Error('getImageTensor - batch input is not accepted here')
tf.fromPixels(canvas).expandDims(0).toFloat() }
)
) as tf.Tensor4D return tf.fromPixels(input.canvases[0]).expandDims(0).toFloat() as tf.Tensor4D
}) })
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
export function tensorTo4D(input: tf.Tensor): tf.Tensor4D {
if (input.rank !== 3 && input.rank !== 4) {
throw new Error('tensorTo4D - input tensor must be of rank 3 or 4')
}
return tf.tidy(
() => input.rank === 3 ? input.expandDims(0) : input
) as tf.Tensor4D
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { NetInput } from '../NetInput';
import { padToSquare } from '../padToSquare';
import { tensorTo4D } from './tensorTo4D';
import { BatchReshapeInfo } from './types';
export function toInputTensor(
input: tf.Tensor | tf.Tensor[] | NetInput,
inputSize: number,
center: boolean = true
): { batchTensor: tf.Tensor4D, batchInfo: BatchReshapeInfo[] } {
if (!(input instanceof tf.Tensor) && !(input instanceof NetInput)) {
throw new Error('toInputTensor - expected input to be a tensor of an instance of NetInput')
}
return tf.tidy(() => {
const inputTensors = input instanceof NetInput
? input.canvases.map(c => tf.expandDims(tf.fromPixels(c)))
: [tensorTo4D(input)]
const preprocessedTensors: tf.Tensor4D[] = []
const batchInfo: BatchReshapeInfo[] = []
inputTensors.forEach((inputTensor: tf.Tensor4D) => {
const [originalHeight, originalWidth] = inputTensor.shape.slice(1)
let imgTensor = padToSquare(inputTensor.toFloat(), center)
const [heightAfterPadding, widthAfterPadding] = imgTensor.shape.slice(1)
if (heightAfterPadding !== inputSize || widthAfterPadding !== inputSize) {
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize])
}
preprocessedTensors.push(imgTensor)
batchInfo.push({
originalWidth,
originalHeight,
paddingX: widthAfterPadding - originalWidth,
paddingY: heightAfterPadding - originalHeight
})
})
const batchSize = inputTensors.length
return {
batchTensor: tf.stack(preprocessedTensors).as4D(batchSize, inputSize, inputSize, 3),
batchInfo
}
})
}
\ No newline at end of file
...@@ -5,4 +5,11 @@ export type ConvParams = { ...@@ -5,4 +5,11 @@ export type ConvParams = {
bias: tf.Tensor1D bias: tf.Tensor1D
} }
export type ExtractWeightsFunction = (numWeights: number) => Float32Array export type ExtractWeightsFunction = (numWeights: number) => Float32Array
\ No newline at end of file
export type BatchReshapeInfo = {
originalWidth: number
originalHeight: number
paddingX: number
paddingY: number
}
...@@ -4,8 +4,8 @@ import { getImageTensor } from './commons/getImageTensor'; ...@@ -4,8 +4,8 @@ import { getImageTensor } from './commons/getImageTensor';
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import { Rect } from './Rect'; import { Rect } from './Rect';
import { TNetInput } from './types';
import { toNetInput } from './toNetInput'; import { toNetInput } from './toNetInput';
import { TNetInput } from './types';
/** /**
* Extracts the tensors of the image regions containing the detected faces. * Extracts the tensors of the image regions containing the detected faces.
...@@ -29,8 +29,7 @@ export async function extractFaceTensors( ...@@ -29,8 +29,7 @@ export async function extractFaceTensors(
return tf.tidy(() => { return tf.tidy(() => {
const imgTensor = getImageTensor(image) const imgTensor = getImageTensor(image)
// TODO handle batches const [imgHeight, imgWidth, numChannels] = imgTensor.shape.slice(1)
const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape
const boxes = detections.map( const boxes = detections.map(
det => det instanceof FaceDetection det => det instanceof FaceDetection
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { convLayer } from '../commons/convLayer'; import { convLayer } from '../commons/convLayer';
import { getImageTensor } from '../commons/getImageTensor'; import { toInputTensor } from '../commons/toInputTensor';
import { ConvParams } from '../commons/types'; import { ConvParams } from '../commons/types';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { padToSquare } from '../padToSquare';
import { Point } from '../Point'; import { Point } from '../Point';
import { toNetInput } from '../toNetInput'; import { toNetInput } from '../toNetInput';
import { Dimensions, TNetInput } from '../types'; import { TNetInput } from '../types';
import { isEven } from '../utils'; import { isEven } from '../utils';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { FaceLandmarks } from './FaceLandmarks'; import { FaceLandmarks } from './FaceLandmarks';
...@@ -43,7 +42,7 @@ export class FaceLandmarkNet { ...@@ -43,7 +42,7 @@ export class FaceLandmarkNet {
this._params = extractParams(weights) this._params = extractParams(weights)
} }
public forwardTensor(imgTensor: tf.Tensor4D): tf.Tensor2D { public forwardTensor(input: tf.Tensor | NetInput): tf.Tensor2D {
const params = this._params const params = this._params
if (!params) { if (!params) {
...@@ -51,17 +50,9 @@ export class FaceLandmarkNet { ...@@ -51,17 +50,9 @@ export class FaceLandmarkNet {
} }
return tf.tidy(() => { return tf.tidy(() => {
const [batchSize, height, width] = imgTensor.shape.slice() const { batchTensor, batchInfo } = toInputTensor(input, 128, true)
let x = padToSquare(imgTensor, true) let out = conv(batchTensor, params.conv0_params)
const [heightAfterPadding, widthAfterPadding] = x.shape.slice(1)
// work with 128 x 128 sized face images
if (heightAfterPadding !== 128 || widthAfterPadding !== 128) {
x = tf.image.resizeBilinear(x, [128, 128])
}
let out = conv(x, params.conv0_params)
out = maxPool(out) out = maxPool(out)
out = conv(out, params.conv1_params) out = conv(out, params.conv1_params)
out = conv(out, params.conv2_params) out = conv(out, params.conv2_params)
...@@ -76,26 +67,38 @@ export class FaceLandmarkNet { ...@@ -76,26 +67,38 @@ export class FaceLandmarkNet {
const fc0 = tf.relu(fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0_params)) const fc0 = tf.relu(fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0_params))
const fc1 = fullyConnectedLayer(fc0, params.fc1_params) const fc1 = fullyConnectedLayer(fc0, params.fc1_params)
const createInterleavedTensor = (fillX: number, fillY: number) => const createInterleavedTensor = (fillX: number, fillY: number) =>
tf.stack([ tf.stack([
tf.fill([68], fillX), tf.fill([68], fillX),
tf.fill([68], fillY) tf.fill([68], fillY)
], 1).as2D(batchSize, 136) ], 1).as2D(1, 136).as1D()
/* shift coordinates back, to undo centered padding /* shift coordinates back, to undo centered padding
((x * widthAfterPadding) - shiftX) / width x = ((x * widthAfterPadding) - shiftX) / width
((y * heightAfterPadding) - shiftY) / height y = ((y * heightAfterPadding) - shiftY) / height
*/ */
const shiftX = Math.floor(Math.abs(widthAfterPadding - width) / 2)
const shiftY = Math.floor(Math.abs(heightAfterPadding - height) / 2) const landmarkTensors = fc1
const landmarkTensor = fc1 .mul(tf.stack(batchInfo.map(info =>
.mul(createInterleavedTensor(widthAfterPadding, heightAfterPadding)) createInterleavedTensor(
.sub(createInterleavedTensor(shiftX, shiftY)) info.paddingX + info.originalWidth,
.div(createInterleavedTensor(width, height)) info.paddingY + info.originalHeight
)
return landmarkTensor as tf.Tensor2D )))
.sub(tf.stack(batchInfo.map(info =>
createInterleavedTensor(
Math.floor(info.paddingX / 2),
Math.floor(info.paddingY / 2)
)
)))
.div(tf.stack(batchInfo.map(info =>
createInterleavedTensor(
info.originalWidth,
info.originalHeight
)
)))
return landmarkTensors as tf.Tensor2D
}) })
} }
...@@ -104,34 +107,35 @@ export class FaceLandmarkNet { ...@@ -104,34 +107,35 @@ export class FaceLandmarkNet {
? input ? input
: await toNetInput(input) : await toNetInput(input)
return this.forwardTensor(getImageTensor(netInput)) return this.forwardTensor(netInput)
} }
public async detectLandmarks(input: tf.Tensor | NetInput | TNetInput) { public async detectLandmarks(input: tf.Tensor | NetInput | TNetInput): Promise<FaceLandmarks | FaceLandmarks[]> {
const netInput = input instanceof tf.Tensor const netInput = input instanceof tf.Tensor
? input ? input
: await toNetInput(input) : await toNetInput(input)
let imageDimensions: Dimensions | undefined const landmarkTensors = tf.unstack(this.forwardTensor(netInput))
const outTensor = tf.tidy(() => {
const imgTensor = getImageTensor(netInput)
const [height, width] = imgTensor.shape.slice(1) const landmarksForBatch = await Promise.all(landmarkTensors.map(
imageDimensions = { width, height } async (landmarkTensor, batchIdx) => {
const landmarksArray = Array.from(await landmarkTensor.data())
landmarkTensor.dispose()
return this.forwardTensor(imgTensor) const xCoords = landmarksArray.filter((_, i) => isEven(i))
}) const yCoords = landmarksArray.filter((_, i) => !isEven(i))
const faceLandmarksArray = Array.from(await outTensor.data()) const [height, width] = netInput instanceof tf.Tensor
outTensor.dispose() ? netInput.shape.slice(1)
: [netInput.canvases[batchIdx].height, netInput.canvases[batchIdx].width]
const xCoords = faceLandmarksArray.filter((_, i) => isEven(i)) return new FaceLandmarks(
const yCoords = faceLandmarksArray.filter((_, i) => !isEven(i)) Array(68).fill(0).map((_, i) => new Point(xCoords[i], yCoords[i])),
{ height, width }
)
}
))
return new FaceLandmarks( return landmarksForBatch.length === 1 ? landmarksForBatch[0] : landmarksForBatch
Array(68).fill(0).map((_, i) => new Point(xCoords[i], yCoords[i])),
imageDimensions as Dimensions
)
} }
} }
\ No newline at end of file
...@@ -6,9 +6,9 @@ import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet'; ...@@ -6,9 +6,9 @@ import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet'; import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks'; import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet'; import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
import { FullFaceDescription } from './FullFaceDescription';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import { TNetInput } from './types'; import { TNetInput } from './types';
import { FullFaceDescription } from './FullFaceDescription';
export const detectionNet = new FaceDetectionNet() export const detectionNet = new FaceDetectionNet()
export const landmarkNet = new FaceLandmarkNet() export const landmarkNet = new FaceLandmarkNet()
...@@ -44,7 +44,7 @@ export function locateFaces( ...@@ -44,7 +44,7 @@ export function locateFaces(
export function detectLandmarks( export function detectLandmarks(
input: tf.Tensor | NetInput | TNetInput input: tf.Tensor | NetInput | TNetInput
): Promise<FaceLandmarks> { ): Promise<FaceLandmarks | FaceLandmarks[]> {
return landmarkNet.detectLandmarks(input) return landmarkNet.detectLandmarks(input)
} }
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { isEven } from './utils';
/** /**
* Pads the smaller dimension of an image tensor with zeros, such that width === height. * Pads the smaller dimension of an image tensor with zeros, such that width === height.
* *
......
[{"x": 9.995004907250404, "y": 53.55449616909027}, {"x": 12.50796876847744, "y": 71.41348421573639}, {"x": 16.677917540073395, "y": 88.59677910804749}, {"x": 22.6475290954113, "y": 104.6014130115509}, {"x": 30.59161528944969, "y": 119.35952603816986}, {"x": 41.422560811042786, "y": 132.23226964473724}, {"x": 54.74700182676315, "y": 142.4335777759552}, {"x": 70.32481580972672, "y": 149.33189749717712}, {"x": 87.31497824192047, "y": 150.50972700119019}, {"x": 103.98584604263306, "y": 145.98273038864136}, {"x": 117.90181696414948, "y": 135.19554734230042}, {"x": 128.67935299873352, "y": 121.79077863693237}, {"x": 136.7296814918518, "y": 105.85636496543884}, {"x": 140.29521346092224, "y": 88.25878500938416}, {"x": 140.9232795238495, "y": 70.16736567020416}, {"x": 140.2374029159546, "y": 52.73242145776749}, {"x": 137.97148168087006, "y": 34.537942707538605}, {"x": 14.37721811234951, "y": 33.1049881875515}, {"x": 22.6781465113163, "y": 24.685607850551605}, {"x": 34.36600640416145, "y": 21.1758591234684}, {"x": 46.24761343002319, "y": 22.49436378479004}, {"x": 57.12086856365204, "y": 26.742971688508987}, {"x": 81.21025264263153, "y": 23.014162480831146}, {"x": 92.2086775302887, "y": 15.48520065844059}, {"x": 104.77548837661743, "y": 11.306393891572952}, {"x": 117.67798662185669, "y": 11.740228906273842}, {"x": 127.28274464607239, "y": 18.115675449371338}, {"x": 69.62742805480957, "y": 41.51403307914734}, {"x": 70.82946002483368, "y": 55.146731436252594}, {"x": 71.84555232524872, "y": 68.59723627567291}, {"x": 73.0046421289444, "y": 81.93029165267944}, {"x": 60.417647659778595, "y": 88.01697492599487}, {"x": 67.98770427703857, "y": 90.65443575382233}, {"x": 76.07284784317017, "y": 91.86699986457825}, {"x": 84.35145914554596, "y": 88.2117748260498}, {"x": 90.86072444915771, "y": 83.67109894752502}, {"x": 28.828849643468857, "y": 47.794362902641296}, {"x": 36.311765760183334, "y": 43.33548992872238}, {"x": 44.95347887277603, "y": 43.20283681154251}, {"x": 52.85406410694122, "y": 48.07424694299698}, {"x": 44.7566494345665, "y": 49.9691516160965}, {"x": 35.997654497623444, "y": 50.32083839178085}, {"x": 89.51361179351807, "y": 42.501528561115265}, {"x": 97.55686819553375, "y": 35.38782298564911}, {"x": 106.73499405384064, "y": 33.59129726886749}, {"x": 114.8474782705307, "y": 36.34611591696739}, {"x": 108.40394496917725, "y": 40.97002297639847}, {"x": 98.98389279842377, "y": 42.47862249612808}, {"x": 53.17014008760452, "y": 109.99322533607483}, {"x": 62.47727572917938, "y": 105.5664449930191}, {"x": 72.82306104898453, "y": 102.35638618469238}, {"x": 80.85319697856903, "y": 103.16510796546936}, {"x": 89.42103087902069, "y": 100.08856952190399}, {"x": 99.8135894536972, "y": 100.0435084104538}, {"x": 109.78849232196808, "y": 101.60946249961853}, {"x": 101.90783143043518, "y": 115.25328755378723}, {"x": 92.65078604221344, "y": 122.49988317489624}, {"x": 83.28675627708435, "y": 125.00075697898865}, {"x": 74.31002408266068, "y": 125.16917288303375}, {"x": 63.37641924619675, "y": 121.38420939445496}, {"x": 57.85811394453049, "y": 110.28821468353271}, {"x": 73.27612638473511, "y": 108.2253098487854}, {"x": 81.34059011936188, "y": 108.251291513443}, {"x": 89.9710088968277, "y": 105.99507093429565}, {"x": 104.85810041427612, "y": 103.1228095293045}, {"x": 90.87785482406616, "y": 112.19902038574219}, {"x": 82.05846548080444, "y": 114.52528238296509}, {"x": 73.8232746720314, "y": 114.64338898658752}]
\ No newline at end of file
...@@ -5,35 +5,64 @@ import { expectMaxDelta } from '../../utils'; ...@@ -5,35 +5,64 @@ import { expectMaxDelta } from '../../utils';
describe('faceLandmarkNet', () => { describe('faceLandmarkNet', () => {
let imgEl: HTMLImageElement let imgEl1: HTMLImageElement
let imgEl2: HTMLImageElement
let faceLandmarkPositions1: Point[]
let faceLandmarkPositions2: Point[]
beforeAll(async () => { beforeAll(async () => {
const img = await (await fetch('base/test/images/face.png')).blob() const img1 = await (await fetch('base/test/images/face1.png')).blob()
imgEl = await faceapi.bufferToImage(img) imgEl1 = await faceapi.bufferToImage(img1)
const img2 = await (await fetch('base/test/images/face2.png')).blob()
imgEl2 = await faceapi.bufferToImage(img2)
faceLandmarkPositions1 = await (await fetch('base/test/data/faceLandmarkPositions1.json')).json()
faceLandmarkPositions2 = await (await fetch('base/test/data/faceLandmarkPositions2.json')).json()
}) })
describe('uncompressed weights', () => { describe('uncompressed weights', () => {
let faceLandmarkNet: faceapi.FaceLandmarkNet, faceLandmarkPositions: Point[] let faceLandmarkNet: faceapi.FaceLandmarkNet
beforeAll(async () => { beforeAll(async () => {
const res = await fetch('base/weights/uncompressed/face_landmark_68_model.weights') const res = await fetch('base/weights/uncompressed/face_landmark_68_model.weights')
const weights = new Float32Array(await res.arrayBuffer()) const weights = new Float32Array(await res.arrayBuffer())
faceLandmarkNet = faceapi.faceLandmarkNet(weights) faceLandmarkNet = faceapi.faceLandmarkNet(weights)
faceLandmarkPositions = await (await fetch('base/test/data/faceLandmarkPositions.json')).json()
}) })
it('computes face landmarks', async () => { it('computes face landmarks', async () => {
const { width, height } = imgEl const { width, height } = imgEl1
const result = await faceLandmarkNet.detectLandmarks(imgEl) as FaceLandmarks const result = await faceLandmarkNet.detectLandmarks(imgEl1) as FaceLandmarks
expect(result.getImageWidth()).toEqual(width) expect(result.getImageWidth()).toEqual(width)
expect(result.getImageHeight()).toEqual(height) expect(result.getImageHeight()).toEqual(height)
expect(result.getShift().x).toEqual(0) expect(result.getShift().x).toEqual(0)
expect(result.getShift().y).toEqual(0) expect(result.getShift().y).toEqual(0)
result.getPositions().forEach(({ x, y }, i) => { result.getPositions().forEach(({ x, y }, i) => {
expectMaxDelta(x, faceLandmarkPositions[i].x, 0.1) expectMaxDelta(x, faceLandmarkPositions1[i].x, 0.1)
expectMaxDelta(y, faceLandmarkPositions[i].y, 0.1) expectMaxDelta(y, faceLandmarkPositions1[i].y, 0.1)
})
})
it('computes face landmarks for batch input', async () => {
const imgEls = [imgEl1, imgEl2]
const faceLandmarkPositions = [
faceLandmarkPositions1,
faceLandmarkPositions2
]
const results = await faceLandmarkNet.detectLandmarks(imgEls) as FaceLandmarks[]
expect(Array.isArray(results)).toBe(true)
expect(results.length).toEqual(2)
results.forEach((result, batchIdx) => {
const { width, height } = imgEls[batchIdx]
expect(result.getImageWidth()).toEqual(width)
expect(result.getImageHeight()).toEqual(height)
expect(result.getShift().x).toEqual(0)
expect(result.getShift().y).toEqual(0)
result.getPositions().forEach(({ x, y }, i) => {
expectMaxDelta(x, faceLandmarkPositions[batchIdx][i].x, 0.1)
expectMaxDelta(y, faceLandmarkPositions[batchIdx][i].y, 0.1)
})
}) })
}) })
...@@ -41,25 +70,47 @@ describe('faceLandmarkNet', () => { ...@@ -41,25 +70,47 @@ describe('faceLandmarkNet', () => {
describe('quantized weights', () => { describe('quantized weights', () => {
let faceLandmarkNet: faceapi.FaceLandmarkNet, faceLandmarkPositions: Point[] let faceLandmarkNet: faceapi.FaceLandmarkNet
beforeAll(async () => { beforeAll(async () => {
faceLandmarkNet = new faceapi.FaceLandmarkNet() faceLandmarkNet = new faceapi.FaceLandmarkNet()
await faceLandmarkNet.load('base/weights') await faceLandmarkNet.load('base/weights')
faceLandmarkPositions = await (await fetch('base/test/data/faceLandmarkPositions.json')).json()
}) })
it('computes face landmarks', async () => { it('computes face landmarks', async () => {
const { width, height } = imgEl const { width, height } = imgEl1
const result = await faceLandmarkNet.detectLandmarks(imgEl) as FaceLandmarks const result = await faceLandmarkNet.detectLandmarks(imgEl1) as FaceLandmarks
expect(result.getImageWidth()).toEqual(width) expect(result.getImageWidth()).toEqual(width)
expect(result.getImageHeight()).toEqual(height) expect(result.getImageHeight()).toEqual(height)
expect(result.getShift().x).toEqual(0) expect(result.getShift().x).toEqual(0)
expect(result.getShift().y).toEqual(0) expect(result.getShift().y).toEqual(0)
result.getPositions().forEach(({ x, y }, i) => { result.getPositions().forEach(({ x, y }, i) => {
expectMaxDelta(x, faceLandmarkPositions[i].x, 2) expectMaxDelta(x, faceLandmarkPositions1[i].x, 2)
expectMaxDelta(y, faceLandmarkPositions[i].y, 2) expectMaxDelta(y, faceLandmarkPositions1[i].y, 2)
})
})
it('computes face landmarks for batch input', async () => {
const imgEls = [imgEl1, imgEl2]
const faceLandmarkPositions = [
faceLandmarkPositions1,
faceLandmarkPositions2
]
const results = await faceLandmarkNet.detectLandmarks(imgEls) as FaceLandmarks[]
expect(Array.isArray(results)).toBe(true)
expect(results.length).toEqual(2)
results.forEach((result, batchIdx) => {
const { width, height } = imgEls[batchIdx]
expect(result.getImageWidth()).toEqual(width)
expect(result.getImageHeight()).toEqual(height)
expect(result.getShift().x).toEqual(0)
expect(result.getShift().y).toEqual(0)
result.getPositions().forEach(({ x, y }, i) => {
expectMaxDelta(x, faceLandmarkPositions[batchIdx][i].x, 3)
expectMaxDelta(y, faceLandmarkPositions[batchIdx][i].y, 3)
})
}) })
}) })
......
...@@ -9,7 +9,7 @@ describe('faceRecognitionNet', () => { ...@@ -9,7 +9,7 @@ describe('faceRecognitionNet', () => {
const weights = new Float32Array(await res.arrayBuffer()) const weights = new Float32Array(await res.arrayBuffer())
faceRecognitionNet = faceapi.faceRecognitionNet(weights) faceRecognitionNet = faceapi.faceRecognitionNet(weights)
const img = await (await fetch('base/test/images/face.png')).blob() const img = await (await fetch('base/test/images/face1.png')).blob()
imgEl = await faceapi.bufferToImage(img) imgEl = await faceapi.bufferToImage(img)
faceDescriptor = await (await fetch('base/test/data/faceDescriptor.json')).json() faceDescriptor = await (await fetch('base/test/data/faceDescriptor.json')).json()
}) })
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment