Commit 0ea55f92 by vincent

fixed mtcnn face landmark positions

parent abf82d46
import { FaceDetection } from './FaceDetection'; import { FaceDetection } from './FaceDetection';
import { FaceLandmarks } from './FaceLandmarks'; import { FaceLandmarks } from './FaceLandmarks';
import { FaceLandmarks68 } from './FaceLandmarks68';
export class FaceDetectionWithLandmarks { export class FaceDetectionWithLandmarks<TFaceLandmarks extends FaceLandmarks = FaceLandmarks68> {
private _detection: FaceDetection private _detection: FaceDetection
private _relativeLandmarks: FaceLandmarks private _unshiftedLandmarks: TFaceLandmarks
constructor( constructor(
detection: FaceDetection, detection: FaceDetection,
relativeLandmarks: FaceLandmarks unshiftedLandmarks: TFaceLandmarks
) { ) {
this._detection = detection this._detection = detection
this._relativeLandmarks = relativeLandmarks this._unshiftedLandmarks = unshiftedLandmarks
} }
public get detection(): FaceDetection { return this._detection } public get detection(): FaceDetection { return this._detection }
public get relativeLandmarks(): FaceLandmarks { return this._relativeLandmarks } public get unshiftedLandmarks(): TFaceLandmarks { return this._unshiftedLandmarks }
public get alignedRect(): FaceDetection { public get alignedRect(): FaceDetection {
const rect = this.landmarks.align() const rect = this.landmarks.align()
...@@ -22,14 +23,18 @@ export class FaceDetectionWithLandmarks { ...@@ -22,14 +23,18 @@ export class FaceDetectionWithLandmarks {
return new FaceDetection(this._detection.score, rect.rescale(imageDims.reverse()), imageDims) return new FaceDetection(this._detection.score, rect.rescale(imageDims.reverse()), imageDims)
} }
public get landmarks(): FaceLandmarks { public get landmarks(): TFaceLandmarks {
const { x, y } = this.detection.box const { x, y } = this.detection.box
return this._relativeLandmarks.shift(x, y) return this._unshiftedLandmarks.shiftBy(x, y)
} }
public forSize(width: number, height: number): FaceDetectionWithLandmarks { // aliases for backward compatibily
get faceDetection(): FaceDetection { return this.detection }
get faceLandmarks(): TFaceLandmarks { return this.landmarks }
public forSize(width: number, height: number): FaceDetectionWithLandmarks<TFaceLandmarks> {
const resizedDetection = this._detection.forSize(width, height) const resizedDetection = this._detection.forSize(width, height)
const resizedLandmarks = this._relativeLandmarks.forSize(resizedDetection.box.width, resizedDetection.box.height) const resizedLandmarks = this._unshiftedLandmarks.forSize<TFaceLandmarks>(resizedDetection.box.width, resizedDetection.box.height)
return new FaceDetectionWithLandmarks(resizedDetection, resizedLandmarks) return new FaceDetectionWithLandmarks<TFaceLandmarks>(resizedDetection, resizedLandmarks)
} }
} }
\ No newline at end of file
...@@ -25,10 +25,7 @@ export class FaceLandmarks { ...@@ -25,10 +25,7 @@ export class FaceLandmarks {
) )
} }
public getShift(): Point { public get shift(): Point { return new Point(this._shift.x, this._shift.y) }
return new Point(this._shift.x, this._shift.y)
}
public get imageWidth(): number { return this._imgDims.width } public get imageWidth(): number { return this._imgDims.width }
public get imageHeight(): number { return this._imgDims.height } public get imageHeight(): number { return this._imgDims.height }
public get positions(): Point[] { return this._positions } public get positions(): Point[] { return this._positions }
...@@ -45,7 +42,7 @@ export class FaceLandmarks { ...@@ -45,7 +42,7 @@ export class FaceLandmarks {
) )
} }
public shift<T extends FaceLandmarks>(x: number, y: number): T { public shiftBy<T extends FaceLandmarks>(x: number, y: number): T {
return new (this.constructor as any)( return new (this.constructor as any)(
this.relativePositions, this.relativePositions,
this._imgDims, this._imgDims,
...@@ -54,7 +51,7 @@ export class FaceLandmarks { ...@@ -54,7 +51,7 @@ export class FaceLandmarks {
} }
public shiftByPoint<T extends FaceLandmarks>(pt: Point): T { public shiftByPoint<T extends FaceLandmarks>(pt: Point): T {
return this.shift(pt.x, pt.y) return this.shiftBy(pt.x, pt.y)
} }
/** /**
...@@ -76,7 +73,7 @@ export class FaceLandmarks { ...@@ -76,7 +73,7 @@ export class FaceLandmarks {
? detection.box.floor() ? detection.box.floor()
: detection : detection
return this.shift(box.x, box.y).align() return this.shiftBy(box.x, box.y).align()
} }
const centers = this.getRefPointsForAlignment() const centers = this.getRefPointsForAlignment()
......
import { FaceDetection } from './FaceDetection'; import { FaceDetection } from './FaceDetection';
import { FaceDetectionWithLandmarks } from './FaceDetectionWithLandmarks'; import { FaceDetectionWithLandmarks } from './FaceDetectionWithLandmarks';
import { FaceLandmarks } from './FaceLandmarks'; import { FaceLandmarks } from './FaceLandmarks';
import { FaceLandmarks68 } from './FaceLandmarks68';
export class FullFaceDescription extends FaceDetectionWithLandmarks { export class FullFaceDescription<TFaceLandmarks extends FaceLandmarks = FaceLandmarks68> extends FaceDetectionWithLandmarks<TFaceLandmarks> {
private _descriptor: Float32Array private _descriptor: Float32Array
constructor( constructor(
detection: FaceDetection, detection: FaceDetection,
landmarks: FaceLandmarks, unshiftedLandmarks: TFaceLandmarks,
descriptor: Float32Array descriptor: Float32Array
) { ) {
super(detection, landmarks) super(detection, unshiftedLandmarks)
this._descriptor = descriptor this._descriptor = descriptor
} }
...@@ -18,8 +19,8 @@ export class FullFaceDescription extends FaceDetectionWithLandmarks { ...@@ -18,8 +19,8 @@ export class FullFaceDescription extends FaceDetectionWithLandmarks {
return this._descriptor return this._descriptor
} }
public forSize(width: number, height: number): FullFaceDescription { public forSize(width: number, height: number): FullFaceDescription<TFaceLandmarks> {
const { detection, landmarks } = super.forSize(width, height) const { detection, landmarks } = super.forSize(width, height)
return new FullFaceDescription(detection, landmarks, this.descriptor) return new FullFaceDescription<TFaceLandmarks>(detection, landmarks, this.descriptor)
} }
} }
\ No newline at end of file
...@@ -2,13 +2,13 @@ import { TNetInput } from 'tfjs-image-recognition-base'; ...@@ -2,13 +2,13 @@ import { TNetInput } from 'tfjs-image-recognition-base';
import { ITinyYolov2Options } from 'tfjs-tiny-yolov2'; import { ITinyYolov2Options } from 'tfjs-tiny-yolov2';
import { FaceDetection } from '../classes/FaceDetection'; import { FaceDetection } from '../classes/FaceDetection';
import { FaceDetectionWithLandmarks } from '../classes/FaceDetectionWithLandmarks';
import { FaceLandmarks68 } from '../classes/FaceLandmarks68'; import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
import { FaceLandmark68Net } from '../faceLandmarkNet/FaceLandmark68Net'; import { FaceLandmark68Net } from '../faceLandmarkNet/FaceLandmark68Net';
import { FaceLandmark68TinyNet } from '../faceLandmarkNet/FaceLandmark68TinyNet'; import { FaceLandmark68TinyNet } from '../faceLandmarkNet/FaceLandmark68TinyNet';
import { FaceRecognitionNet } from '../faceRecognitionNet/FaceRecognitionNet'; import { FaceRecognitionNet } from '../faceRecognitionNet/FaceRecognitionNet';
import { Mtcnn } from '../mtcnn/Mtcnn'; import { Mtcnn } from '../mtcnn/Mtcnn';
import { MtcnnOptions } from '../mtcnn/MtcnnOptions'; import { MtcnnOptions } from '../mtcnn/MtcnnOptions';
import { MtcnnResult } from '../mtcnn/MtcnnResult';
import { SsdMobilenetv1 } from '../ssdMobilenetv1/SsdMobilenetv1'; import { SsdMobilenetv1 } from '../ssdMobilenetv1/SsdMobilenetv1';
import { SsdMobilenetv1Options } from '../ssdMobilenetv1/SsdMobilenetv1Options'; import { SsdMobilenetv1Options } from '../ssdMobilenetv1/SsdMobilenetv1Options';
import { TinyFaceDetector } from '../tinyFaceDetector/TinyFaceDetector'; import { TinyFaceDetector } from '../tinyFaceDetector/TinyFaceDetector';
...@@ -63,7 +63,7 @@ export const tinyYolov2 = (input: TNetInput, options: ITinyYolov2Options): Promi ...@@ -63,7 +63,7 @@ export const tinyYolov2 = (input: TNetInput, options: ITinyYolov2Options): Promi
* @param options (optional, default: see MtcnnOptions constructor for default parameters). * @param options (optional, default: see MtcnnOptions constructor for default parameters).
* @returns Bounding box of each face with score and 5 point face landmarks. * @returns Bounding box of each face with score and 5 point face landmarks.
*/ */
export const mtcnn = (input: TNetInput, options: MtcnnOptions): Promise<MtcnnResult[]> => export const mtcnn = (input: TNetInput, options: MtcnnOptions): Promise<FaceDetectionWithLandmarks[]> =>
nets.mtcnn.forward(input, options) nets.mtcnn.forward(input, options)
/** /**
......
...@@ -9,12 +9,12 @@ import { extractParams } from './extractParams'; ...@@ -9,12 +9,12 @@ import { extractParams } from './extractParams';
import { getSizesForScale } from './getSizesForScale'; import { getSizesForScale } from './getSizesForScale';
import { loadQuantizedParams } from './loadQuantizedParams'; import { loadQuantizedParams } from './loadQuantizedParams';
import { IMtcnnOptions, MtcnnOptions } from './MtcnnOptions'; import { IMtcnnOptions, MtcnnOptions } from './MtcnnOptions';
import { MtcnnResult } from './MtcnnResult';
import { pyramidDown } from './pyramidDown'; import { pyramidDown } from './pyramidDown';
import { stage1 } from './stage1'; import { stage1 } from './stage1';
import { stage2 } from './stage2'; import { stage2 } from './stage2';
import { stage3 } from './stage3'; import { stage3 } from './stage3';
import { NetParams } from './types'; import { NetParams } from './types';
import { FaceDetectionWithLandmarks } from '../classes/FaceDetectionWithLandmarks';
export class Mtcnn extends NeuralNetwork<NetParams> { export class Mtcnn extends NeuralNetwork<NetParams> {
...@@ -25,7 +25,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -25,7 +25,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forwardInput( public async forwardInput(
input: NetInput, input: NetInput,
forwardParams: IMtcnnOptions = {} forwardParams: IMtcnnOptions = {}
): Promise<{ results: MtcnnResult[], stats: any }> { ): Promise<{ results: FaceDetectionWithLandmarks[], stats: any }> {
const { params } = this const { params } = this
...@@ -101,7 +101,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -101,7 +101,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
const out3 = await stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet, stats) const out3 = await stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet, stats)
stats.total_stage3 = Date.now() - ts stats.total_stage3 = Date.now() - ts
const results = out3.boxes.map((box, idx) => new MtcnnResult( const results = out3.boxes.map((box, idx) => new FaceDetectionWithLandmarks(
new FaceDetection( new FaceDetection(
out3.scores[idx], out3.scores[idx],
new Rect( new Rect(
...@@ -116,8 +116,8 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -116,8 +116,8 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
} }
), ),
new FaceLandmarks5( new FaceLandmarks5(
out3.points[idx].map(pt => pt.div(new Point(width, height))), out3.points[idx].map(pt => pt.sub(new Point(box.left, box.top)).div(new Point(box.width, box.height))),
{ width, height } { width: box.width, height: box.height }
) )
)) ))
...@@ -127,7 +127,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -127,7 +127,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forward( public async forward(
input: TNetInput, input: TNetInput,
forwardParams: IMtcnnOptions = {} forwardParams: IMtcnnOptions = {}
): Promise<MtcnnResult[]> { ): Promise<FaceDetectionWithLandmarks[]> {
return ( return (
await this.forwardInput( await this.forwardInput(
await toNetInput(input), await toNetInput(input),
...@@ -139,7 +139,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -139,7 +139,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forwardWithStats( public async forwardWithStats(
input: TNetInput, input: TNetInput,
forwardParams: IMtcnnOptions = {} forwardParams: IMtcnnOptions = {}
): Promise<{ results: MtcnnResult[], stats: any }> { ): Promise<{ results: FaceDetectionWithLandmarks[], stats: any }> {
return this.forwardInput( return this.forwardInput(
await toNetInput(input), await toNetInput(input),
forwardParams forwardParams
......
import { FaceDetection } from '../classes/FaceDetection';
import { FaceDetectionWithLandmarks } from '../classes/FaceDetectionWithLandmarks';
import { FaceLandmarks5 } from '../classes/FaceLandmarks5';
export class MtcnnResult extends FaceDetectionWithLandmarks {
// aliases for backward compatibily
get faceDetection(): FaceDetection { return this.detection }
get faceLandmarks(): FaceLandmarks5 { return this.faceLandmarks }
}
\ No newline at end of file
...@@ -46,7 +46,7 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> { ...@@ -46,7 +46,7 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
public async locateFaces( public async locateFaces(
input: TNetInput, input: TNetInput,
options: ISsdMobilenetv1Options options: ISsdMobilenetv1Options = {}
): Promise<FaceDetection[]> { ): Promise<FaceDetection[]> {
const { maxResults, minConfidence } = new SsdMobilenetv1Options(options) const { maxResults, minConfidence } = new SsdMobilenetv1Options(options)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment