Commit e95dd35d by vincent

allow arbitrary input size

parent 28a77453
...@@ -8,7 +8,7 @@ import { FaceDetection } from '../FaceDetection'; ...@@ -8,7 +8,7 @@ import { FaceDetection } from '../FaceDetection';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { toNetInput } from '../toNetInput'; import { toNetInput } from '../toNetInput';
import { TNetInput } from '../types'; import { TNetInput } from '../types';
import { BOX_ANCHORS, INPUT_SIZES, IOU_THRESHOLD, NUM_BOXES, NUM_CELLS } from './config'; import { BOX_ANCHORS, INPUT_SIZES, IOU_THRESHOLD, NUM_BOXES } from './config';
import { convWithBatchNorm } from './convWithBatchNorm'; import { convWithBatchNorm } from './convWithBatchNorm';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { getDefaultParams } from './getDefaultParams'; import { getDefaultParams } from './getDefaultParams';
...@@ -59,18 +59,19 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> { ...@@ -59,18 +59,19 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
public async locateFaces(input: TNetInput, forwardParams: TinyYolov2ForwardParams = {}): Promise<FaceDetection[]> { public async locateFaces(input: TNetInput, forwardParams: TinyYolov2ForwardParams = {}): Promise<FaceDetection[]> {
const { sizeType, scoreThreshold } = getDefaultParams(forwardParams) const { inputSize: _inputSize, scoreThreshold } = getDefaultParams(forwardParams)
const inputSize = typeof _inputSize === 'string'
? INPUT_SIZES[_inputSize]
: _inputSize
const inputSize = INPUT_SIZES[sizeType] if (typeof inputSize !== 'number') {
const numCells = NUM_CELLS[sizeType] throw new Error(`TinyYolov2 - unkown inputSize: ${inputSize}, expected number or one of xs | sm | md | lg`)
if (!inputSize) {
throw new Error(`TinyYolov2 - unkown sizeType: ${sizeType}, expected one of: xs | sm | md | lg`)
} }
const netInput = await toNetInput(input, true) const netInput = await toNetInput(input, true)
const out = await this.forwardInput(netInput, inputSize) const out = await this.forwardInput(netInput, inputSize)
const numCells = out.shape[1]
const [boxesTensor, scoresTensor] = tf.tidy(() => { const [boxesTensor, scoresTensor] = tf.tidy(() => {
const reshaped = out.reshape([numCells, numCells, NUM_BOXES, 6]) const reshaped = out.reshape([numCells, numCells, NUM_BOXES, 6])
......
import { Point } from '../Point'; import { Point } from '../Point';
export const INPUT_SIZES = { xs: 224, sm: 320, md: 416, lg: 608 } export const INPUT_SIZES = { xs: 224, sm: 320, md: 416, lg: 608 }
export const NUM_CELLS = { xs: 7, sm: 10, md: 13, lg: 19 }
export const NUM_BOXES = 5 export const NUM_BOXES = 5
export const IOU_THRESHOLD = 0.4 export const IOU_THRESHOLD = 0.4
......
...@@ -32,6 +32,6 @@ export enum SizeType { ...@@ -32,6 +32,6 @@ export enum SizeType {
} }
export type TinyYolov2ForwardParams = { export type TinyYolov2ForwardParams = {
sizeType?: SizeType inputSize?: SizeType | number
scoreThreshold?: number scoreThreshold?: number
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment