Commit 930f85b0 by vincent

face alignment from 5 point face landmarks + allFacesMtcnn

parent 049997bc
import { Point } from './Point'; import { getCenterPoint } from './commons/getCenterPoint';
import { FaceDetection } from './FaceDetection';
import { IPoint, Point } from './Point';
import { Rect } from './Rect';
import { Dimensions } from './types'; import { Dimensions } from './types';
// face alignment constants
const relX = 0.5
const relY = 0.43
const relScale = 0.45
export class FaceLandmarks { export class FaceLandmarks {
protected _imageWidth: number protected _imageWidth: number
protected _imageHeight: number protected _imageHeight: number
...@@ -42,4 +50,65 @@ export class FaceLandmarks { ...@@ -42,4 +50,65 @@ export class FaceLandmarks {
pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight)) pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
) )
} }
public forSize<T extends FaceLandmarks>(width: number, height: number): T {
return new (this.constructor as any)(
this.getRelativePositions(),
{ width, height }
)
}
public shift<T extends FaceLandmarks>(x: number, y: number): T {
return new (this.constructor as any)(
this.getRelativePositions(),
{ width: this._imageWidth, height: this._imageHeight },
new Point(x, y)
)
}
public shiftByPoint<T extends FaceLandmarks>(pt: IPoint): T {
return this.shift(pt.x, pt.y)
}
/**
* Aligns the face landmarks after face detection from the relative positions of the faces
* bounding box, or it's current shift. This function should be used to align the face images
* after face detection has been performed, before they are passed to the face recognition net.
* This will make the computed face descriptor more accurate.
*
* @param detection (optional) The bounding box of the face or the face detection result. If
* no argument was passed the position of the face landmarks are assumed to be relative to
* it's current shift.
* @returns The bounding box of the aligned face.
*/
public align(
detection?: FaceDetection | Rect
): Rect {
if (detection) {
const box = detection instanceof FaceDetection
? detection.getBox().floor()
: detection
return this.shift(box.x, box.y).align()
}
const centers = this.getRefPointsForAlignment()
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
const size = Math.floor(eyeToMouthDist / relScale)
const refPoint = getCenterPoint(centers)
// TODO: pad in case rectangle is out of image bounds
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)))
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)))
return new Rect(x, y, Math.min(size, this._imageWidth - x), Math.min(size, this._imageHeight - y))
}
protected getRefPointsForAlignment(): Point[] {
throw new Error('getRefPointsForAlignment not implemented by base class')
}
} }
\ No newline at end of file
import { FaceDetection } from './FaceDetection'; import { FaceDetection } from './FaceDetection';
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68'; import { FaceLandmarks } from './FaceLandmarks';
export class FullFaceDescription { export class FullFaceDescription {
constructor( constructor(
private _detection: FaceDetection, private _detection: FaceDetection,
private _landmarks: FaceLandmarks68, private _landmarks: FaceLandmarks,
private _descriptor: Float32Array private _descriptor: Float32Array
) {} ) {}
...@@ -12,7 +12,7 @@ export class FullFaceDescription { ...@@ -12,7 +12,7 @@ export class FullFaceDescription {
return this._detection return this._detection
} }
public get landmarks(): FaceLandmarks68 { public get landmarks(): FaceLandmarks {
return this._landmarks return this._landmarks
} }
......
...@@ -2,14 +2,16 @@ import { extractFaceTensors } from './extractFaceTensors'; ...@@ -2,14 +2,16 @@ import { extractFaceTensors } from './extractFaceTensors';
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet'; import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet'; import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68'; import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
import { FullFaceDescription } from './FullFaceDescription'; import { FullFaceDescription } from './FullFaceDescription';
import { Mtcnn } from './mtcnn/Mtcnn';
import { MtcnnForwardParams } from './mtcnn/types';
import { Rect } from './Rect';
import { TNetInput } from './types'; import { TNetInput } from './types';
export function allFacesFactory( export function allFacesFactory(
detectionNet: FaceDetectionNet, detectionNet: FaceDetectionNet,
landmarkNet: FaceLandmarkNet, landmarkNet: FaceLandmarkNet,
recognitionNet: FaceRecognitionNet computeDescriptors: (input: TNetInput, alignedFaceBoxes: Rect[], useBatchProcessing: boolean) => Promise<Float32Array[]>
) { ) {
return async function( return async function(
input: TNetInput, input: TNetInput,
...@@ -32,20 +34,42 @@ export function allFacesFactory( ...@@ -32,20 +34,42 @@ export function allFacesFactory(
const alignedFaceBoxes = faceLandmarksByFace.map( const alignedFaceBoxes = faceLandmarksByFace.map(
(landmarks, i) => landmarks.align(detections[i].getBox()) (landmarks, i) => landmarks.align(detections[i].getBox())
) )
const alignedFaceTensors = await extractFaceTensors(input, alignedFaceBoxes)
const descriptors = useBatchProcessing const descriptors = await computeDescriptors(input, alignedFaceBoxes, useBatchProcessing)
? await recognitionNet.computeFaceDescriptor(alignedFaceTensors) as Float32Array[]
: await Promise.all(alignedFaceTensors.map(
faceTensor => recognitionNet.computeFaceDescriptor(faceTensor)
)) as Float32Array[]
alignedFaceTensors.forEach(t => t.dispose())
return detections.map((detection, i) => return detections.map((detection, i) =>
new FullFaceDescription( new FullFaceDescription(
detection, detection,
faceLandmarksByFace[i].shiftByPoint(detection.getBox()), faceLandmarksByFace[i].shiftByPoint<FaceLandmarks68>(detection.getBox()),
descriptors[i]
)
)
}
}
export function allFacesMtcnnFactory(
mtcnn: Mtcnn,
computeDescriptors: (input: TNetInput, alignedFaceBoxes: Rect[], useBatchProcessing: boolean) => Promise<Float32Array[]>
) {
return async function(
input: TNetInput,
mtcnnForwardParams: MtcnnForwardParams,
useBatchProcessing: boolean = false
): Promise<FullFaceDescription[]> {
const results = await mtcnn.forward(input, mtcnnForwardParams)
const alignedFaceBoxes = results.map(
({ faceLandmarks }) => faceLandmarks.align()
)
const descriptors = await computeDescriptors(input, alignedFaceBoxes, useBatchProcessing)
return results.map(({ faceDetection, faceLandmarks }, i) =>
new FullFaceDescription(
faceDetection,
faceLandmarks,
descriptors[i] descriptors[i]
) )
) )
......
import { getCenterPoint } from '../commons/getCenterPoint'; import { getCenterPoint } from '../commons/getCenterPoint';
import { FaceDetection } from '../FaceDetection'; import { FaceDetection } from '../FaceDetection';
import { FaceLandmarks } from '../FaceLandmarks'; import { FaceLandmarks } from '../FaceLandmarks';
import { IPoint, Point } from '../Point'; import { Point } from '../Point';
import { Rect } from '../Rect'; import { Rect } from '../Rect';
// face alignment constants
const relX = 0.5
const relY = 0.43
const relScale = 0.45
export class FaceLandmarks68 extends FaceLandmarks { export class FaceLandmarks68 extends FaceLandmarks {
public getJawOutline(): Point[] { public getJawOutline(): Point[] {
return this._faceLandmarks.slice(0, 17) return this._faceLandmarks.slice(0, 17)
...@@ -38,64 +33,11 @@ export class FaceLandmarks68 extends FaceLandmarks { ...@@ -38,64 +33,11 @@ export class FaceLandmarks68 extends FaceLandmarks {
return this._faceLandmarks.slice(48, 68) return this._faceLandmarks.slice(48, 68)
} }
public forSize(width: number, height: number): FaceLandmarks68 { protected getRefPointsForAlignment(): Point[] {
return new FaceLandmarks68( return [
this.getRelativePositions(),
{ width, height }
)
}
public shift(x: number, y: number): FaceLandmarks68 {
return new FaceLandmarks68(
this.getRelativePositions(),
{ width: this._imageWidth, height: this._imageHeight },
new Point(x, y)
)
}
public shiftByPoint(pt: IPoint): FaceLandmarks68 {
return this.shift(pt.x, pt.y)
}
/**
* Aligns the face landmarks after face detection from the relative positions of the faces
* bounding box, or it's current shift. This function should be used to align the face images
* after face detection has been performed, before they are passed to the face recognition net.
* This will make the computed face descriptor more accurate.
*
* @param detection (optional) The bounding box of the face or the face detection result. If
* no argument was passed the position of the face landmarks are assumed to be relative to
* it's current shift.
* @returns The bounding box of the aligned face.
*/
public align(
detection?: FaceDetection | Rect
): Rect {
if (detection) {
const box = detection instanceof FaceDetection
? detection.getBox().floor()
: detection
return this.shift(box.x, box.y).align()
}
const centers = [
this.getLeftEye(), this.getLeftEye(),
this.getRightEye(), this.getRightEye(),
this.getMouth() this.getMouth()
].map(getCenterPoint) ].map(getCenterPoint)
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
const size = Math.floor(eyeToMouthDist / relScale)
const refPoint = getCenterPoint(centers)
// TODO: pad in case rectangle is out of image bounds
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)))
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)))
return new Rect(x, y, size, size)
} }
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { allFacesFactory } from './allFacesFactory'; import { allFacesFactory, allFacesMtcnnFactory } from './allFacesFactory';
import { extractFaceTensors } from './extractFaceTensors';
import { FaceDetection } from './FaceDetection'; import { FaceDetection } from './FaceDetection';
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet'; import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet'; import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68'; import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet'; import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
import { FullFaceDescription } from './FullFaceDescription'; import { FullFaceDescription } from './FullFaceDescription';
import { getDefaultMtcnnForwardParams } from './mtcnn/getDefaultMtcnnForwardParams';
import { Mtcnn } from './mtcnn/Mtcnn'; import { Mtcnn } from './mtcnn/Mtcnn';
import { MtcnnForwardParams, MtcnnResult } from './mtcnn/types'; import { MtcnnForwardParams, MtcnnResult } from './mtcnn/types';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import { Rect } from './Rect';
import { TNetInput } from './types'; import { TNetInput } from './types';
export const detectionNet = new FaceDetectionNet() export const detectionNet = new FaceDetectionNet()
...@@ -22,7 +23,7 @@ export const recognitionNet = new FaceRecognitionNet() ...@@ -22,7 +23,7 @@ export const recognitionNet = new FaceRecognitionNet()
export const nets = { export const nets = {
ssdMobilenet: detectionNet, ssdMobilenet: detectionNet,
faceLandmark68Net: landmarkNet, faceLandmark68Net: landmarkNet,
faceNet: recognitionNet, faceRecognitionNet: recognitionNet,
mtcnn: new Mtcnn() mtcnn: new Mtcnn()
} }
...@@ -35,7 +36,7 @@ export function loadFaceLandmarkModel(url: string) { ...@@ -35,7 +36,7 @@ export function loadFaceLandmarkModel(url: string) {
} }
export function loadFaceRecognitionModel(url: string) { export function loadFaceRecognitionModel(url: string) {
return nets.faceNet.load(url) return nets.faceRecognitionNet.load(url)
} }
export function loadMtcnnModel(url: string) { export function loadMtcnnModel(url: string) {
...@@ -68,7 +69,7 @@ export function detectLandmarks( ...@@ -68,7 +69,7 @@ export function detectLandmarks(
export function computeFaceDescriptor( export function computeFaceDescriptor(
input: TNetInput input: TNetInput
): Promise<Float32Array | Float32Array[]> { ): Promise<Float32Array | Float32Array[]> {
return nets.faceNet.computeFaceDescriptor(input) return nets.faceRecognitionNet.computeFaceDescriptor(input)
} }
export function mtcnn( export function mtcnn(
...@@ -85,5 +86,32 @@ export const allFaces: ( ...@@ -85,5 +86,32 @@ export const allFaces: (
) => Promise<FullFaceDescription[]> = allFacesFactory( ) => Promise<FullFaceDescription[]> = allFacesFactory(
detectionNet, detectionNet,
landmarkNet, landmarkNet,
recognitionNet computeDescriptorsFactory(nets.faceRecognitionNet)
) )
\ No newline at end of file
export const allFacesMtcnn: (
input: tf.Tensor | NetInput | TNetInput,
mtcnnForwardParams: MtcnnForwardParams,
useBatchProcessing?: boolean
) => Promise<FullFaceDescription[]> = allFacesMtcnnFactory(
nets.mtcnn,
computeDescriptorsFactory(nets.faceRecognitionNet)
)
function computeDescriptorsFactory(
recognitionNet: FaceRecognitionNet
) {
return async function(input: TNetInput, alignedFaceBoxes: Rect[], useBatchProcessing: boolean) {
const alignedFaceTensors = await extractFaceTensors(input, alignedFaceBoxes)
const descriptors = useBatchProcessing
? await recognitionNet.computeFaceDescriptor(alignedFaceTensors) as Float32Array[]
: await Promise.all(alignedFaceTensors.map(
faceTensor => recognitionNet.computeFaceDescriptor(faceTensor)
)) as Float32Array[]
alignedFaceTensors.forEach(t => t.dispose())
return descriptors
}
}
\ No newline at end of file
import { getCenterPoint } from '../commons/getCenterPoint';
import { FaceLandmarks } from '../FaceLandmarks'; import { FaceLandmarks } from '../FaceLandmarks';
import { IPoint, Point } from '../Point'; import { Point } from '../Point';
export class FaceLandmarks5 extends FaceLandmarks { export class FaceLandmarks5 extends FaceLandmarks {
public forSize(width: number, height: number): FaceLandmarks5 { protected getRefPointsForAlignment(): Point[] {
return new FaceLandmarks5( const pts = this.getPositions()
this.getRelativePositions(), return [
{ width, height } pts[0],
) pts[1],
} getCenterPoint([pts[3], pts[4]])
]
public shift(x: number, y: number): FaceLandmarks5 {
return new FaceLandmarks5(
this.getRelativePositions(),
{ width: this._imageWidth, height: this._imageHeight },
new Point(x, y)
)
}
public shiftByPoint(pt: IPoint): FaceLandmarks5 {
return this.shift(pt.x, pt.y)
} }
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment