Commit 93959e13 by vincent

implemented face alignment from landmarks

parent 5f46b72d
...@@ -27,4 +27,16 @@ export class Point implements IPoint { ...@@ -27,4 +27,16 @@ export class Point implements IPoint {
public div(pt: IPoint): Point { public div(pt: IPoint): Point {
return new Point(this.x / pt.x, this.y / pt.y) return new Point(this.x / pt.x, this.y / pt.y)
} }
public abs(): Point {
return new Point(Math.abs(this.x), Math.abs(this.y))
}
public magnitude(): number {
return Math.sqrt(Math.pow(this.x, 2) + Math.pow(this.y, 2))
}
public floor(): Point {
return new Point(Math.floor(this.x), Math.floor(this.y))
}
} }
\ No newline at end of file
import { Point } from '../Point';
export function getCenterPoint(pts: Point[]): Point {
return pts.reduce((sum, pt) => sum.add(pt), new Point(0, 0))
.div(new Point(pts.length, pts.length))
}
\ No newline at end of file
...@@ -3,22 +3,22 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -3,22 +3,22 @@ import * as tf from '@tensorflow/tfjs-core';
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { getImageTensor } from './getImageTensor'; import { getImageTensor } from './getImageTensor';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import { Rect } from './Rect';
import { TNetInput } from './types'; import { TNetInput } from './types';
/** /**
* Extracts the tensors of the image regions containing the detected faces. * Extracts the tensors of the image regions containing the detected faces.
* Returned tensors have to be disposed manually once you don't need them anymore! * Useful if you want to compute the face descriptors for the face images.
* Useful if you want to compute the face descriptors for the face * Using this method is faster then extracting a canvas for each face and
* images. Using this method is faster then extracting a canvas for each face and
* converting them to tensors individually. * converting them to tensors individually.
* *
* @param input The image that face detection has been performed on. * @param input The image that face detection has been performed on.
* @param detections The face detection results for that image. * @param detections The face detection results or face bounding boxes for that image.
* @returns Tensors of the corresponding image region for each detected face. * @returns Tensors of the corresponding image region for each detected face.
*/ */
export function extractFaceTensors( export function extractFaceTensors(
image: tf.Tensor | NetInput | TNetInput, image: tf.Tensor | NetInput | TNetInput,
detections: FaceDetection[] detections: Array<FaceDetection|Rect>
): tf.Tensor4D[] { ): tf.Tensor4D[] {
return tf.tidy(() => { return tf.tidy(() => {
const imgTensor = getImageTensor(image) const imgTensor = getImageTensor(image)
...@@ -26,10 +26,14 @@ export function extractFaceTensors( ...@@ -26,10 +26,14 @@ export function extractFaceTensors(
// TODO handle batches // TODO handle batches
const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape
const faceTensors = detections.map(det => { const boxes = detections.map(
const { x, y, width, height } = det.forSize(imgWidth, imgHeight).getBox().floor() det => det instanceof FaceDetection
return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels]) ? det.forSize(imgWidth, imgHeight).getBox().floor()
}) : det
)
const faceTensors = boxes.map(({ x, y, width, height }) =>
tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels])
)
return faceTensors return faceTensors
}) })
......
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { Rect } from './Rect';
import { createCanvas, getContext2dOrThrow } from './utils'; import { createCanvas, getContext2dOrThrow } from './utils';
/** /**
* Extracts the image regions containing the detected faces. * Extracts the image regions containing the detected faces.
* *
* @param input The image that face detection has been performed on. * @param input The image that face detection has been performed on.
* @param detections The face detection results for that image. * @param detections The face detection results or face bounding boxes for that image.
* @returns The Canvases of the corresponding image region for each detected face. * @returns The Canvases of the corresponding image region for each detected face.
*/ */
export function extractFaces( export function extractFaces(
image: HTMLCanvasElement, image: HTMLCanvasElement,
detections: FaceDetection[] detections: Array<FaceDetection|Rect>
): HTMLCanvasElement[] { ): HTMLCanvasElement[] {
const ctx = getContext2dOrThrow(image) const ctx = getContext2dOrThrow(image)
return detections.map(det => { const boxes = detections.map(
const { x, y, width, height } = det.forSize(image.width, image.height).getBox().floor() det => det instanceof FaceDetection
? det.forSize(image.width, image.height).getBox().floor()
: det
)
return boxes.map(({ x, y, width, height }) => {
const faceImg = createCanvas({ width, height }) const faceImg = createCanvas({ width, height })
getContext2dOrThrow(faceImg) getContext2dOrThrow(faceImg)
.putImageData(ctx.getImageData(x, y, width, height), 0, 0) .putImageData(ctx.getImageData(x, y, width, height), 0, 0)
......
import { Point, IPoint } from '../Point'; import { getCenterPoint } from '../commons/getCenterPoint';
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
import { Point } from '../Point';
import { Rect } from '../Rect';
import { Dimensions } from '../types'; import { Dimensions } from '../types';
// face alignment constants
const relX = 0.5
const relY = 0.43
const relScale = 0.45
export class FaceLandmarks { export class FaceLandmarks {
private _faceLandmarks: Point[]
private _imageWidth: number private _imageWidth: number
private _imageHeight: number private _imageHeight: number
private _shift: Point private _shift: Point
private _faceLandmarks: Point[]
constructor( constructor(
relativeFaceLandmarkPositions: Point[], relativeFaceLandmarkPositions: Point[],
...@@ -21,41 +29,53 @@ export class FaceLandmarks { ...@@ -21,41 +29,53 @@ export class FaceLandmarks {
) )
} }
public getPositions() { public getShift(): Point {
return new Point(this._shift.x, this._shift.y)
}
public getImageWidth(): number {
return this._imageWidth
}
public getImageHeight(): number {
return this._imageHeight
}
public getPositions(): Point[] {
return this._faceLandmarks return this._faceLandmarks
} }
public getRelativePositions() { public getRelativePositions(): Point[] {
return this._faceLandmarks.map( return this._faceLandmarks.map(
pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight)) pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
) )
} }
public getJawOutline() { public getJawOutline(): Point[] {
return this._faceLandmarks.slice(0, 17) return this._faceLandmarks.slice(0, 17)
} }
public getLeftEyeBrow() { public getLeftEyeBrow(): Point[] {
return this._faceLandmarks.slice(17, 22) return this._faceLandmarks.slice(17, 22)
} }
public getRightEyeBrow() { public getRightEyeBrow(): Point[] {
return this._faceLandmarks.slice(22, 27) return this._faceLandmarks.slice(22, 27)
} }
public getNose() { public getNose(): Point[] {
return this._faceLandmarks.slice(27, 36) return this._faceLandmarks.slice(27, 36)
} }
public getLeftEye() { public getLeftEye(): Point[] {
return this._faceLandmarks.slice(36, 42) return this._faceLandmarks.slice(36, 42)
} }
public getRightEye() { public getRightEye(): Point[] {
return this._faceLandmarks.slice(42, 48) return this._faceLandmarks.slice(42, 48)
} }
public getMouth() { public getMouth(): Point[] {
return this._faceLandmarks.slice(48, 68) return this._faceLandmarks.slice(48, 68)
} }
...@@ -73,4 +93,46 @@ export class FaceLandmarks { ...@@ -73,4 +93,46 @@ export class FaceLandmarks {
new Point(x, y) new Point(x, y)
) )
} }
/**
* Aligns the face landmarks after face detection from the relative positions of the faces
* bounding box, or it's current shift. This function should be used to align the face images
* after face detection has been performed, before they are passed to the face recognition net.
* This will make the computed face descriptor more accurate.
*
* @param detection (optional) The bounding box of the face or the face detection result. If
* no argument was passed the position of the face landmarks are assumed to be relative to
* it's current shift.
* @returns The bounding box of the aligned face.
*/
public align(
detection?: Rect
): Rect {
if (detection) {
const box = detection instanceof FaceDetection
? detection.getBox().floor()
: detection
return this.shift(box.x, box.y).align()
}
const centers = [
this.getLeftEye(),
this.getRightEye(),
this.getMouth()
].map(getCenterPoint)
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
const size = Math.floor(eyeToMouthDist / relScale)
const refPoint = getCenterPoint(centers)
// TODO: pad in case rectangle is out of image bounds
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)))
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)))
return new Rect(x, y, size, size)
}
} }
\ No newline at end of file
...@@ -4,7 +4,7 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -4,7 +4,7 @@ import * as tf from '@tensorflow/tfjs-core';
* Pads the smaller dimension of an image tensor with zeros, such that width === height. * Pads the smaller dimension of an image tensor with zeros, such that width === height.
* *
* @param imgTensor The image tensor. * @param imgTensor The image tensor.
* @param isCenterImage (optional, default: false) If true, add padding on both sides of the image, such that the image * @param isCenterImage (optional, default: false) If true, add padding on both sides of the image, such that the image.
* @returns The padded tensor with width === height. * @returns The padded tensor with width === height.
*/ */
export function padToSquare( export function padToSquare(
......
import * as tf from '@tensorflow/tfjs-core';
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks'; import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
import { Dimensions, DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
import { Point } from './Point'; import { Point } from './Point';
import { Dimensions, DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
export function isFloat(num: number) { export function isFloat(num: number) {
return num % 1 !== 0 return num % 1 !== 0
...@@ -68,6 +70,18 @@ export function bufferToImage(buf: Blob): Promise<HTMLImageElement> { ...@@ -68,6 +70,18 @@ export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
}) })
} }
export async function imageTensorToCanvas(
imgTensor: tf.Tensor4D,
canvas?: HTMLCanvasElement
): Promise<HTMLCanvasElement> {
const targetCanvas = canvas || document.createElement('canvas')
const [_, height, width, numChannels] = imgTensor.shape
await tf.toPixels(imgTensor.as3D(height, width, numChannels).toInt(), targetCanvas)
return targetCanvas
}
export function getDefaultDrawOptions(): DrawOptions { export function getDefaultDrawOptions(): DrawOptions {
return { return {
color: 'blue', color: 'blue',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment