Commit 23d4664d by vincent

decode layer

parent bc94ce37
...@@ -94,7 +94,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) ...@@ -94,7 +94,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
} }
function extractPredictionLayerParams(): FaceDetectionNet.PredictionParams { function extractPredictionLayerParams(): FaceDetectionNet.PredictionLayerParams {
const conv_0_params = extractPointwiseConvParams(1024, 256, 1) const conv_0_params = extractPointwiseConvParams(1024, 256, 1)
const conv_1_params = extractPointwiseConvParams(256, 512, 3) const conv_1_params = extractPointwiseConvParams(256, 512, 3)
const conv_2_params = extractPointwiseConvParams(512, 128, 1) const conv_2_params = extractPointwiseConvParams(512, 128, 1)
...@@ -182,6 +182,13 @@ export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams ...@@ -182,6 +182,13 @@ export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams
const mobilenetv1_params = extractMobilenetV1Params() const mobilenetv1_params = extractMobilenetV1Params()
const prediction_layer_params = extractPredictionLayerParams() const prediction_layer_params = extractPredictionLayerParams()
const extra_dim = tf.tensor3d(
extractWeights(5118 * 4),
[1, 5118, 4]
)
const output_layer_params = {
extra_dim
}
if (weights.length !== 0) { if (weights.length !== 0) {
throw new Error(`weights remaing after extract: ${weights.length}`) throw new Error(`weights remaing after extract: ${weights.length}`)
...@@ -189,6 +196,7 @@ export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams ...@@ -189,6 +196,7 @@ export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams
return { return {
mobilenetv1_params, mobilenetv1_params,
prediction_layer_params prediction_layer_params,
output_layer_params
} }
} }
\ No newline at end of file
...@@ -5,6 +5,7 @@ import { extractParams } from './extractParams'; ...@@ -5,6 +5,7 @@ import { extractParams } from './extractParams';
import { mobileNetV1 } from './mobileNetV1'; import { mobileNetV1 } from './mobileNetV1';
import { resizeLayer } from './resizeLayer'; import { resizeLayer } from './resizeLayer';
import { predictionLayer } from './predictionLayer'; import { predictionLayer } from './predictionLayer';
import { outputLayer } from './outputLayer';
function fromData(input: number[]): tf.Tensor4D { function fromData(input: number[]): tf.Tensor4D {
const pxPerChannel = input.length / 3 const pxPerChannel = input.length / 3
...@@ -56,10 +57,9 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -56,10 +57,9 @@ export function faceDetectionNet(weights: Float32Array) {
classPredictions classPredictions
} = predictionLayer(features.out, features.conv11, params.prediction_layer_params) } = predictionLayer(features.out, features.conv11, params.prediction_layer_params)
return { const decoded = outputLayer(boxPredictions, classPredictions, params.output_layer_params)
boxPredictions,
classPredictions return decoded
}
}) })
} }
......
import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionNet } from './types';
function batchMultiClassNonMaxSuppressionLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
// TODO
return x0
}
function getCenterCoordinatesAndSizesLayer(x: tf.Tensor2D) {
const vec = tf.unstack(tf.transpose(x, [1, 0]))
const sizes = [
tf.sub(vec[2], vec[0]),
tf.sub(vec[3], vec[1])
]
const centers = [
tf.add(vec[0], tf.div(sizes[0], tf.scalar(2))),
tf.add(vec[1], tf.div(sizes[1], tf.scalar(2)))
]
return {
sizes,
centers
}
}
function decodeLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
const {
sizes,
centers
} = getCenterCoordinatesAndSizesLayer(x0)
const vec = tf.unstack(tf.transpose(x1, [1, 0]))
const div0_out = tf.div(tf.mul(tf.exp(tf.div(vec[2], tf.scalar(5))), sizes[0]), tf.scalar(2))
const add0_out = tf.add(tf.mul(tf.exp(tf.div(vec[0], tf.scalar(10))), sizes[0]), centers[0])
const div1_out = tf.div(tf.mul(tf.exp(tf.div(vec[3], tf.scalar(5))), sizes[1]), tf.scalar(2))
const add1_out = tf.add(tf.mul(tf.exp(tf.div(vec[1], tf.scalar(10))), sizes[1]), centers[1])
return tf.transpose(
tf.stack([
tf.sub(div0_out, add0_out),
tf.sub(div1_out, add1_out),
tf.add(div0_out, add0_out),
tf.add(div1_out, add1_out)
]),
[1, 0]
)
}
export function outputLayer(
boxPredictions: tf.Tensor4D,
classPredictions: tf.Tensor4D,
params: FaceDetectionNet.OutputLayerParams
) {
return tf.tidy(() => {
const batchSize = boxPredictions.shape[0]
const decoded = decodeLayer(
tf.reshape(tf.tile(params.extra_dim, [batchSize, 1, 1]), [-1, 4]) as tf.Tensor2D,
tf.reshape(boxPredictions, [-1, 4]) as tf.Tensor2D
)
const in1 = tf.sigmoid(tf.slice(classPredictions, [0, 0, 1], [-1, -1, -1]))
const in2 = tf.expandDims(tf.reshape(decoded, [batchSize, 5118, 4]), 2)
return decoded
})
}
\ No newline at end of file
...@@ -4,7 +4,7 @@ import { boxPredictionLayer } from './boxPredictionLayer'; ...@@ -4,7 +4,7 @@ import { boxPredictionLayer } from './boxPredictionLayer';
import { pointwiseConvLayer } from './pointwiseConvLayer'; import { pointwiseConvLayer } from './pointwiseConvLayer';
import { FaceDetectionNet } from './types'; import { FaceDetectionNet } from './types';
export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) { export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: FaceDetectionNet.PredictionLayerParams) {
return tf.tidy(() => { return tf.tidy(() => {
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1]) const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1])
...@@ -30,7 +30,7 @@ export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: Fac ...@@ -30,7 +30,7 @@ export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: Fac
boxPrediction3.boxPredictionEncoding, boxPrediction3.boxPredictionEncoding,
boxPrediction4.boxPredictionEncoding, boxPrediction4.boxPredictionEncoding,
boxPrediction5.boxPredictionEncoding boxPrediction5.boxPredictionEncoding
], 1) ], 1) as tf.Tensor4D
const classPredictions = tf.concat([ const classPredictions = tf.concat([
boxPrediction0.classPrediction, boxPrediction0.classPrediction,
...@@ -39,7 +39,7 @@ export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: Fac ...@@ -39,7 +39,7 @@ export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: Fac
boxPrediction3.classPrediction, boxPrediction3.classPrediction,
boxPrediction4.classPrediction, boxPrediction4.classPrediction,
boxPrediction5.classPrediction boxPrediction5.classPrediction
], 1) ], 1) as tf.Tensor4D
return { return {
boxPredictions, boxPredictions,
......
...@@ -39,7 +39,7 @@ export namespace FaceDetectionNet { ...@@ -39,7 +39,7 @@ export namespace FaceDetectionNet {
class_predictor_params: ConvWithBiasParams class_predictor_params: ConvWithBiasParams
} }
export type PredictionParams = { export type PredictionLayerParams = {
conv_0_params: PointwiseConvParams conv_0_params: PointwiseConvParams
conv_1_params: PointwiseConvParams conv_1_params: PointwiseConvParams
conv_2_params: PointwiseConvParams conv_2_params: PointwiseConvParams
...@@ -56,9 +56,14 @@ export namespace FaceDetectionNet { ...@@ -56,9 +56,14 @@ export namespace FaceDetectionNet {
box_predictor_5_params: BoxPredictionParams box_predictor_5_params: BoxPredictionParams
} }
export type OutputLayerParams = {
extra_dim: tf.Tensor3D
}
export type NetParams = { export type NetParams = {
mobilenetv1_params: MobileNetV1.Params, mobilenetv1_params: MobileNetV1.Params,
prediction_layer_params: PredictionParams prediction_layer_params: PredictionLayerParams,
output_layer_params: OutputLayerParams
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment