Commit f3469d09 by vincent

prediction layer + params

parent 8ad95e33
...@@ -2,28 +2,16 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,28 +2,16 @@ import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionNet } from './types'; import { FaceDetectionNet } from './types';
function boxEncodingPredictionLayer( function convWithBias(
x: tf.Tensor4D, x: tf.Tensor4D,
params: FaceDetectionNet.ConvWithBiasParams params: FaceDetectionNet.ConvWithBiasParams
) { ) {
return tf.tidy(() => { return tf.tidy(() =>
tf.add(
// TODO tf.conv2d(x, params.filters, [1, 1], 'same'),
return x params.bias
)
}) )
}
function classPredictionLayer(
x: tf.Tensor4D,
params: FaceDetectionNet.ConvWithBiasParams
) {
return tf.tidy(() => {
// TODO
return x
})
} }
export function boxPredictionLayer( export function boxPredictionLayer(
...@@ -33,13 +21,15 @@ export function boxPredictionLayer( ...@@ -33,13 +21,15 @@ export function boxPredictionLayer(
) { ) {
return tf.tidy(() => { return tf.tidy(() => {
const batchSize = x.shape[0]
const boxPredictionEncoding = tf.reshape( const boxPredictionEncoding = tf.reshape(
boxEncodingPredictionLayer(x, params.box_encoding_predictor_params), convWithBias(x, params.box_encoding_predictor_params),
[x.shape[0], size, 1, 4] [batchSize, size, 1, 4]
) )
const classPrediction = tf.reshape( const classPrediction = tf.reshape(
classPredictionLayer(x, params.class_predictor_params), convWithBias(x, params.class_predictor_params),
[x.shape[0], size, 3] [batchSize, size, 3]
) )
return { return {
......
...@@ -20,19 +20,42 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) ...@@ -20,19 +20,42 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
} }
} }
function extractPointwiseConvParams(channelsIn: number, channelsOut: number): FaceDetectionNet.PointwiseConvParams { function extractConvWithBiasParams(
const filters = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]) channelsIn: number,
const batch_norm_offset = tf.tensor1d(extractWeights(channelsOut)) channelsOut: number,
filterSize: number
): FaceDetectionNet.ConvWithBiasParams {
const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut]
)
const bias = tf.tensor1d(extractWeights(channelsOut))
return { return {
filters, filters,
batch_norm_offset bias
}
}
function extractPointwiseConvParams(
channelsIn: number,
channelsOut: number,
filterSize: number
): FaceDetectionNet.PointwiseConvParams {
const {
filters,
bias
} = extractConvWithBiasParams(channelsIn, channelsOut, filterSize)
return {
filters,
batch_norm_offset: bias
} }
} }
function extractConvPairParams(channelsIn: number, channelsOut: number): FaceDetectionNet.MobileNetV1.ConvPairParams { function extractConvPairParams(channelsIn: number, channelsOut: number): FaceDetectionNet.MobileNetV1.ConvPairParams {
const depthwise_conv_params = extractDepthwiseConvParams(channelsIn) const depthwise_conv_params = extractDepthwiseConvParams(channelsIn)
const pointwise_conv_params = extractPointwiseConvParams(channelsIn, channelsOut) const pointwise_conv_params = extractPointwiseConvParams(channelsIn, channelsOut, 1)
return { return {
depthwise_conv_params, depthwise_conv_params,
...@@ -42,11 +65,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) ...@@ -42,11 +65,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
function extractMobilenetV1Params(): FaceDetectionNet.MobileNetV1.Params { function extractMobilenetV1Params(): FaceDetectionNet.MobileNetV1.Params {
const conv_0_params = { const conv_0_params = extractPointwiseConvParams(3, 32, 3)
filters: tf.tensor4d(extractWeights(3 * 3 * 3 * 32), [3, 3, 3, 32]),
batch_norm_offset: tf.tensor1d(extractWeights(32))
}
const channelNumPairs = [ const channelNumPairs = [
[32, 64], [32, 64],
...@@ -75,32 +94,101 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) ...@@ -75,32 +94,101 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
} }
function extractPredictionLayerParams(): FaceDetectionNet.PredictionParams {
const conv_0_params = extractPointwiseConvParams(1024, 256, 1)
const conv_1_params = extractPointwiseConvParams(256, 512, 3)
const conv_2_params = extractPointwiseConvParams(512, 128, 1)
const conv_3_params = extractPointwiseConvParams(128, 256, 3)
const conv_4_params = extractPointwiseConvParams(256, 128, 1)
const conv_5_params = extractPointwiseConvParams(128, 256, 3)
const conv_6_params = extractPointwiseConvParams(256, 64, 1)
const conv_7_params = extractPointwiseConvParams(64, 128, 3)
const box_encoding_0_predictor_params = extractConvWithBiasParams(512, 12, 1)
const class_predictor_0_params = extractConvWithBiasParams(512, 9, 1)
const box_encoding_1_predictor_params = extractConvWithBiasParams(1024, 24, 1)
const class_predictor_1_params = extractConvWithBiasParams(1024, 18, 1)
const box_encoding_2_predictor_params = extractConvWithBiasParams(512, 24, 1)
const class_predictor_2_params = extractConvWithBiasParams(512, 18, 1)
const box_encoding_3_predictor_params = extractConvWithBiasParams(256, 24, 1)
const class_predictor_3_params = extractConvWithBiasParams(256, 18, 1)
const box_encoding_4_predictor_params = extractConvWithBiasParams(256, 24, 1)
const class_predictor_4_params = extractConvWithBiasParams(256, 18, 1)
const box_encoding_5_predictor_params = extractConvWithBiasParams(128, 24, 1)
const class_predictor_5_params = extractConvWithBiasParams(128, 18, 1)
const box_predictor_0_params = {
box_encoding_predictor_params: box_encoding_0_predictor_params,
class_predictor_params: class_predictor_0_params
}
const box_predictor_1_params = {
box_encoding_predictor_params: box_encoding_1_predictor_params,
class_predictor_params: class_predictor_1_params
}
const box_predictor_2_params = {
box_encoding_predictor_params: box_encoding_2_predictor_params,
class_predictor_params: class_predictor_2_params
}
const box_predictor_3_params = {
box_encoding_predictor_params: box_encoding_3_predictor_params,
class_predictor_params: class_predictor_3_params
}
const box_predictor_4_params = {
box_encoding_predictor_params: box_encoding_4_predictor_params,
class_predictor_params: class_predictor_4_params
}
const box_predictor_5_params = {
box_encoding_predictor_params: box_encoding_5_predictor_params,
class_predictor_params: class_predictor_5_params
}
return {
conv_0_params,
conv_1_params,
conv_2_params,
conv_3_params,
conv_4_params,
conv_5_params,
conv_6_params,
conv_7_params,
box_predictor_0_params,
box_predictor_1_params,
box_predictor_2_params,
box_predictor_3_params,
box_predictor_4_params,
box_predictor_5_params
}
}
return { return {
extractMobilenetV1Params extractMobilenetV1Params,
extractPredictionLayerParams
} }
} }
export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams { export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams {
const extractWeights = (numWeights: number): Float32Array => { const extractWeights = (numWeights: number): Float32Array => {
console.log(numWeights)
const ret = weights.slice(0, numWeights) const ret = weights.slice(0, numWeights)
weights = weights.slice(numWeights) weights = weights.slice(numWeights)
return ret return ret
} }
const { const {
extractMobilenetV1Params extractMobilenetV1Params,
extractPredictionLayerParams
} = extractorsFactory(extractWeights) } = extractorsFactory(extractWeights)
const mobilenetv1_params = extractMobilenetV1Params() const mobilenetv1_params = extractMobilenetV1Params()
const prediction_layer_params = extractPredictionLayerParams()
if (weights.length !== 0) { if (weights.length !== 0) {
throw new Error(`weights remaing after extract: ${weights.length}`) throw new Error(`weights remaing after extract: ${weights.length}`)
} }
return { return {
mobilenetv1_params mobilenetv1_params,
prediction_layer_params
} }
} }
\ No newline at end of file
...@@ -4,6 +4,7 @@ import { isFloat } from '../utils'; ...@@ -4,6 +4,7 @@ import { isFloat } from '../utils';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { mobileNetV1 } from './mobileNetV1'; import { mobileNetV1 } from './mobileNetV1';
import { resizeLayer } from './resizeLayer'; import { resizeLayer } from './resizeLayer';
import { predictionLayer } from './predictionLayer';
function fromData(input: number[]): tf.Tensor4D { function fromData(input: number[]): tf.Tensor4D {
const pxPerChannel = input.length / 3 const pxPerChannel = input.length / 3
...@@ -47,19 +48,18 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -47,19 +48,18 @@ export function faceDetectionNet(weights: Float32Array) {
? fromImageData(imgDataArray) ? fromImageData(imgDataArray)
: fromData(input as number[]) : fromData(input as number[])
let out = resizeLayer(imgTensor) as tf.Tensor4D const resized = resizeLayer(imgTensor) as tf.Tensor4D
out = mobileNetV1(out, params.mobilenetv1_params) const features = mobileNetV1(resized, params.mobilenetv1_params)
const {
boxPredictions,
classPredictions
} = predictionLayer(features.out, features.conv11, params.prediction_layer_params)
return {
// boxpredictor0: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6 boxPredictions,
// boxpredictor1: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6 classPredictions
// boxpredictor2: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_2_3x3_s2_512/Relu6 }
// boxpredictor3: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_3_3x3_s2_256/Relu6
// boxpredictor4: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_4_3x3_s2_256/Relu6
// boxpredictor5: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6
return out
}) })
} }
......
...@@ -26,8 +26,6 @@ function depthwiseConvLayer( ...@@ -26,8 +26,6 @@ function depthwiseConvLayer(
}) })
} }
function getStridesForLayerIdx(layerIdx: number): [number, number] { function getStridesForLayerIdx(layerIdx: number): [number, number] {
return [2, 4, 6, 12].some(idx => idx === layerIdx) ? [2, 2] : [1, 1] return [2, 4, 6, 12].some(idx => idx === layerIdx) ? [2, 2] : [1, 1]
} }
...@@ -35,13 +33,26 @@ function getStridesForLayerIdx(layerIdx: number): [number, number] { ...@@ -35,13 +33,26 @@ function getStridesForLayerIdx(layerIdx: number): [number, number] {
export function mobileNetV1(x: tf.Tensor4D, params: FaceDetectionNet.MobileNetV1.Params) { export function mobileNetV1(x: tf.Tensor4D, params: FaceDetectionNet.MobileNetV1.Params) {
return tf.tidy(() => { return tf.tidy(() => {
let conv11 = null
let out = pointwiseConvLayer(x, params.conv_0_params, [2, 2]) let out = pointwiseConvLayer(x, params.conv_0_params, [2, 2])
params.conv_pair_params.forEach((param, i) => { params.conv_pair_params.forEach((param, i) => {
const depthwiseConvStrides = getStridesForLayerIdx(i + 1) const layerIdx = i + 1
const depthwiseConvStrides = getStridesForLayerIdx(layerIdx)
out = depthwiseConvLayer(out, param.depthwise_conv_params, depthwiseConvStrides) out = depthwiseConvLayer(out, param.depthwise_conv_params, depthwiseConvStrides)
out = pointwiseConvLayer(out, param.pointwise_conv_params, [1, 1]) out = pointwiseConvLayer(out, param.pointwise_conv_params, [1, 1])
if (layerIdx === 11) {
conv11 = out
}
}) })
return out
if (conv11 === null) {
throw new Error('mobileNetV1 - output of conv layer 11 is null')
}
return {
out,
conv11: conv11 as any
}
}) })
} }
\ No newline at end of file
...@@ -4,19 +4,19 @@ import { boxPredictionLayer } from './boxPredictionLayer'; ...@@ -4,19 +4,19 @@ import { boxPredictionLayer } from './boxPredictionLayer';
import { pointwiseConvLayer } from './pointwiseConvLayer'; import { pointwiseConvLayer } from './pointwiseConvLayer';
import { FaceDetectionNet } from './types'; import { FaceDetectionNet } from './types';
export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) { export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) {
return tf.tidy(() => { return tf.tidy(() => {
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1]) const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1])
const conv1 = pointwiseConvLayer(x, params.conv_1_params, [2, 2]) const conv1 = pointwiseConvLayer(conv0, params.conv_1_params, [2, 2])
const conv2 = pointwiseConvLayer(x, params.conv_2_params, [1, 1]) const conv2 = pointwiseConvLayer(conv1, params.conv_2_params, [1, 1])
const conv3 = pointwiseConvLayer(x, params.conv_3_params, [2, 2]) const conv3 = pointwiseConvLayer(conv2, params.conv_3_params, [2, 2])
const conv4 = pointwiseConvLayer(x, params.conv_4_params, [1, 1]) const conv4 = pointwiseConvLayer(conv3, params.conv_4_params, [1, 1])
const conv5 = pointwiseConvLayer(x, params.conv_5_params, [2, 2]) const conv5 = pointwiseConvLayer(conv4, params.conv_5_params, [2, 2])
const conv6 = pointwiseConvLayer(x, params.conv_4_params, [1, 1]) const conv6 = pointwiseConvLayer(conv5, params.conv_6_params, [1, 1])
const conv7 = pointwiseConvLayer(x, params.conv_5_params, [2, 2]) const conv7 = pointwiseConvLayer(conv6, params.conv_7_params, [2, 2])
const boxPrediction0 = boxPredictionLayer(x, params.box_predictor_0_params, 3072) const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0_params, 3072)
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params, 1536) const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params, 1536)
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params, 384) const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params, 384)
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params, 96) const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params, 96)
...@@ -30,7 +30,7 @@ export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.Predict ...@@ -30,7 +30,7 @@ export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.Predict
boxPrediction3.boxPredictionEncoding, boxPrediction3.boxPredictionEncoding,
boxPrediction4.boxPredictionEncoding, boxPrediction4.boxPredictionEncoding,
boxPrediction5.boxPredictionEncoding boxPrediction5.boxPredictionEncoding
]) ], 1)
const classPredictions = tf.concat([ const classPredictions = tf.concat([
boxPrediction0.classPrediction, boxPrediction0.classPrediction,
...@@ -39,7 +39,7 @@ export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.Predict ...@@ -39,7 +39,7 @@ export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.Predict
boxPrediction3.classPrediction, boxPrediction3.classPrediction,
boxPrediction4.classPrediction, boxPrediction4.classPrediction,
boxPrediction5.classPrediction boxPrediction5.classPrediction
]) ], 1)
return { return {
boxPredictions, boxPredictions,
......
...@@ -35,8 +35,8 @@ export namespace FaceDetectionNet { ...@@ -35,8 +35,8 @@ export namespace FaceDetectionNet {
} }
export type BoxPredictionParams = { export type BoxPredictionParams = {
class_predictor_params: ConvWithBiasParams
box_encoding_predictor_params: ConvWithBiasParams box_encoding_predictor_params: ConvWithBiasParams
class_predictor_params: ConvWithBiasParams
} }
export type PredictionParams = { export type PredictionParams = {
...@@ -46,6 +46,8 @@ export namespace FaceDetectionNet { ...@@ -46,6 +46,8 @@ export namespace FaceDetectionNet {
conv_3_params: PointwiseConvParams conv_3_params: PointwiseConvParams
conv_4_params: PointwiseConvParams conv_4_params: PointwiseConvParams
conv_5_params: PointwiseConvParams conv_5_params: PointwiseConvParams
conv_6_params: PointwiseConvParams
conv_7_params: PointwiseConvParams
box_predictor_0_params: BoxPredictionParams box_predictor_0_params: BoxPredictionParams
box_predictor_1_params: BoxPredictionParams box_predictor_1_params: BoxPredictionParams
box_predictor_2_params: BoxPredictionParams box_predictor_2_params: BoxPredictionParams
...@@ -55,7 +57,8 @@ export namespace FaceDetectionNet { ...@@ -55,7 +57,8 @@ export namespace FaceDetectionNet {
} }
export type NetParams = { export type NetParams = {
mobilenetv1_params: MobileNetV1.Params mobilenetv1_params: MobileNetV1.Params,
prediction_layer_params: PredictionParams
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment