Commit 5d874922 by vincent

implemented mtcnn model loading from url + expose mtcnn to global api + fixed some minor issues

parent 4eed7a39
import { Rect } from '../Rect'; import { Rect } from './Rect';
import { Dimensions } from '../types'; import { Dimensions } from './types';
export class FaceDetection { export class FaceDetection {
private _score: number private _score: number
......
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './FaceDetection';
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68'; import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
export class FullFaceDescription { export class FullFaceDescription {
......
import { FaceDetection } from '../faceDetectionNet/FaceDetection'; import { FaceDetection } from '../FaceDetection';
import { FaceLandmarks68 } from '../faceLandmarkNet'; import { FaceLandmarks68 } from '../faceLandmarkNet';
import { FaceLandmarks } from '../FaceLandmarks'; import { FaceLandmarks } from '../FaceLandmarks';
import { Point } from '../Point'; import { Point } from '../Point';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './FaceDetection';
import { Rect } from './Rect'; import { Rect } from './Rect';
import { toNetInput } from './toNetInput'; import { toNetInput } from './toNetInput';
import { TNetInput } from './types'; import { TNetInput } from './types';
......
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './FaceDetection';
import { Rect } from './Rect'; import { Rect } from './Rect';
import { toNetInput } from './toNetInput'; import { toNetInput } from './toNetInput';
import { TNetInput } from './types'; import { TNetInput } from './types';
import { createCanvas, getContext2dOrThrow, imageTensorToCanvas } from './utils'; import { createCanvas, getContext2dOrThrow, imageTensorToCanvas } from './utils';
import * as tf from '@tensorflow/tfjs-core';
/** /**
* Extracts the image regions containing the detected faces. * Extracts the image regions containing the detected faces.
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NeuralNetwork } from '../commons/NeuralNetwork'; import { NeuralNetwork } from '../commons/NeuralNetwork';
import { FaceDetection } from '../FaceDetection';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { Rect } from '../Rect'; import { Rect } from '../Rect';
import { toNetInput } from '../toNetInput'; import { toNetInput } from '../toNetInput';
import { TNetInput } from '../types'; import { TNetInput } from '../types';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { FaceDetection } from './FaceDetection';
import { loadQuantizedParams } from './loadQuantizedParams'; import { loadQuantizedParams } from './loadQuantizedParams';
import { mobileNetV1 } from './mobileNetV1'; import { mobileNetV1 } from './mobileNetV1';
import { nonMaxSuppression } from './nonMaxSuppression'; import { nonMaxSuppression } from './nonMaxSuppression';
......
import { FaceDetectionNet } from './FaceDetectionNet'; import { FaceDetectionNet } from './FaceDetectionNet';
export * from './FaceDetectionNet'; export * from './FaceDetectionNet';
export * from './FaceDetection';
export function faceDetectionNet(weights: Float32Array) { export function createFaceDetectionNet(weights: Float32Array) {
const net = new FaceDetectionNet() const net = new FaceDetectionNet()
net.extractWeights(weights) net.extractWeights(weights)
return net return net
}
export function faceDetectionNet(weights: Float32Array) {
console.warn('faceDetectionNet(weights: Float32Array) will be deprecated in future, use createFaceDetectionNet instead')
return createFaceDetectionNet(weights)
} }
\ No newline at end of file
import { tf } from '..'; import * as tf from '@tensorflow/tfjs-core';
import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensors'; import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensors';
import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory'; import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory';
import { isTensor1D, isTensor3D, isTensor4D } from '../commons/isTensor'; import { isTensor3D } from '../commons/isTensor';
import { loadWeightMap } from '../commons/loadWeightMap'; import { loadWeightMap } from '../commons/loadWeightMap';
import { ConvParams, ParamMapping } from '../commons/types'; import { ConvParams, ParamMapping } from '../commons/types';
import { BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types'; import { BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
......
import { getCenterPoint } from '../commons/getCenterPoint'; import { getCenterPoint } from '../commons/getCenterPoint';
import { FaceDetection } from '../faceDetectionNet/FaceDetection'; import { FaceDetection } from '../FaceDetection';
import { FaceLandmarks } from '../FaceLandmarks'; import { FaceLandmarks } from '../FaceLandmarks';
import { IPoint, Point } from '../Point'; import { IPoint, Point } from '../Point';
import { Rect } from '../Rect'; import { Rect } from '../Rect';
import { Dimensions } from '../types';
// face alignment constants // face alignment constants
const relX = 0.5 const relX = 0.5
...@@ -70,7 +69,7 @@ export class FaceLandmarks68 extends FaceLandmarks { ...@@ -70,7 +69,7 @@ export class FaceLandmarks68 extends FaceLandmarks {
* @returns The bounding box of the aligned face. * @returns The bounding box of the aligned face.
*/ */
public align( public align(
detection?: Rect detection?: FaceDetection | Rect
): Rect { ): Rect {
if (detection) { if (detection) {
const box = detection instanceof FaceDetection const box = detection instanceof FaceDetection
......
...@@ -3,8 +3,13 @@ import { FaceLandmarkNet } from './FaceLandmarkNet'; ...@@ -3,8 +3,13 @@ import { FaceLandmarkNet } from './FaceLandmarkNet';
export * from './FaceLandmarkNet'; export * from './FaceLandmarkNet';
export * from './FaceLandmarks68'; export * from './FaceLandmarks68';
export function faceLandmarkNet(weights: Float32Array) { export function createFaceLandmarkNet(weights: Float32Array) {
const net = new FaceLandmarkNet() const net = new FaceLandmarkNet()
net.extractWeights(weights) net.extractWeights(weights)
return net return net
}
export function faceLandmarkNet(weights: Float32Array) {
console.warn('faceLandmarkNet(weights: Float32Array) will be deprecated in future, use createFaceLandmarkNet instead')
return createFaceLandmarkNet(weights)
} }
\ No newline at end of file
...@@ -2,8 +2,13 @@ import { FaceRecognitionNet } from './FaceRecognitionNet'; ...@@ -2,8 +2,13 @@ import { FaceRecognitionNet } from './FaceRecognitionNet';
export * from './FaceRecognitionNet'; export * from './FaceRecognitionNet';
export function faceRecognitionNet(weights: Float32Array) { export function createFaceRecognitionNet(weights: Float32Array) {
const net = new FaceRecognitionNet() const net = new FaceRecognitionNet()
net.extractWeights(weights) net.extractWeights(weights)
return net return net
}
export function faceRecognitionNet(weights: Float32Array) {
console.warn('faceRecognitionNet(weights: Float32Array) will be deprecated in future, use createFaceRecognitionNet instead')
return createFaceRecognitionNet(weights)
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { allFacesFactory } from './allFacesFactory'; import { allFacesFactory } from './allFacesFactory';
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './FaceDetection';
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet'; import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet'; import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68'; import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet'; import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
import { FullFaceDescription } from './FullFaceDescription'; import { FullFaceDescription } from './FullFaceDescription';
import { getDefaultMtcnnForwardParams } from './mtcnn/getDefaultMtcnnForwardParams';
import { Mtcnn } from './mtcnn/Mtcnn';
import { MtcnnForwardParams, MtcnnResult } from './mtcnn/types';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import { TNetInput } from './types'; import { TNetInput } from './types';
...@@ -14,23 +17,37 @@ export const detectionNet = new FaceDetectionNet() ...@@ -14,23 +17,37 @@ export const detectionNet = new FaceDetectionNet()
export const landmarkNet = new FaceLandmarkNet() export const landmarkNet = new FaceLandmarkNet()
export const recognitionNet = new FaceRecognitionNet() export const recognitionNet = new FaceRecognitionNet()
// nets need more specific names, to avoid ambiguity in future
// when alternative net implementations are provided
export const nets = {
ssdMobilenet: detectionNet,
faceLandmark68Net: landmarkNet,
faceNet: recognitionNet,
mtcnn: new Mtcnn()
}
export function loadFaceDetectionModel(url: string) { export function loadFaceDetectionModel(url: string) {
return detectionNet.load(url) return nets.ssdMobilenet.load(url)
} }
export function loadFaceLandmarkModel(url: string) { export function loadFaceLandmarkModel(url: string) {
return landmarkNet.load(url) return nets.faceLandmark68Net.load(url)
} }
export function loadFaceRecognitionModel(url: string) { export function loadFaceRecognitionModel(url: string) {
return recognitionNet.load(url) return nets.faceNet.load(url)
}
export function loadMtcnnModel(url: string) {
return nets.mtcnn.load(url)
} }
export function loadModels(url: string) { export function loadModels(url: string) {
return Promise.all([ return Promise.all([
loadFaceDetectionModel(url), loadFaceDetectionModel(url),
loadFaceLandmarkModel(url), loadFaceLandmarkModel(url),
loadFaceRecognitionModel(url) loadFaceRecognitionModel(url),
loadMtcnnModel(url)
]) ])
} }
...@@ -39,19 +56,26 @@ export function locateFaces( ...@@ -39,19 +56,26 @@ export function locateFaces(
minConfidence?: number, minConfidence?: number,
maxResults?: number maxResults?: number
): Promise<FaceDetection[]> { ): Promise<FaceDetection[]> {
return detectionNet.locateFaces(input, minConfidence, maxResults) return nets.ssdMobilenet.locateFaces(input, minConfidence, maxResults)
} }
export function detectLandmarks( export function detectLandmarks(
input: TNetInput input: TNetInput
): Promise<FaceLandmarks68 | FaceLandmarks68[]> { ): Promise<FaceLandmarks68 | FaceLandmarks68[]> {
return landmarkNet.detectLandmarks(input) return nets.faceLandmark68Net.detectLandmarks(input)
} }
export function computeFaceDescriptor( export function computeFaceDescriptor(
input: TNetInput input: TNetInput
): Promise<Float32Array | Float32Array[]> { ): Promise<Float32Array | Float32Array[]> {
return recognitionNet.computeFaceDescriptor(input) return nets.faceNet.computeFaceDescriptor(input)
}
export function mtcnn(
input: TNetInput,
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
): Promise<MtcnnResult[]> {
return nets.mtcnn.forward(input, forwardParameters)
} }
export const allFaces: ( export const allFaces: (
......
...@@ -4,6 +4,8 @@ export { ...@@ -4,6 +4,8 @@ export {
tf tf
} }
export * from './FaceDetection';
export * from './FullFaceDescription'; export * from './FullFaceDescription';
export * from './NetInput'; export * from './NetInput';
export * from './Point'; export * from './Point';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NeuralNetwork } from '../commons/NeuralNetwork'; import { NeuralNetwork } from '../commons/NeuralNetwork';
import { FaceDetection } from '../faceDetectionNet/FaceDetection'; import { FaceDetection } from '../FaceDetection';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { Point } from '../Point'; import { Point } from '../Point';
import { Rect } from '../Rect'; import { Rect } from '../Rect';
import { toNetInput } from '../toNetInput'; import { toNetInput } from '../toNetInput';
import { TNetInput } from '../types'; import { TNetInput } from '../types';
import { bgrToRgbTensor } from './bgrToRgbTensor'; import { bgrToRgbTensor } from './bgrToRgbTensor';
import { CELL_SIZE } from './config';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { FaceLandmarks5 } from './FaceLandmarks5'; import { FaceLandmarks5 } from './FaceLandmarks5';
import { getDefaultMtcnnForwardParams } from './getDefaultMtcnnForwardParams';
import { getSizesForScale } from './getSizesForScale'; import { getSizesForScale } from './getSizesForScale';
import { loadQuantizedParams } from './loadQuantizedParams';
import { pyramidDown } from './pyramidDown'; import { pyramidDown } from './pyramidDown';
import { stage1 } from './stage1'; import { stage1 } from './stage1';
import { stage2 } from './stage2'; import { stage2 } from './stage2';
import { stage3 } from './stage3'; import { stage3 } from './stage3';
import { MtcnnResult, NetParams } from './types'; import { MtcnnForwardParams, MtcnnResult, NetParams } from './types';
export class Mtcnn extends NeuralNetwork<NetParams> { export class Mtcnn extends NeuralNetwork<NetParams> {
...@@ -25,10 +28,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -25,10 +28,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forwardInput( public async forwardInput(
input: NetInput, input: NetInput,
minFaceSize: number = 20, { minFaceSize, scaleFactor, maxNumScales, scoreThresholds, scaleSteps } = getDefaultMtcnnForwardParams()
scaleFactor: number = 0.709,
maxNumScales: number = 10,
scoreThresholds: number[] = [0.6, 0.7, 0.7]
): Promise<{ results: MtcnnResult[], stats: any }> { ): Promise<{ results: MtcnnResult[], stats: any }> {
const { params } = this const { params } = this
...@@ -64,10 +64,10 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -64,10 +64,10 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
const [height, width] = imgTensor.shape.slice(1) const [height, width] = imgTensor.shape.slice(1)
const scales = pyramidDown(minFaceSize, scaleFactor, [height, width]) const scales = scaleSteps || pyramidDown(minFaceSize, scaleFactor, [height, width])
.filter(scale => { .filter(scale => {
const sizes = getSizesForScale(scale, [height, width]) const sizes = getSizesForScale(scale, [height, width])
return Math.min(sizes.width, sizes.height) > 48 return Math.min(sizes.width, sizes.height) > CELL_SIZE
}) })
.slice(0, maxNumScales) .slice(0, maxNumScales)
...@@ -124,38 +124,31 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -124,38 +124,31 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forward( public async forward(
input: TNetInput, input: TNetInput,
minFaceSize: number = 20, forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
scaleFactor: number = 0.709,
maxNumScales: number = 10,
scoreThresholds: number[] = [0.6, 0.7, 0.7]
): Promise<MtcnnResult[]> { ): Promise<MtcnnResult[]> {
return ( return (
await this.forwardInput( await this.forwardInput(
await toNetInput(input, true, true), await toNetInput(input, true, true),
minFaceSize, forwardParameters
scaleFactor,
maxNumScales,
scoreThresholds
) )
).results ).results
} }
public async forwardWithStats( public async forwardWithStats(
input: TNetInput, input: TNetInput,
minFaceSize: number = 20, forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
scaleFactor: number = 0.709,
maxNumScales: number = 10,
scoreThresholds: number[] = [0.6, 0.7, 0.7]
): Promise<{ results: MtcnnResult[], stats: any }> { ): Promise<{ results: MtcnnResult[], stats: any }> {
return this.forwardInput( return this.forwardInput(
await toNetInput(input, true, true), await toNetInput(input, true, true),
minFaceSize, forwardParameters
scaleFactor,
maxNumScales,
scoreThresholds
) )
} }
// none of the param tensors are quantized yet
protected loadQuantizedParams(uri: string | undefined) {
return loadQuantizedParams(uri)
}
protected extractParams(weights: Float32Array) { protected extractParams(weights: Float32Array) {
return extractParams(weights) return extractParams(weights)
} }
......
...@@ -55,10 +55,10 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings ...@@ -55,10 +55,10 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings
const conv4 = extractConvParams(64, 128, 2, 'onet/conv4') const conv4 = extractConvParams(64, 128, 2, 'onet/conv4')
const prelu4_alpha = extractPReluParams(128, 'onet/prelu4_alpha') const prelu4_alpha = extractPReluParams(128, 'onet/prelu4_alpha')
const fc1 = extractFCParams(1152, 256, 'onet/fc1') const fc1 = extractFCParams(1152, 256, 'onet/fc1')
const prelu5_alpha = extractPReluParams(256, 'onet/prelu4_alpha') const prelu5_alpha = extractPReluParams(256, 'onet/prelu5_alpha')
const fc2_1 = extractFCParams(256, 2, 'onet/fc2_1') const fc2_1 = extractFCParams(256, 2, 'onet/fc2_1')
const fc2_2 = extractFCParams(256, 4, 'onet/fc2_2') const fc2_2 = extractFCParams(256, 4, 'onet/fc2_2')
const fc2_3 = extractFCParams(256, 10, 'onet/fc2_2') const fc2_3 = extractFCParams(256, 10, 'onet/fc2_3')
return { ...sharedParams, conv4, prelu4_alpha, fc1, prelu5_alpha, fc2_1, fc2_2, fc2_3 } return { ...sharedParams, conv4, prelu4_alpha, fc1, prelu5_alpha, fc2_1, fc2_2, fc2_3 }
} }
......
import { MtcnnForwardParams } from './types';
export function getDefaultMtcnnForwardParams(): MtcnnForwardParams {
return {
minFaceSize: 20,
scaleFactor: 0.709,
maxNumScales: 10,
scoreThresholds: [0.6, 0.7, 0.7]
}
}
\ No newline at end of file
import { Mtcnn } from './Mtcnn'; import { Mtcnn } from './Mtcnn';
export * from './Mtcnn'; export * from './Mtcnn';
export * from './FaceLandmarks5';
export function mtcnn(weights: Float32Array) { export function createMtcnn(weights: Float32Array) {
const net = new Mtcnn() const net = new Mtcnn()
net.extractWeights(weights) net.extractWeights(weights)
return net return net
......
import * as tf from '@tensorflow/tfjs-core';
import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensors';
import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory';
import { loadWeightMap } from '../commons/loadWeightMap';
import { ConvParams, FCParams, ParamMapping } from '../commons/types';
import { NetParams, ONetParams, PNetParams, RNetParams, SharedParams } from './types';
const DEFAULT_MODEL_NAME = 'mtcnn_model'
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
function extractConvParams(prefix: string): ConvParams {
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/weights`, 4, `${prefix}/filters`)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return { filters, bias }
}
function extractFCParams(prefix: string): FCParams {
const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return { weights, bias }
}
function extractPReluParams(paramPath: string): tf.Tensor1D {
return extractWeightEntry<tf.Tensor1D>(paramPath, 1)
}
function extractSharedParams(prefix: string): SharedParams {
const conv1 = extractConvParams(`${prefix}/conv1`)
const prelu1_alpha = extractPReluParams(`${prefix}/prelu1_alpha`)
const conv2 = extractConvParams(`${prefix}/conv2`)
const prelu2_alpha = extractPReluParams(`${prefix}/prelu2_alpha`)
const conv3 = extractConvParams(`${prefix}/conv3`)
const prelu3_alpha = extractPReluParams(`${prefix}/prelu3_alpha`)
return { conv1, prelu1_alpha, conv2, prelu2_alpha, conv3, prelu3_alpha }
}
function extractPNetParams(): PNetParams {
const sharedParams = extractSharedParams('pnet')
const conv4_1 = extractConvParams('pnet/conv4_1')
const conv4_2 = extractConvParams('pnet/conv4_2')
return { ...sharedParams, conv4_1, conv4_2 }
}
function extractRNetParams(): RNetParams {
const sharedParams = extractSharedParams('rnet')
const fc1 = extractFCParams('rnet/fc1')
const prelu4_alpha = extractPReluParams('rnet/prelu4_alpha')
const fc2_1 = extractFCParams('rnet/fc2_1')
const fc2_2 = extractFCParams('rnet/fc2_2')
return { ...sharedParams, fc1, prelu4_alpha, fc2_1, fc2_2 }
}
function extractONetParams(): ONetParams {
const sharedParams = extractSharedParams('onet')
const conv4 = extractConvParams('onet/conv4')
const prelu4_alpha = extractPReluParams('onet/prelu4_alpha')
const fc1 = extractFCParams('onet/fc1')
const prelu5_alpha = extractPReluParams('onet/prelu5_alpha')
const fc2_1 = extractFCParams('onet/fc2_1')
const fc2_2 = extractFCParams('onet/fc2_2')
const fc2_3 = extractFCParams('onet/fc2_3')
return { ...sharedParams, conv4, prelu4_alpha, fc1, prelu5_alpha, fc2_1, fc2_2, fc2_3 }
}
return {
extractPNetParams,
extractRNetParams,
extractONetParams
}
}
export async function loadQuantizedParams(
uri: string | undefined
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
const paramMappings: ParamMapping[] = []
const {
extractPNetParams,
extractRNetParams,
extractONetParams
} = extractorsFactory(weightMap, paramMappings)
const pnet = extractPNetParams()
const rnet = extractRNetParams()
const onet = extractONetParams()
disposeUnusedWeightTensors(weightMap, paramMappings)
return { params: { pnet, rnet, onet }, paramMappings }
}
\ No newline at end of file
import { tf } from '..'; import * as tf from '@tensorflow/tfjs-core';
import { BoundingBox } from './BoundingBox'; import { BoundingBox } from './BoundingBox';
import { extractImagePatches } from './extractImagePatches'; import { extractImagePatches } from './extractImagePatches';
import { nms } from './nms'; import { nms } from './nms';
......
import * as tf from '@tensorflow/tfjs-core';
import { Point } from '../Point'; import { Point } from '../Point';
import { BoundingBox } from './BoundingBox'; import { BoundingBox } from './BoundingBox';
import { extractImagePatches } from './extractImagePatches'; import { extractImagePatches } from './extractImagePatches';
import { nms } from './nms'; import { nms } from './nms';
import { ONet } from './ONet'; import { ONet } from './ONet';
import { ONetParams } from './types'; import { ONetParams } from './types';
import { tf } from '..';
export async function stage3( export async function stage3(
img: HTMLCanvasElement, img: HTMLCanvasElement,
......
import { tf } from '..'; import * as tf from '@tensorflow/tfjs-core';
import { ConvParams, FCParams } from '../commons/types'; import { ConvParams, FCParams } from '../commons/types';
import { FaceDetection } from '../faceDetectionNet/FaceDetection'; import { FaceDetection } from '../FaceDetection';
import { FaceLandmarks5 } from './FaceLandmarks5'; import { FaceLandmarks5 } from './FaceLandmarks5';
export type SharedParams = { export type SharedParams = {
...@@ -43,4 +44,12 @@ export type NetParams = { ...@@ -43,4 +44,12 @@ export type NetParams = {
export type MtcnnResult = { export type MtcnnResult = {
faceDetection: FaceDetection, faceDetection: FaceDetection,
faceLandmarks: FaceLandmarks5 faceLandmarks: FaceLandmarks5
} }
\ No newline at end of file
export type MtcnnForwardParams = {
minFaceSize: number
scaleFactor: number
maxNumScales: number
scoreThresholds: number[]
scaleSteps?: number[]
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment