Commit f3469d09 by vincent

prediction layer + params

parent 8ad95e33
......@@ -2,28 +2,16 @@ import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionNet } from './types';
function boxEncodingPredictionLayer(
function convWithBias(
x: tf.Tensor4D,
params: FaceDetectionNet.ConvWithBiasParams
) {
return tf.tidy(() => {
// TODO
return x
})
}
function classPredictionLayer(
x: tf.Tensor4D,
params: FaceDetectionNet.ConvWithBiasParams
) {
return tf.tidy(() => {
// TODO
return x
})
return tf.tidy(() =>
tf.add(
tf.conv2d(x, params.filters, [1, 1], 'same'),
params.bias
)
)
}
export function boxPredictionLayer(
......@@ -33,13 +21,15 @@ export function boxPredictionLayer(
) {
return tf.tidy(() => {
const batchSize = x.shape[0]
const boxPredictionEncoding = tf.reshape(
boxEncodingPredictionLayer(x, params.box_encoding_predictor_params),
[x.shape[0], size, 1, 4]
convWithBias(x, params.box_encoding_predictor_params),
[batchSize, size, 1, 4]
)
const classPrediction = tf.reshape(
classPredictionLayer(x, params.class_predictor_params),
[x.shape[0], size, 3]
convWithBias(x, params.class_predictor_params),
[batchSize, size, 3]
)
return {
......
......@@ -20,19 +20,42 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
}
}
function extractPointwiseConvParams(channelsIn: number, channelsOut: number): FaceDetectionNet.PointwiseConvParams {
const filters = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
const batch_norm_offset = tf.tensor1d(extractWeights(channelsOut))
function extractConvWithBiasParams(
channelsIn: number,
channelsOut: number,
filterSize: number
): FaceDetectionNet.ConvWithBiasParams {
const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut]
)
const bias = tf.tensor1d(extractWeights(channelsOut))
return {
filters,
bias
}
}
function extractPointwiseConvParams(
channelsIn: number,
channelsOut: number,
filterSize: number
): FaceDetectionNet.PointwiseConvParams {
const {
filters,
bias
} = extractConvWithBiasParams(channelsIn, channelsOut, filterSize)
return {
filters,
batch_norm_offset
batch_norm_offset: bias
}
}
function extractConvPairParams(channelsIn: number, channelsOut: number): FaceDetectionNet.MobileNetV1.ConvPairParams {
const depthwise_conv_params = extractDepthwiseConvParams(channelsIn)
const pointwise_conv_params = extractPointwiseConvParams(channelsIn, channelsOut)
const pointwise_conv_params = extractPointwiseConvParams(channelsIn, channelsOut, 1)
return {
depthwise_conv_params,
......@@ -42,11 +65,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
function extractMobilenetV1Params(): FaceDetectionNet.MobileNetV1.Params {
const conv_0_params = {
filters: tf.tensor4d(extractWeights(3 * 3 * 3 * 32), [3, 3, 3, 32]),
batch_norm_offset: tf.tensor1d(extractWeights(32))
}
const conv_0_params = extractPointwiseConvParams(3, 32, 3)
const channelNumPairs = [
[32, 64],
......@@ -75,32 +94,101 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
}
function extractPredictionLayerParams(): FaceDetectionNet.PredictionParams {
const conv_0_params = extractPointwiseConvParams(1024, 256, 1)
const conv_1_params = extractPointwiseConvParams(256, 512, 3)
const conv_2_params = extractPointwiseConvParams(512, 128, 1)
const conv_3_params = extractPointwiseConvParams(128, 256, 3)
const conv_4_params = extractPointwiseConvParams(256, 128, 1)
const conv_5_params = extractPointwiseConvParams(128, 256, 3)
const conv_6_params = extractPointwiseConvParams(256, 64, 1)
const conv_7_params = extractPointwiseConvParams(64, 128, 3)
const box_encoding_0_predictor_params = extractConvWithBiasParams(512, 12, 1)
const class_predictor_0_params = extractConvWithBiasParams(512, 9, 1)
const box_encoding_1_predictor_params = extractConvWithBiasParams(1024, 24, 1)
const class_predictor_1_params = extractConvWithBiasParams(1024, 18, 1)
const box_encoding_2_predictor_params = extractConvWithBiasParams(512, 24, 1)
const class_predictor_2_params = extractConvWithBiasParams(512, 18, 1)
const box_encoding_3_predictor_params = extractConvWithBiasParams(256, 24, 1)
const class_predictor_3_params = extractConvWithBiasParams(256, 18, 1)
const box_encoding_4_predictor_params = extractConvWithBiasParams(256, 24, 1)
const class_predictor_4_params = extractConvWithBiasParams(256, 18, 1)
const box_encoding_5_predictor_params = extractConvWithBiasParams(128, 24, 1)
const class_predictor_5_params = extractConvWithBiasParams(128, 18, 1)
const box_predictor_0_params = {
box_encoding_predictor_params: box_encoding_0_predictor_params,
class_predictor_params: class_predictor_0_params
}
const box_predictor_1_params = {
box_encoding_predictor_params: box_encoding_1_predictor_params,
class_predictor_params: class_predictor_1_params
}
const box_predictor_2_params = {
box_encoding_predictor_params: box_encoding_2_predictor_params,
class_predictor_params: class_predictor_2_params
}
const box_predictor_3_params = {
box_encoding_predictor_params: box_encoding_3_predictor_params,
class_predictor_params: class_predictor_3_params
}
const box_predictor_4_params = {
box_encoding_predictor_params: box_encoding_4_predictor_params,
class_predictor_params: class_predictor_4_params
}
const box_predictor_5_params = {
box_encoding_predictor_params: box_encoding_5_predictor_params,
class_predictor_params: class_predictor_5_params
}
return {
conv_0_params,
conv_1_params,
conv_2_params,
conv_3_params,
conv_4_params,
conv_5_params,
conv_6_params,
conv_7_params,
box_predictor_0_params,
box_predictor_1_params,
box_predictor_2_params,
box_predictor_3_params,
box_predictor_4_params,
box_predictor_5_params
}
}
return {
extractMobilenetV1Params
extractMobilenetV1Params,
extractPredictionLayerParams
}
}
export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams {
const extractWeights = (numWeights: number): Float32Array => {
console.log(numWeights)
const ret = weights.slice(0, numWeights)
weights = weights.slice(numWeights)
return ret
}
const {
extractMobilenetV1Params
extractMobilenetV1Params,
extractPredictionLayerParams
} = extractorsFactory(extractWeights)
const mobilenetv1_params = extractMobilenetV1Params()
const prediction_layer_params = extractPredictionLayerParams()
if (weights.length !== 0) {
throw new Error(`weights remaing after extract: ${weights.length}`)
}
return {
mobilenetv1_params
mobilenetv1_params,
prediction_layer_params
}
}
\ No newline at end of file
......@@ -4,6 +4,7 @@ import { isFloat } from '../utils';
import { extractParams } from './extractParams';
import { mobileNetV1 } from './mobileNetV1';
import { resizeLayer } from './resizeLayer';
import { predictionLayer } from './predictionLayer';
function fromData(input: number[]): tf.Tensor4D {
const pxPerChannel = input.length / 3
......@@ -47,19 +48,18 @@ export function faceDetectionNet(weights: Float32Array) {
? fromImageData(imgDataArray)
: fromData(input as number[])
let out = resizeLayer(imgTensor) as tf.Tensor4D
out = mobileNetV1(out, params.mobilenetv1_params)
const resized = resizeLayer(imgTensor) as tf.Tensor4D
const features = mobileNetV1(resized, params.mobilenetv1_params)
const {
boxPredictions,
classPredictions
} = predictionLayer(features.out, features.conv11, params.prediction_layer_params)
// boxpredictor0: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6
// boxpredictor1: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6
// boxpredictor2: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_2_3x3_s2_512/Relu6
// boxpredictor3: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_3_3x3_s2_256/Relu6
// boxpredictor4: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_4_3x3_s2_256/Relu6
// boxpredictor5: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6
return out
return {
boxPredictions,
classPredictions
}
})
}
......
......@@ -26,8 +26,6 @@ function depthwiseConvLayer(
})
}
function getStridesForLayerIdx(layerIdx: number): [number, number] {
return [2, 4, 6, 12].some(idx => idx === layerIdx) ? [2, 2] : [1, 1]
}
......@@ -35,13 +33,26 @@ function getStridesForLayerIdx(layerIdx: number): [number, number] {
export function mobileNetV1(x: tf.Tensor4D, params: FaceDetectionNet.MobileNetV1.Params) {
return tf.tidy(() => {
let conv11 = null
let out = pointwiseConvLayer(x, params.conv_0_params, [2, 2])
params.conv_pair_params.forEach((param, i) => {
const depthwiseConvStrides = getStridesForLayerIdx(i + 1)
const layerIdx = i + 1
const depthwiseConvStrides = getStridesForLayerIdx(layerIdx)
out = depthwiseConvLayer(out, param.depthwise_conv_params, depthwiseConvStrides)
out = pointwiseConvLayer(out, param.pointwise_conv_params, [1, 1])
if (layerIdx === 11) {
conv11 = out
}
})
return out
if (conv11 === null) {
throw new Error('mobileNetV1 - output of conv layer 11 is null')
}
return {
out,
conv11: conv11 as any
}
})
}
\ No newline at end of file
......@@ -4,19 +4,19 @@ import { boxPredictionLayer } from './boxPredictionLayer';
import { pointwiseConvLayer } from './pointwiseConvLayer';
import { FaceDetectionNet } from './types';
export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) {
export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) {
return tf.tidy(() => {
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1])
const conv1 = pointwiseConvLayer(x, params.conv_1_params, [2, 2])
const conv2 = pointwiseConvLayer(x, params.conv_2_params, [1, 1])
const conv3 = pointwiseConvLayer(x, params.conv_3_params, [2, 2])
const conv4 = pointwiseConvLayer(x, params.conv_4_params, [1, 1])
const conv5 = pointwiseConvLayer(x, params.conv_5_params, [2, 2])
const conv6 = pointwiseConvLayer(x, params.conv_4_params, [1, 1])
const conv7 = pointwiseConvLayer(x, params.conv_5_params, [2, 2])
const conv1 = pointwiseConvLayer(conv0, params.conv_1_params, [2, 2])
const conv2 = pointwiseConvLayer(conv1, params.conv_2_params, [1, 1])
const conv3 = pointwiseConvLayer(conv2, params.conv_3_params, [2, 2])
const conv4 = pointwiseConvLayer(conv3, params.conv_4_params, [1, 1])
const conv5 = pointwiseConvLayer(conv4, params.conv_5_params, [2, 2])
const conv6 = pointwiseConvLayer(conv5, params.conv_6_params, [1, 1])
const conv7 = pointwiseConvLayer(conv6, params.conv_7_params, [2, 2])
const boxPrediction0 = boxPredictionLayer(x, params.box_predictor_0_params, 3072)
const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0_params, 3072)
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params, 1536)
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params, 384)
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params, 96)
......@@ -30,7 +30,7 @@ export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.Predict
boxPrediction3.boxPredictionEncoding,
boxPrediction4.boxPredictionEncoding,
boxPrediction5.boxPredictionEncoding
])
], 1)
const classPredictions = tf.concat([
boxPrediction0.classPrediction,
......@@ -39,7 +39,7 @@ export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.Predict
boxPrediction3.classPrediction,
boxPrediction4.classPrediction,
boxPrediction5.classPrediction
])
], 1)
return {
boxPredictions,
......
......@@ -35,8 +35,8 @@ export namespace FaceDetectionNet {
}
export type BoxPredictionParams = {
class_predictor_params: ConvWithBiasParams
box_encoding_predictor_params: ConvWithBiasParams
class_predictor_params: ConvWithBiasParams
}
export type PredictionParams = {
......@@ -46,6 +46,8 @@ export namespace FaceDetectionNet {
conv_3_params: PointwiseConvParams
conv_4_params: PointwiseConvParams
conv_5_params: PointwiseConvParams
conv_6_params: PointwiseConvParams
conv_7_params: PointwiseConvParams
box_predictor_0_params: BoxPredictionParams
box_predictor_1_params: BoxPredictionParams
box_predictor_2_params: BoxPredictionParams
......@@ -55,7 +57,8 @@ export namespace FaceDetectionNet {
}
export type NetParams = {
mobilenetv1_params: MobileNetV1.Params
mobilenetv1_params: MobileNetV1.Params,
prediction_layer_params: PredictionParams
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment