Commit 1d133c54 by vincent

fixed MTCNN forwardParams initialization for optional params

parent 5d874922
......@@ -73,9 +73,9 @@ export function computeFaceDescriptor(
export function mtcnn(
input: TNetInput,
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
forwardParams: MtcnnForwardParams
): Promise<MtcnnResult[]> {
return nets.mtcnn.forward(input, forwardParameters)
return nets.mtcnn.forward(input, forwardParams)
}
export const allFaces: (
......
......@@ -28,7 +28,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forwardInput(
input: NetInput,
{ minFaceSize, scaleFactor, maxNumScales, scoreThresholds, scaleSteps } = getDefaultMtcnnForwardParams()
forwardParams: MtcnnForwardParams
): Promise<{ results: MtcnnResult[], stats: any }> {
const { params } = this
......@@ -64,6 +64,14 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
const [height, width] = imgTensor.shape.slice(1)
const {
minFaceSize,
scaleFactor,
maxNumScales,
scoreThresholds,
scaleSteps
} = Object.assign({}, getDefaultMtcnnForwardParams(), forwardParams)
const scales = scaleSteps || pyramidDown(minFaceSize, scaleFactor, [height, width])
.filter(scale => {
const sizes = getSizesForScale(scale, [height, width])
......@@ -124,23 +132,23 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
public async forward(
input: TNetInput,
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
forwardParams: MtcnnForwardParams
): Promise<MtcnnResult[]> {
return (
await this.forwardInput(
await toNetInput(input, true, true),
forwardParameters
forwardParams
)
).results
}
public async forwardWithStats(
input: TNetInput,
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
forwardParams: MtcnnForwardParams
): Promise<{ results: MtcnnResult[], stats: any }> {
return this.forwardInput(
await toNetInput(input, true, true),
forwardParameters
forwardParams
)
}
......
import { MtcnnForwardParams } from './types';
export function getDefaultMtcnnForwardParams(): MtcnnForwardParams {
export function getDefaultMtcnnForwardParams() {
return {
minFaceSize: 20,
scaleFactor: 0.709,
......
......@@ -47,9 +47,9 @@ export type MtcnnResult = {
}
export type MtcnnForwardParams = {
minFaceSize: number
scaleFactor: number
maxNumScales: number
scoreThresholds: number[]
minFaceSize?: number
scaleFactor?: number
maxNumScales?: number
scoreThresholds?: number[]
scaleSteps?: number[]
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment