Commit 4205fc29 by vincent

allow face alignment before classification

parent 256ee654
import { Dimensions, getCenterPoint, IDimensions, Point, Rect } from 'tfjs-image-recognition-base'; import { Box, Dimensions, getCenterPoint, IBoundingBox, IDimensions, IRect, Point, Rect } from 'tfjs-image-recognition-base';
import { minBbox } from '../minBbox';
import { FaceDetection } from './FaceDetection'; import { FaceDetection } from './FaceDetection';
// face alignment constants // face alignment constants
...@@ -71,16 +72,28 @@ export class FaceLandmarks implements IFaceLandmarks { ...@@ -71,16 +72,28 @@ export class FaceLandmarks implements IFaceLandmarks {
* @returns The bounding box of the aligned face. * @returns The bounding box of the aligned face.
*/ */
public align( public align(
detection?: FaceDetection | Rect detection?: FaceDetection | IRect | IBoundingBox | null,
): Rect { options: { useDlibAlignment?: boolean, minBoxPadding?: number } = { }
): Box {
if (detection) { if (detection) {
const box = detection instanceof FaceDetection const box = detection instanceof FaceDetection
? detection.box.floor() ? detection.box.floor()
: detection : new Box(detection)
return this.shiftBy(box.x, box.y).align() return this.shiftBy(box.x, box.y).align(null, options)
} }
const { useDlibAlignment, minBoxPadding } = Object.assign({}, { useDlibAlignment: false, minBoxPadding: 0.2 }, options)
if (useDlibAlignment) {
return this.alignDlib()
}
return this.alignMinBbox(minBoxPadding)
}
private alignDlib(): Box {
const centers = this.getRefPointsForAlignment() const centers = this.getRefPointsForAlignment()
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
...@@ -97,6 +110,11 @@ export class FaceLandmarks implements IFaceLandmarks { ...@@ -97,6 +110,11 @@ export class FaceLandmarks implements IFaceLandmarks {
return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y)) return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y))
} }
private alignMinBbox(padding: number): Box {
const box = minBbox(this.positions)
return box.pad(box.width * padding, box.height * padding)
}
protected getRefPointsForAlignment(): Point[] { protected getRefPointsForAlignment(): Point[] {
throw new Error('getRefPointsForAlignment not implemented by base class') throw new Error('getRefPointsForAlignment not implemented by base class')
} }
......
...@@ -25,17 +25,17 @@ export class ComputeAllFaceDescriptorsTask< ...@@ -25,17 +25,17 @@ export class ComputeAllFaceDescriptorsTask<
const parentResults = await this.parentTask const parentResults = await this.parentTask
const alignedRects = parentResults.map(({ alignedRect }) => alignedRect) const dlibAlignedRects = parentResults.map(({ landmarks }) => landmarks.align(null, { useDlibAlignment: true }))
const alignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor const dlibAlignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
? await extractFaceTensors(this.input, alignedRects) ? await extractFaceTensors(this.input, dlibAlignedRects)
: await extractFaces(this.input, alignedRects) : await extractFaces(this.input, dlibAlignedRects)
const results = await Promise.all(parentResults.map(async (parentResult, i) => { const results = await Promise.all(parentResults.map(async (parentResult, i) => {
const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(alignedFaces[i]) as Float32Array const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(dlibAlignedFaces[i]) as Float32Array
return extendWithFaceDescriptor<TSource>(parentResult, descriptor) return extendWithFaceDescriptor<TSource>(parentResult, descriptor)
})) }))
alignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose()) dlibAlignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose())
return results return results
} }
...@@ -52,10 +52,10 @@ export class ComputeSingleFaceDescriptorTask< ...@@ -52,10 +52,10 @@ export class ComputeSingleFaceDescriptorTask<
return return
} }
const { alignedRect } = parentResult const dlibAlignedRect = parentResult.landmarks.align(null, { useDlibAlignment: true })
const alignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor const alignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
? await extractFaceTensors(this.input, [alignedRect]) ? await extractFaceTensors(this.input, [dlibAlignedRect])
: await extractFaces(this.input, [alignedRect]) : await extractFaces(this.input, [dlibAlignedRect])
const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(alignedFaces[0]) as Float32Array const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(alignedFaces[0]) as Float32Array
alignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose()) alignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose())
......
...@@ -10,7 +10,10 @@ import { extendWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFac ...@@ -10,7 +10,10 @@ import { extendWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFac
import { ComposableTask } from './ComposableTask'; import { ComposableTask } from './ComposableTask';
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks'; import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
import { nets } from './nets'; import { nets } from './nets';
import { PredictAllFaceExpressionsTask, PredictSingleFaceExpressionTask } from './PredictFaceExpressionsTask'; import {
PredictAllFaceExpressionsWithFaceAlignmentTask,
PredictSingleFaceExpressionsWithFaceAlignmentTask,
} from './PredictFaceExpressionsTask';
export class DetectFaceLandmarksTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> { export class DetectFaceLandmarksTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
constructor( constructor(
...@@ -52,6 +55,10 @@ export class DetectAllFaceLandmarksTask< ...@@ -52,6 +55,10 @@ export class DetectAllFaceLandmarksTask<
) )
} }
withFaceExpressions(): PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
return new PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
}
withFaceDescriptors(): ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>> { withFaceDescriptors(): ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>> {
return new ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>>(this, this.input) return new ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>>(this, this.input)
} }
...@@ -80,6 +87,10 @@ export class DetectSingleFaceLandmarksTask< ...@@ -80,6 +87,10 @@ export class DetectSingleFaceLandmarksTask<
return extendWithFaceLandmarks<TSource>(parentResult, landmarks) return extendWithFaceLandmarks<TSource>(parentResult, landmarks)
} }
withFaceExpressions(): PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
return new PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
}
withFaceDescriptor(): ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>> { withFaceDescriptor(): ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>> {
return new ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>>(this, this.input) return new ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>>(this, this.input)
} }
......
...@@ -8,7 +8,7 @@ import { TinyFaceDetectorOptions } from '../tinyFaceDetector/TinyFaceDetectorOpt ...@@ -8,7 +8,7 @@ import { TinyFaceDetectorOptions } from '../tinyFaceDetector/TinyFaceDetectorOpt
import { ComposableTask } from './ComposableTask'; import { ComposableTask } from './ComposableTask';
import { DetectAllFaceLandmarksTask, DetectSingleFaceLandmarksTask } from './DetectFaceLandmarksTasks'; import { DetectAllFaceLandmarksTask, DetectSingleFaceLandmarksTask } from './DetectFaceLandmarksTasks';
import { nets } from './nets'; import { nets } from './nets';
import { PredictAllFaceExpressionsTask, PredictSingleFaceExpressionTask } from './PredictFaceExpressionsTask'; import { PredictAllFaceExpressionsTask, PredictSingleFaceExpressionsTask } from './PredictFaceExpressionsTask';
import { FaceDetectionOptions } from './types'; import { FaceDetectionOptions } from './types';
export class DetectFacesTaskBase<TReturn> extends ComposableTask<TReturn> { export class DetectFacesTaskBase<TReturn> extends ComposableTask<TReturn> {
...@@ -101,8 +101,8 @@ export class DetectSingleFaceTask extends DetectFacesTaskBase<FaceDetection | un ...@@ -101,8 +101,8 @@ export class DetectSingleFaceTask extends DetectFacesTaskBase<FaceDetection | un
) )
} }
withFaceExpressions(): PredictSingleFaceExpressionTask<WithFaceDetection<{}>> { withFaceExpressions(): PredictSingleFaceExpressionsTask<WithFaceDetection<{}>> {
return new PredictSingleFaceExpressionTask<WithFaceDetection<{}>>( return new PredictSingleFaceExpressionsTask<WithFaceDetection<{}>>(
this.runAndExtendWithFaceDetection(), this.runAndExtendWithFaceDetection(),
this.input this.input
) )
......
...@@ -5,8 +5,9 @@ import { extractFaces, extractFaceTensors } from '../dom'; ...@@ -5,8 +5,9 @@ import { extractFaces, extractFaceTensors } from '../dom';
import { FaceExpressions } from '../faceExpressionNet/FaceExpressions'; import { FaceExpressions } from '../faceExpressionNet/FaceExpressions';
import { WithFaceDetection } from '../factories/WithFaceDetection'; import { WithFaceDetection } from '../factories/WithFaceDetection';
import { extendWithFaceExpressions, WithFaceExpressions } from '../factories/WithFaceExpressions'; import { extendWithFaceExpressions, WithFaceExpressions } from '../factories/WithFaceExpressions';
import { isWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFaceLandmarks';
import { ComposableTask } from './ComposableTask'; import { ComposableTask } from './ComposableTask';
import { DetectAllFaceLandmarksTask, DetectSingleFaceLandmarksTask } from './DetectFaceLandmarksTasks'; import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
import { nets } from './nets'; import { nets } from './nets';
export class PredictFaceExpressionsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> { export class PredictFaceExpressionsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
...@@ -26,10 +27,12 @@ export class PredictAllFaceExpressionsTask< ...@@ -26,10 +27,12 @@ export class PredictAllFaceExpressionsTask<
const parentResults = await this.parentTask const parentResults = await this.parentTask
const detections = parentResults.map(parentResult => parentResult.detection) const faceBoxes = parentResults.map(
parentResult => isWithFaceLandmarks(parentResult) ? parentResult.alignedRect : parentResult.detection
)
const faces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor const faces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
? await extractFaceTensors(this.input, detections) ? await extractFaceTensors(this.input, faceBoxes)
: await extractFaces(this.input, detections) : await extractFaces(this.input, faceBoxes)
const faceExpressionsByFace = await Promise.all(faces.map( const faceExpressionsByFace = await Promise.all(faces.map(
face => nets.faceExpressionNet.predictExpressions(face) face => nets.faceExpressionNet.predictExpressions(face)
...@@ -41,13 +44,9 @@ export class PredictAllFaceExpressionsTask< ...@@ -41,13 +44,9 @@ export class PredictAllFaceExpressionsTask<
(parentResult, i) => extendWithFaceExpressions<TSource>(parentResult, faceExpressionsByFace[i]) (parentResult, i) => extendWithFaceExpressions<TSource>(parentResult, faceExpressionsByFace[i])
) )
} }
withFaceLandmarks(): DetectAllFaceLandmarksTask<WithFaceExpressions<TSource>> {
return new DetectAllFaceLandmarksTask(this, this.input, false)
}
} }
export class PredictSingleFaceExpressionTask< export class PredictSingleFaceExpressionsTask<
TSource extends WithFaceDetection<{}> TSource extends WithFaceDetection<{}>
> extends PredictFaceExpressionsTaskBase<WithFaceExpressions<TSource> | undefined, TSource | undefined> { > extends PredictFaceExpressionsTaskBase<WithFaceExpressions<TSource> | undefined, TSource | undefined> {
...@@ -58,10 +57,10 @@ export class PredictSingleFaceExpressionTask< ...@@ -58,10 +57,10 @@ export class PredictSingleFaceExpressionTask<
return return
} }
const { detection } = parentResult const faceBox = isWithFaceLandmarks(parentResult) ? parentResult.alignedRect : parentResult.detection
const faces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor const faces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
? await extractFaceTensors(this.input, [detection]) ? await extractFaceTensors(this.input, [faceBox])
: await extractFaces(this.input, [detection]) : await extractFaces(this.input, [faceBox])
const faceExpressions = await nets.faceExpressionNet.predictExpressions(faces[0]) as FaceExpressions const faceExpressions = await nets.faceExpressionNet.predictExpressions(faces[0]) as FaceExpressions
...@@ -69,8 +68,22 @@ export class PredictSingleFaceExpressionTask< ...@@ -69,8 +68,22 @@ export class PredictSingleFaceExpressionTask<
return extendWithFaceExpressions(parentResult, faceExpressions) return extendWithFaceExpressions(parentResult, faceExpressions)
} }
}
export class PredictAllFaceExpressionsWithFaceAlignmentTask<
TSource extends WithFaceLandmarks<WithFaceDetection<{}>>
> extends PredictAllFaceExpressionsTask<TSource> {
withFaceDescriptors(): ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>> {
return new ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>>(this, this.input)
}
}
export class PredictSingleFaceExpressionsWithFaceAlignmentTask<
TSource extends WithFaceLandmarks<WithFaceDetection<{}>>
> extends PredictSingleFaceExpressionsTask<TSource> {
withFaceLandmarks(): DetectSingleFaceLandmarksTask<WithFaceExpressions<TSource>> { withFaceDescriptor(): ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>> {
return new DetectSingleFaceLandmarksTask(this, this.input, false) return new ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>>(this, this.input)
} }
} }
\ No newline at end of file
import { BoundingBox, IPoint } from 'tfjs-image-recognition-base';
export function minBbox(pts: IPoint[]): BoundingBox {
const xs = pts.map(pt => pt.x)
const ys = pts.map(pt => pt.y)
const minX = xs.reduce((min, x) => x < min ? x : min, Infinity)
const minY = ys.reduce((min, y) => y < min ? y : min, Infinity)
const maxX = xs.reduce((max, x) => max < x ? x : max, 0)
const maxY = ys.reduce((max, y) => max < y ? y : max, 0)
return new BoundingBox(minX, minY, maxX, maxY)
}
...@@ -10,6 +10,7 @@ import { ...@@ -10,6 +10,7 @@ import {
} from 'tfjs-image-recognition-base'; } from 'tfjs-image-recognition-base';
import { depthwiseSeparableConv } from '../common/depthwiseSeparableConv'; import { depthwiseSeparableConv } from '../common/depthwiseSeparableConv';
import { bgrToRgbTensor } from '../mtcnn/bgrToRgbTensor';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap'; import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types'; import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types';
...@@ -54,8 +55,9 @@ export class TinyXception extends NeuralNetwork<TinyXceptionParams> { ...@@ -54,8 +55,9 @@ export class TinyXception extends NeuralNetwork<TinyXceptionParams> {
return tf.tidy(() => { return tf.tidy(() => {
const batchTensor = input.toBatchTensor(112, true) const batchTensor = input.toBatchTensor(112, true)
const batchTensorRgb = bgrToRgbTensor(batchTensor)
const meanRgb = [122.782, 117.001, 104.298] const meanRgb = [122.782, 117.001, 104.298]
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D const normalized = normalize(batchTensorRgb, meanRgb).div(tf.scalar(256)) as tf.Tensor4D
let out = tf.relu(conv(normalized, params.entry_flow.conv_in, [2, 2])) let out = tf.relu(conv(normalized, params.entry_flow.conv_in, [2, 2]))
out = reductionBlock(out, params.entry_flow.reduction_block_0, false) out = reductionBlock(out, params.entry_flow.reduction_block_0, false)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment