Commit 0ea55f92 by vincent

fixed mtcnn face landmark positions

parent abf82d46
import { FaceDetection } from './FaceDetection';
import { FaceLandmarks } from './FaceLandmarks';
import { FaceLandmarks68 } from './FaceLandmarks68';
export class FaceDetectionWithLandmarks {
export class FaceDetectionWithLandmarks<TFaceLandmarks extends FaceLandmarks = FaceLandmarks68> {
private _detection: FaceDetection
private _relativeLandmarks: FaceLandmarks
private _unshiftedLandmarks: TFaceLandmarks
constructor(
detection: FaceDetection,
relativeLandmarks: FaceLandmarks
unshiftedLandmarks: TFaceLandmarks
) {
this._detection = detection
this._relativeLandmarks = relativeLandmarks
this._unshiftedLandmarks = unshiftedLandmarks
}
public get detection(): FaceDetection { return this._detection }
public get relativeLandmarks(): FaceLandmarks { return this._relativeLandmarks }
public get unshiftedLandmarks(): TFaceLandmarks { return this._unshiftedLandmarks }
public get alignedRect(): FaceDetection {
const rect = this.landmarks.align()
......@@ -22,14 +23,18 @@ export class FaceDetectionWithLandmarks {
return new FaceDetection(this._detection.score, rect.rescale(imageDims.reverse()), imageDims)
}
public get landmarks(): FaceLandmarks {
public get landmarks(): TFaceLandmarks {
const { x, y } = this.detection.box
return this._relativeLandmarks.shift(x, y)
return this._unshiftedLandmarks.shiftBy(x, y)
}
public forSize(width: number, height: number): FaceDetectionWithLandmarks {
// aliases for backward compatibily
get faceDetection(): FaceDetection { return this.detection }
get faceLandmarks(): TFaceLandmarks { return this.landmarks }
public forSize(width: number, height: number): FaceDetectionWithLandmarks<TFaceLandmarks> {
const resizedDetection = this._detection.forSize(width, height)
const resizedLandmarks = this._relativeLandmarks.forSize(resizedDetection.box.width, resizedDetection.box.height)
return new FaceDetectionWithLandmarks(resizedDetection, resizedLandmarks)
const resizedLandmarks = this._unshiftedLandmarks.forSize<TFaceLandmarks>(resizedDetection.box.width, resizedDetection.box.height)
return new FaceDetectionWithLandmarks<TFaceLandmarks>(resizedDetection, resizedLandmarks)
}
}
\ No newline at end of file
......@@ -25,10 +25,7 @@ export class FaceLandmarks {
)
}
public getShift(): Point {
return new Point(this._shift.x, this._shift.y)
}
public get shift(): Point { return new Point(this._shift.x, this._shift.y) }
public get imageWidth(): number { return this._imgDims.width }
public get imageHeight(): number { return this._imgDims.height }
public get positions(): Point[] { return this._positions }
......@@ -45,7 +42,7 @@ export class FaceLandmarks {
)
}
public shift<T extends FaceLandmarks>(x: number, y: number): T {
public shiftBy<T extends FaceLandmarks>(x: number, y: number): T {
return new (this.constructor as any)(
this.relativePositions,
this._imgDims,
......@@ -54,7 +51,7 @@ export class FaceLandmarks {
}
public shiftByPoint<T extends FaceLandmarks>(pt: Point): T {
return this.shift(pt.x, pt.y)
return this.shiftBy(pt.x, pt.y)
}
/**
......@@ -76,7 +73,7 @@ export class FaceLandmarks {
? detection.box.floor()
: detection
return this.shift(box.x, box.y).align()
return this.shiftBy(box.x, box.y).align()
}
const centers = this.getRefPointsForAlignment()
......
import { FaceDetection } from './FaceDetection';
import { FaceDetectionWithLandmarks } from './FaceDetectionWithLandmarks';
import { FaceLandmarks } from './FaceLandmarks';
import { FaceLandmarks68 } from './FaceLandmarks68';
export class FullFaceDescription extends FaceDetectionWithLandmarks {
export class FullFaceDescription<TFaceLandmarks extends FaceLandmarks = FaceLandmarks68> extends FaceDetectionWithLandmarks<TFaceLandmarks> {
private _descriptor: Float32Array
constructor(
detection: FaceDetection,
landmarks: FaceLandmarks,
unshiftedLandmarks: TFaceLandmarks,
descriptor: Float32Array
) {
super(detection, landmarks)
super(detection, unshiftedLandmarks)
this._descriptor = descriptor
}
......@@ -18,8 +19,8 @@ export class FullFaceDescription extends FaceDetectionWithLandmarks {
return this._descriptor
}
public forSize(width: number, height: number): FullFaceDescription {
public forSize(width: number, height: number): FullFaceDescription<TFaceLandmarks> {
const { detection, landmarks } = super.forSize(width, height)
return new FullFaceDescription(detection, landmarks, this.descriptor)
return new FullFaceDescription<TFaceLandmarks>(detection, landmarks, this.descriptor)
}
}
\ No newline at end of file
......@@ -2,13 +2,13 @@ import { TNetInput } from 'tfjs-image-recognition-base';
import { ITinyYolov2Options } from 'tfjs-tiny-yolov2';
import { FaceDetection } from '../classes/FaceDetection';
import { FaceDetectionWithLandmarks } from '../classes/FaceDetectionWithLandmarks';
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
import { FaceLandmark68Net } from '../faceLandmarkNet/FaceLandmark68Net';
import { FaceLandmark68TinyNet } from '../faceLandmarkNet/FaceLandmark68TinyNet';
import { FaceRecognitionNet } from '../faceRecognitionNet/FaceRecognitionNet';
import { Mtcnn } from '../mtcnn/Mtcnn';
import { MtcnnOptions } from '../mtcnn/MtcnnOptions';
import { MtcnnResult } from '../mtcnn/MtcnnResult';
import { SsdMobilenetv1 } from '../ssdMobilenetv1/SsdMobilenetv1';
import { SsdMobilenetv1Options } from '../ssdMobilenetv1/SsdMobilenetv1Options';
import { TinyFaceDetector } from '../tinyFaceDetector/TinyFaceDetector';
......@@ -63,7 +63,7 @@ export const tinyYolov2 = (input: TNetInput, options: ITinyYolov2Options): Promi
* @param options (optional, default: see MtcnnOptions constructor for default parameters).
* @returns Bounding box of each face with score and 5 point face landmarks.
*/
export const mtcnn = (input: TNetInput, options: MtcnnOptions): Promise<MtcnnResult[]> =>
export const mtcnn = (input: TNetInput, options: MtcnnOptions): Promise<FaceDetectionWithLandmarks[]> =>
nets.mtcnn.forward(input, options)
/**
......
......@@ -9,12 +9,12 @@ import { extractParams } from './extractParams';
import { getSizesForScale } from './getSizesForScale';
import { loadQuantizedParams } from './loadQuantizedParams';
import { IMtcnnOptions, MtcnnOptions } from './MtcnnOptions';
import { MtcnnResult } from './MtcnnResult';
import { pyramidDown } from './pyramidDown';
import { stage1 } from './stage1';
import { stage2 } from './stage2';
import { stage3 } from './stage3';
import { NetParams } from './types';
import { FaceDetectionWithLandmarks } from '../classes/FaceDetectionWithLandmarks';
export class Mtcnn extends NeuralNetwork<NetParams> {
......@@ -25,7 +25,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forwardInput(
input: NetInput,
forwardParams: IMtcnnOptions = {}
): Promise<{ results: MtcnnResult[], stats: any }> {
): Promise<{ results: FaceDetectionWithLandmarks[], stats: any }> {
const { params } = this
......@@ -101,7 +101,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
const out3 = await stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet, stats)
stats.total_stage3 = Date.now() - ts
const results = out3.boxes.map((box, idx) => new MtcnnResult(
const results = out3.boxes.map((box, idx) => new FaceDetectionWithLandmarks(
new FaceDetection(
out3.scores[idx],
new Rect(
......@@ -116,8 +116,8 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
}
),
new FaceLandmarks5(
out3.points[idx].map(pt => pt.div(new Point(width, height))),
{ width, height }
out3.points[idx].map(pt => pt.sub(new Point(box.left, box.top)).div(new Point(box.width, box.height))),
{ width: box.width, height: box.height }
)
))
......@@ -127,7 +127,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forward(
input: TNetInput,
forwardParams: IMtcnnOptions = {}
): Promise<MtcnnResult[]> {
): Promise<FaceDetectionWithLandmarks[]> {
return (
await this.forwardInput(
await toNetInput(input),
......@@ -139,7 +139,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forwardWithStats(
input: TNetInput,
forwardParams: IMtcnnOptions = {}
): Promise<{ results: MtcnnResult[], stats: any }> {
): Promise<{ results: FaceDetectionWithLandmarks[], stats: any }> {
return this.forwardInput(
await toNetInput(input),
forwardParams
......
import { FaceDetection } from '../classes/FaceDetection';
import { FaceDetectionWithLandmarks } from '../classes/FaceDetectionWithLandmarks';
import { FaceLandmarks5 } from '../classes/FaceLandmarks5';
export class MtcnnResult extends FaceDetectionWithLandmarks {
// aliases for backward compatibily
get faceDetection(): FaceDetection { return this.detection }
get faceLandmarks(): FaceLandmarks5 { return this.faceLandmarks }
}
\ No newline at end of file
......@@ -46,7 +46,7 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
public async locateFaces(
input: TNetInput,
options: ISsdMobilenetv1Options
options: ISsdMobilenetv1Options = {}
): Promise<FaceDetection[]> {
const { maxResults, minConfidence } = new SsdMobilenetv1Options(options)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment