Commit 8ad95e33 by vincent

init box prediction

parent 2efff4f7
import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionNet } from './types';
function boxEncodingPredictionLayer(
x: tf.Tensor4D,
params: FaceDetectionNet.ConvWithBiasParams
) {
return tf.tidy(() => {
// TODO
return x
})
}
function classPredictionLayer(
x: tf.Tensor4D,
params: FaceDetectionNet.ConvWithBiasParams
) {
return tf.tidy(() => {
// TODO
return x
})
}
export function boxPredictionLayer(
x: tf.Tensor4D,
params: FaceDetectionNet.BoxPredictionParams,
size: number
) {
return tf.tidy(() => {
const boxPredictionEncoding = tf.reshape(
boxEncodingPredictionLayer(x, params.box_encoding_predictor_params),
[x.shape[0], size, 1, 4]
)
const classPrediction = tf.reshape(
classPredictionLayer(x, params.class_predictor_params),
[x.shape[0], size, 3]
)
return {
boxPredictionEncoding,
classPrediction
}
})
}
\ No newline at end of file
...@@ -2,7 +2,7 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,7 +2,7 @@ import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionNet } from './types'; import { FaceDetectionNet } from './types';
function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number) => Float32Array) { function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) {
function extractDepthwiseConvParams(numChannels: number): FaceDetectionNet.MobileNetV1.DepthwiseConvParams { function extractDepthwiseConvParams(numChannels: number): FaceDetectionNet.MobileNetV1.DepthwiseConvParams {
const filters = tf.tensor4d(extractWeights(3 * 3 * numChannels), [3, 3, numChannels, 1]) const filters = tf.tensor4d(extractWeights(3 * 3 * numChannels), [3, 3, numChannels, 1])
...@@ -20,7 +20,7 @@ function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number ...@@ -20,7 +20,7 @@ function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number
} }
} }
function extractPointwiseConvParams(channelsIn: number, channelsOut: number): FaceDetectionNet.MobileNetV1.PointwiseConvParams { function extractPointwiseConvParams(channelsIn: number, channelsOut: number): FaceDetectionNet.PointwiseConvParams {
const filters = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]) const filters = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
const batch_norm_offset = tf.tensor1d(extractWeights(channelsOut)) const batch_norm_offset = tf.tensor1d(extractWeights(channelsOut))
...@@ -40,22 +40,6 @@ function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number ...@@ -40,22 +40,6 @@ function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number
} }
} }
return {
extractPointwiseConvParams,
extractConvPairParams
}
}
function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) {
const {
extractPointwiseConvParams,
extractConvPairParams
} = mobilenetV1WeightsExtractorsFactory(extractWeights)
function extractMobilenetV1Params(): FaceDetectionNet.MobileNetV1.Params { function extractMobilenetV1Params(): FaceDetectionNet.MobileNetV1.Params {
const conv_0_params = { const conv_0_params = {
......
...@@ -50,6 +50,15 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -50,6 +50,15 @@ export function faceDetectionNet(weights: Float32Array) {
let out = resizeLayer(imgTensor) as tf.Tensor4D let out = resizeLayer(imgTensor) as tf.Tensor4D
out = mobileNetV1(out, params.mobilenetv1_params) out = mobileNetV1(out, params.mobilenetv1_params)
// boxpredictor0: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6
// boxpredictor1: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6
// boxpredictor2: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_2_3x3_s2_512/Relu6
// boxpredictor3: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_3_3x3_s2_256/Relu6
// boxpredictor4: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_4_3x3_s2_256/Relu6
// boxpredictor5: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6
return out return out
}) })
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { pointwiseConvLayer } from './pointwiseConvLayer';
import { FaceDetectionNet } from './types'; import { FaceDetectionNet } from './types';
const epsilon = 0.0010000000474974513 const epsilon = 0.0010000000474974513
...@@ -25,19 +26,7 @@ function depthwiseConvLayer( ...@@ -25,19 +26,7 @@ function depthwiseConvLayer(
}) })
} }
function pointwiseConvLayer(
x: tf.Tensor4D,
params: FaceDetectionNet.MobileNetV1.PointwiseConvParams,
strides: [number, number]
) {
return tf.tidy(() => {
let out = tf.conv2d(x, params.filters, strides, 'same')
out = tf.add(out, params.batch_norm_offset)
return tf.relu(out)
})
}
function getStridesForLayerIdx(layerIdx: number): [number, number] { function getStridesForLayerIdx(layerIdx: number): [number, number] {
return [2, 4, 6, 12].some(idx => idx === layerIdx) ? [2, 2] : [1, 1] return [2, 4, 6, 12].some(idx => idx === layerIdx) ? [2, 2] : [1, 1]
......
import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionNet } from './types';
export function pointwiseConvLayer(
x: tf.Tensor4D,
params: FaceDetectionNet.PointwiseConvParams,
strides: [number, number]
) {
return tf.tidy(() => {
let out = tf.conv2d(x, params.filters, strides, 'same')
out = tf.add(out, params.batch_norm_offset)
return tf.relu(out)
})
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { boxPredictionLayer } from './boxPredictionLayer';
import { pointwiseConvLayer } from './pointwiseConvLayer';
import { FaceDetectionNet } from './types';
export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) {
return tf.tidy(() => {
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1])
const conv1 = pointwiseConvLayer(x, params.conv_1_params, [2, 2])
const conv2 = pointwiseConvLayer(x, params.conv_2_params, [1, 1])
const conv3 = pointwiseConvLayer(x, params.conv_3_params, [2, 2])
const conv4 = pointwiseConvLayer(x, params.conv_4_params, [1, 1])
const conv5 = pointwiseConvLayer(x, params.conv_5_params, [2, 2])
const conv6 = pointwiseConvLayer(x, params.conv_4_params, [1, 1])
const conv7 = pointwiseConvLayer(x, params.conv_5_params, [2, 2])
const boxPrediction0 = boxPredictionLayer(x, params.box_predictor_0_params, 3072)
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params, 1536)
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params, 384)
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params, 96)
const boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4_params, 24)
const boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5_params, 6)
const boxPredictions = tf.concat([
boxPrediction0.boxPredictionEncoding,
boxPrediction1.boxPredictionEncoding,
boxPrediction2.boxPredictionEncoding,
boxPrediction3.boxPredictionEncoding,
boxPrediction4.boxPredictionEncoding,
boxPrediction5.boxPredictionEncoding
])
const classPredictions = tf.concat([
boxPrediction0.classPrediction,
boxPrediction1.classPrediction,
boxPrediction2.classPrediction,
boxPrediction3.classPrediction,
boxPrediction4.classPrediction,
boxPrediction5.classPrediction
])
return {
boxPredictions,
classPredictions
}
})
}
\ No newline at end of file
...@@ -2,6 +2,11 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,6 +2,11 @@ import * as tf from '@tensorflow/tfjs-core';
export namespace FaceDetectionNet { export namespace FaceDetectionNet {
export type PointwiseConvParams = {
filters: tf.Tensor4D
batch_norm_offset: tf.Tensor1D
}
export namespace MobileNetV1 { export namespace MobileNetV1 {
export type DepthwiseConvParams = { export type DepthwiseConvParams = {
...@@ -12,11 +17,6 @@ export namespace FaceDetectionNet { ...@@ -12,11 +17,6 @@ export namespace FaceDetectionNet {
batch_norm_variance: tf.Tensor1D batch_norm_variance: tf.Tensor1D
} }
export type PointwiseConvParams = {
filters: tf.Tensor4D
batch_norm_offset: tf.Tensor1D
}
export type ConvPairParams = { export type ConvPairParams = {
depthwise_conv_params: DepthwiseConvParams depthwise_conv_params: DepthwiseConvParams
pointwise_conv_params: PointwiseConvParams pointwise_conv_params: PointwiseConvParams
...@@ -29,6 +29,31 @@ export namespace FaceDetectionNet { ...@@ -29,6 +29,31 @@ export namespace FaceDetectionNet {
} }
export type ConvWithBiasParams = {
filters: tf.Tensor4D
bias: tf.Tensor1D
}
export type BoxPredictionParams = {
class_predictor_params: ConvWithBiasParams
box_encoding_predictor_params: ConvWithBiasParams
}
export type PredictionParams = {
conv_0_params: PointwiseConvParams
conv_1_params: PointwiseConvParams
conv_2_params: PointwiseConvParams
conv_3_params: PointwiseConvParams
conv_4_params: PointwiseConvParams
conv_5_params: PointwiseConvParams
box_predictor_0_params: BoxPredictionParams
box_predictor_1_params: BoxPredictionParams
box_predictor_2_params: BoxPredictionParams
box_predictor_3_params: BoxPredictionParams
box_predictor_4_params: BoxPredictionParams
box_predictor_5_params: BoxPredictionParams
}
export type NetParams = { export type NetParams = {
mobilenetv1_params: MobileNetV1.Params mobilenetv1_params: MobileNetV1.Params
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment