Commit 9fef1428 by vincent

build param mappings for remaining nets (detection, recognition) to make them disposable

parent a6a68a56
...@@ -32,6 +32,11 @@ export class Rect implements IRect { ...@@ -32,6 +32,11 @@ export class Rect implements IRect {
return new Rect(x, y, width, height) return new Rect(x, y, width, height)
} }
public pad(padX: number, padY: number): Rect {
let { x, y, width, height } = this
return new Rect(x - (padX / 2), y - (padY / 2), width + padX, height + padY)
}
public floor(): Rect { public floor(): Rect {
return new Rect( return new Rect(
Math.floor(this.x), Math.floor(this.x),
......
...@@ -7,6 +7,8 @@ export class NeuralNetwork<TNetParams> { ...@@ -7,6 +7,8 @@ export class NeuralNetwork<TNetParams> {
protected _params: TNetParams | undefined = undefined protected _params: TNetParams | undefined = undefined
protected _paramMappings: ParamMapping[] = [] protected _paramMappings: ParamMapping[] = []
constructor(private _name: string) {}
public get params(): TNetParams | undefined { public get params(): TNetParams | undefined {
return this._params return this._params
} }
...@@ -53,11 +55,44 @@ export class NeuralNetwork<TNetParams> { ...@@ -53,11 +55,44 @@ export class NeuralNetwork<TNetParams> {
}) })
} }
public dispose() { public dispose(throwOnRedispose: boolean = true) {
this.getParamList().forEach(param => param.tensor.dispose()) this.getParamList().forEach(param => {
if (throwOnRedispose && param.tensor.isDisposed) {
throw new Error(`param tensor has already been disposed for path ${param.path}`)
}
param.tensor.dispose()
})
this._params = undefined this._params = undefined
} }
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> {
if (weightsOrUrl instanceof Float32Array) {
this.extractWeights(weightsOrUrl)
return
}
if (weightsOrUrl && typeof weightsOrUrl !== 'string') {
throw new Error(`${this._name}.load - expected model uri, or weights as Float32Array`)
}
const {
paramMappings,
params
} = await this.loadQuantizedParams(weightsOrUrl)
this._paramMappings = paramMappings
this._params = params
}
public extractWeights(weights: Float32Array) {
const {
paramMappings,
params
} = this.extractParams(weights)
this._paramMappings = paramMappings
this._params = params
}
private traversePropertyPath(paramPath: string) { private traversePropertyPath(paramPath: string) {
if (!this.params) { if (!this.params) {
throw new Error(`traversePropertyPath - model has no loaded params`) throw new Error(`traversePropertyPath - model has no loaded params`)
...@@ -78,4 +113,12 @@ export class NeuralNetwork<TNetParams> { ...@@ -78,4 +113,12 @@ export class NeuralNetwork<TNetParams> {
return { obj, objProp } return { obj, objProp }
} }
protected loadQuantizedParams(_: any): Promise<{ params: TNetParams, paramMappings: ParamMapping[] }> {
throw new Error(`${this._name}.loadQuantizedParams - not implemented`)
}
protected extractParams(_: any): { params: TNetParams, paramMappings: ParamMapping[] } {
throw new Error(`${this._name}.extractParams - not implemented`)
}
} }
\ No newline at end of file
import { ParamMapping } from './types';
export function disposeUnusedWeightTensors(weightMap: any, paramMappings: ParamMapping[]) {
Object.keys(weightMap).forEach(path => {
if (!paramMappings.some(pm => pm.originalPath === path)) {
weightMap[path].dispose()
}
})
}
import * as tf from '@tensorflow/tfjs-core';
import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
return function (
channelsIn: number,
channelsOut: number,
filterSize: number,
mappedPrefix: string
): ConvParams {
const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut]
)
const bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` }
)
return {
filters,
bias
}
}
}
\ No newline at end of file
import { isTensor } from './isTensor';
export function extractWeightEntry(weightMap: any, path: string, paramRank: number) {
const tensor = weightMap[path]
if (!isTensor(tensor, paramRank)) {
throw new Error(`expected weightMap[${path}] to be a Tensor${paramRank}D, instead have ${tensor}`)
}
return { path, tensor }
}
\ No newline at end of file
import { isTensor } from './isTensor';
import { ParamMapping } from './types';
export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) {
return function<T> (originalPath: string, paramRank: number, mappedPath?: string): T {
const tensor = weightMap[originalPath]
if (!isTensor(tensor, paramRank)) {
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`)
}
paramMappings.push(
{ originalPath, paramPath: mappedPath || originalPath }
)
return tensor
}
}
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NeuralNetwork } from '../commons/NeuralNetwork';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { Rect } from '../Rect'; import { Rect } from '../Rect';
import { toNetInput } from '../toNetInput'; import { toNetInput } from '../toNetInput';
...@@ -13,28 +14,17 @@ import { outputLayer } from './outputLayer'; ...@@ -13,28 +14,17 @@ import { outputLayer } from './outputLayer';
import { predictionLayer } from './predictionLayer'; import { predictionLayer } from './predictionLayer';
import { NetParams } from './types'; import { NetParams } from './types';
export class FaceDetectionNet { export class FaceDetectionNet extends NeuralNetwork<NetParams> {
private _params: NetParams constructor() {
super('FaceDetectionNet')
public async load(weightsOrUrl?: Float32Array | string): Promise<void> {
if (weightsOrUrl instanceof Float32Array) {
this.extractWeights(weightsOrUrl)
return
}
if (weightsOrUrl && typeof weightsOrUrl !== 'string') {
throw new Error('FaceDetectionNet.load - expected model uri, or weights as Float32Array')
}
this._params = await loadQuantizedParams(weightsOrUrl)
}
public extractWeights(weights: Float32Array) {
this._params = extractParams(weights)
} }
public forwardInput(input: NetInput) { public forwardInput(input: NetInput) {
if (!this._params) {
const { params } = this
if (!params) {
throw new Error('FaceDetectionNet - load model before inference') throw new Error('FaceDetectionNet - load model before inference')
} }
...@@ -42,14 +32,14 @@ export class FaceDetectionNet { ...@@ -42,14 +32,14 @@ export class FaceDetectionNet {
const batchTensor = input.toBatchTensor(512, false) const batchTensor = input.toBatchTensor(512, false)
const x = tf.sub(tf.mul(batchTensor, tf.scalar(0.007843137718737125)), tf.scalar(1)) as tf.Tensor4D const x = tf.sub(tf.mul(batchTensor, tf.scalar(0.007843137718737125)), tf.scalar(1)) as tf.Tensor4D
const features = mobileNetV1(x, this._params.mobilenetv1_params) const features = mobileNetV1(x, params.mobilenetv1)
const { const {
boxPredictions, boxPredictions,
classPredictions classPredictions
} = predictionLayer(features.out, features.conv11, this._params.prediction_layer_params) } = predictionLayer(features.out, features.conv11, params.prediction_layer)
return outputLayer(boxPredictions, classPredictions, this._params.output_layer_params) return outputLayer(boxPredictions, classPredictions, params.output_layer)
}) })
} }
...@@ -91,7 +81,6 @@ export class FaceDetectionNet { ...@@ -91,7 +81,6 @@ export class FaceDetectionNet {
minConfidence minConfidence
) )
const paddedHeightRelative = (netInput.getPaddings(0).y + netInput.getInputHeight(0)) / netInput.getInputHeight(0) const paddedHeightRelative = (netInput.getPaddings(0).y + netInput.getInputHeight(0)) / netInput.getInputHeight(0)
const paddedWidthRelative = (netInput.getPaddings(0).x + netInput.getInputWidth(0)) / netInput.getInputWidth(0) const paddedWidthRelative = (netInput.getPaddings(0).x + netInput.getInputWidth(0)) / netInput.getInputWidth(0)
...@@ -125,4 +114,12 @@ export class FaceDetectionNet { ...@@ -125,4 +114,12 @@ export class FaceDetectionNet {
return results return results
} }
protected loadQuantizedParams(uri: string | undefined) {
return loadQuantizedParams(uri)
}
protected extractParams(weights: Float32Array) {
return extractParams(weights)
}
} }
\ No newline at end of file
...@@ -13,11 +13,11 @@ export function boxPredictionLayer( ...@@ -13,11 +13,11 @@ export function boxPredictionLayer(
const batchSize = x.shape[0] const batchSize = x.shape[0]
const boxPredictionEncoding = tf.reshape( const boxPredictionEncoding = tf.reshape(
convLayer(x, params.box_encoding_predictor_params), convLayer(x, params.box_encoding_predictor),
[batchSize, -1, 1, 4] [batchSize, -1, 1, 4]
) )
const classPrediction = tf.reshape( const classPrediction = tf.reshape(
convLayer(x, params.class_predictor_params), convLayer(x, params.class_predictor),
[batchSize, -1, 3] [batchSize, -1, 3]
) )
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { extractWeightsFactory } from '../commons/extractWeightsFactory'; import { extractWeightsFactory } from '../commons/extractWeightsFactory';
import { ConvParams } from '../commons/types'; import { ConvParams, ExtractWeightsFunction, ParamMapping } from '../commons/types';
import { MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types'; import { MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) { function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
function extractDepthwiseConvParams(numChannels: number, mappedPrefix: string): MobileNetV1.DepthwiseConvParams {
function extractDepthwiseConvParams(numChannels: number): MobileNetV1.DepthwiseConvParams {
const filters = tf.tensor4d(extractWeights(3 * 3 * numChannels), [3, 3, numChannels, 1]) const filters = tf.tensor4d(extractWeights(3 * 3 * numChannels), [3, 3, numChannels, 1])
const batch_norm_scale = tf.tensor1d(extractWeights(numChannels)) const batch_norm_scale = tf.tensor1d(extractWeights(numChannels))
const batch_norm_offset = tf.tensor1d(extractWeights(numChannels)) const batch_norm_offset = tf.tensor1d(extractWeights(numChannels))
const batch_norm_mean = tf.tensor1d(extractWeights(numChannels)) const batch_norm_mean = tf.tensor1d(extractWeights(numChannels))
const batch_norm_variance = tf.tensor1d(extractWeights(numChannels)) const batch_norm_variance = tf.tensor1d(extractWeights(numChannels))
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/batch_norm_scale` },
{ paramPath: `${mappedPrefix}/batch_norm_offset` },
{ paramPath: `${mappedPrefix}/batch_norm_mean` },
{ paramPath: `${mappedPrefix}/batch_norm_variance` }
)
return { return {
filters, filters,
batch_norm_scale, batch_norm_scale,
...@@ -25,29 +34,36 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) ...@@ -25,29 +34,36 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
function extractConvParams( function extractConvParams(
channelsIn: number, channelsIn: number,
channelsOut: number, channelsOut: number,
filterSize: number filterSize: number,
mappedPrefix: string,
isPointwiseConv?: boolean
): ConvParams { ): ConvParams {
const filters = tf.tensor4d( const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize), extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut] [filterSize, filterSize, channelsIn, channelsOut]
) )
const bias = tf.tensor1d(extractWeights(channelsOut)) const bias = tf.tensor1d(extractWeights(channelsOut))
return { paramMappings.push(
filters, { paramPath: `${mappedPrefix}/filters` },
bias { paramPath: `${mappedPrefix}/${isPointwiseConv ? 'batch_norm_offset' : 'bias'}` }
} )
return { filters, bias }
} }
function extractPointwiseConvParams( function extractPointwiseConvParams(
channelsIn: number, channelsIn: number,
channelsOut: number, channelsOut: number,
filterSize: number filterSize: number,
mappedPrefix: string
): PointwiseConvParams { ): PointwiseConvParams {
const { const {
filters, filters,
bias bias
} = extractConvParams(channelsIn, channelsOut, filterSize) } = extractConvParams(channelsIn, channelsOut, filterSize, mappedPrefix, true)
return { return {
filters, filters,
...@@ -57,115 +73,118 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) ...@@ -57,115 +73,118 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
function extractConvPairParams( function extractConvPairParams(
channelsIn: number, channelsIn: number,
channelsOut: number channelsOut: number,
mappedPrefix: string
): MobileNetV1.ConvPairParams { ): MobileNetV1.ConvPairParams {
const depthwise_conv_params = extractDepthwiseConvParams(channelsIn)
const pointwise_conv_params = extractPointwiseConvParams(channelsIn, channelsOut, 1)
return { const depthwise_conv = extractDepthwiseConvParams(channelsIn, `${mappedPrefix}/depthwise_conv`)
depthwise_conv_params, const pointwise_conv = extractPointwiseConvParams(channelsIn, channelsOut, 1, `${mappedPrefix}/pointwise_conv`)
pointwise_conv_params
} return { depthwise_conv, pointwise_conv }
} }
function extractMobilenetV1Params(): MobileNetV1.Params { function extractMobilenetV1Params(): MobileNetV1.Params {
const conv_0_params = extractPointwiseConvParams(3, 32, 3) const conv_0 = extractPointwiseConvParams(3, 32, 3, 'mobilenetv1/conv_0')
const channelNumPairs = [ const conv_1 = extractConvPairParams(32, 64, 'mobilenetv1/conv_1')
[32, 64], const conv_2 = extractConvPairParams(64, 128, 'mobilenetv1/conv_2')
[64, 128], const conv_3 = extractConvPairParams(128, 128, 'mobilenetv1/conv_3')
[128, 128], const conv_4 = extractConvPairParams(128, 256, 'mobilenetv1/conv_4')
[128, 256], const conv_5 = extractConvPairParams(256, 256, 'mobilenetv1/conv_5')
[256, 256], const conv_6 = extractConvPairParams(256, 512, 'mobilenetv1/conv_6')
[256, 512], const conv_7 = extractConvPairParams(512, 512, 'mobilenetv1/conv_7')
[512, 512], const conv_8 = extractConvPairParams(512, 512, 'mobilenetv1/conv_8')
[512, 512], const conv_9 = extractConvPairParams(512, 512, 'mobilenetv1/conv_9')
[512, 512], const conv_10 = extractConvPairParams(512, 512, 'mobilenetv1/conv_10')
[512, 512], const conv_11 = extractConvPairParams(512, 512, 'mobilenetv1/conv_11')
[512, 512], const conv_12 = extractConvPairParams(512, 1024, 'mobilenetv1/conv_12')
[512, 1024], const conv_13 = extractConvPairParams(1024, 1024, 'mobilenetv1/conv_13')
[1024, 1024]
]
const conv_pair_params = channelNumPairs.map(
([channelsIn, channelsOut]) => extractConvPairParams(channelsIn, channelsOut)
)
return { return {
conv_0_params, conv_0,
conv_pair_params conv_1,
conv_2,
conv_3,
conv_4,
conv_5,
conv_6,
conv_7,
conv_8,
conv_9,
conv_10,
conv_11,
conv_12,
conv_13
} }
} }
function extractPredictionLayerParams(): PredictionLayerParams { function extractPredictionLayerParams(): PredictionLayerParams {
const conv_0_params = extractPointwiseConvParams(1024, 256, 1) const conv_0 = extractPointwiseConvParams(1024, 256, 1, 'prediction_layer/conv_0')
const conv_1_params = extractPointwiseConvParams(256, 512, 3) const conv_1 = extractPointwiseConvParams(256, 512, 3, 'prediction_layer/conv_1')
const conv_2_params = extractPointwiseConvParams(512, 128, 1) const conv_2 = extractPointwiseConvParams(512, 128, 1, 'prediction_layer/conv_2')
const conv_3_params = extractPointwiseConvParams(128, 256, 3) const conv_3 = extractPointwiseConvParams(128, 256, 3, 'prediction_layer/conv_3')
const conv_4_params = extractPointwiseConvParams(256, 128, 1) const conv_4 = extractPointwiseConvParams(256, 128, 1, 'prediction_layer/conv_4')
const conv_5_params = extractPointwiseConvParams(128, 256, 3) const conv_5 = extractPointwiseConvParams(128, 256, 3, 'prediction_layer/conv_5')
const conv_6_params = extractPointwiseConvParams(256, 64, 1) const conv_6 = extractPointwiseConvParams(256, 64, 1, 'prediction_layer/conv_6')
const conv_7_params = extractPointwiseConvParams(64, 128, 3) const conv_7 = extractPointwiseConvParams(64, 128, 3, 'prediction_layer/conv_7')
const box_encoding_0_predictor_params = extractConvParams(512, 12, 1) const box_encoding_0_predictor = extractConvParams(512, 12, 1, 'prediction_layer/box_predictor_0/box_encoding_predictor')
const class_predictor_0_params = extractConvParams(512, 9, 1) const class_predictor_0 = extractConvParams(512, 9, 1, 'prediction_layer/box_predictor_0/class_predictor')
const box_encoding_1_predictor_params = extractConvParams(1024, 24, 1) const box_encoding_1_predictor = extractConvParams(1024, 24, 1, 'prediction_layer/box_predictor_1/box_encoding_predictor')
const class_predictor_1_params = extractConvParams(1024, 18, 1) const class_predictor_1 = extractConvParams(1024, 18, 1, 'prediction_layer/box_predictor_1/class_predictor')
const box_encoding_2_predictor_params = extractConvParams(512, 24, 1) const box_encoding_2_predictor = extractConvParams(512, 24, 1, 'prediction_layer/box_predictor_2/box_encoding_predictor')
const class_predictor_2_params = extractConvParams(512, 18, 1) const class_predictor_2 = extractConvParams(512, 18, 1, 'prediction_layer/box_predictor_2/class_predictor')
const box_encoding_3_predictor_params = extractConvParams(256, 24, 1) const box_encoding_3_predictor = extractConvParams(256, 24, 1, 'prediction_layer/box_predictor_3/box_encoding_predictor')
const class_predictor_3_params = extractConvParams(256, 18, 1) const class_predictor_3 = extractConvParams(256, 18, 1, 'prediction_layer/box_predictor_3/class_predictor')
const box_encoding_4_predictor_params = extractConvParams(256, 24, 1) const box_encoding_4_predictor = extractConvParams(256, 24, 1, 'prediction_layer/box_predictor_4/box_encoding_predictor')
const class_predictor_4_params = extractConvParams(256, 18, 1) const class_predictor_4 = extractConvParams(256, 18, 1, 'prediction_layer/box_predictor_4/class_predictor')
const box_encoding_5_predictor_params = extractConvParams(128, 24, 1) const box_encoding_5_predictor = extractConvParams(128, 24, 1, 'prediction_layer/box_predictor_5/box_encoding_predictor')
const class_predictor_5_params = extractConvParams(128, 18, 1) const class_predictor_5 = extractConvParams(128, 18, 1, 'prediction_layer/box_predictor_5/class_predictor')
const box_predictor_0_params = { const box_predictor_0 = {
box_encoding_predictor_params: box_encoding_0_predictor_params, box_encoding_predictor: box_encoding_0_predictor,
class_predictor_params: class_predictor_0_params class_predictor: class_predictor_0
} }
const box_predictor_1_params = { const box_predictor_1 = {
box_encoding_predictor_params: box_encoding_1_predictor_params, box_encoding_predictor: box_encoding_1_predictor,
class_predictor_params: class_predictor_1_params class_predictor: class_predictor_1
} }
const box_predictor_2_params = { const box_predictor_2 = {
box_encoding_predictor_params: box_encoding_2_predictor_params, box_encoding_predictor: box_encoding_2_predictor,
class_predictor_params: class_predictor_2_params class_predictor: class_predictor_2
} }
const box_predictor_3_params = { const box_predictor_3 = {
box_encoding_predictor_params: box_encoding_3_predictor_params, box_encoding_predictor: box_encoding_3_predictor,
class_predictor_params: class_predictor_3_params class_predictor: class_predictor_3
} }
const box_predictor_4_params = { const box_predictor_4 = {
box_encoding_predictor_params: box_encoding_4_predictor_params, box_encoding_predictor: box_encoding_4_predictor,
class_predictor_params: class_predictor_4_params class_predictor: class_predictor_4
} }
const box_predictor_5_params = { const box_predictor_5 = {
box_encoding_predictor_params: box_encoding_5_predictor_params, box_encoding_predictor: box_encoding_5_predictor,
class_predictor_params: class_predictor_5_params class_predictor: class_predictor_5
} }
return { return {
conv_0_params, conv_0,
conv_1_params, conv_1,
conv_2_params, conv_2,
conv_3_params, conv_3,
conv_4_params, conv_4,
conv_5_params, conv_5,
conv_6_params, conv_6,
conv_7_params, conv_7,
box_predictor_0_params, box_predictor_0,
box_predictor_1_params, box_predictor_1,
box_predictor_2_params, box_predictor_2,
box_predictor_3_params, box_predictor_3,
box_predictor_4_params, box_predictor_4,
box_predictor_5_params box_predictor_5
} }
} }
return { return {
extractMobilenetV1Params, extractMobilenetV1Params,
extractPredictionLayerParams extractPredictionLayerParams
...@@ -173,7 +192,10 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) ...@@ -173,7 +192,10 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
} }
export function extractParams(weights: Float32Array): NetParams { export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights
...@@ -182,25 +204,30 @@ export function extractParams(weights: Float32Array): NetParams { ...@@ -182,25 +204,30 @@ export function extractParams(weights: Float32Array): NetParams {
const { const {
extractMobilenetV1Params, extractMobilenetV1Params,
extractPredictionLayerParams extractPredictionLayerParams
} = extractorsFactory(extractWeights) } = extractorsFactory(extractWeights, paramMappings)
const mobilenetv1_params = extractMobilenetV1Params() const mobilenetv1 = extractMobilenetV1Params()
const prediction_layer_params = extractPredictionLayerParams() const prediction_layer = extractPredictionLayerParams()
const extra_dim = tf.tensor3d( const extra_dim = tf.tensor3d(
extractWeights(5118 * 4), extractWeights(5118 * 4),
[1, 5118, 4] [1, 5118, 4]
) )
const output_layer_params = { const output_layer = {
extra_dim extra_dim
} }
paramMappings.push({ paramPath: 'output_layer/extra_dim' })
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
} }
return { return {
mobilenetv1_params, params: {
prediction_layer_params, mobilenetv1,
output_layer_params prediction_layer,
output_layer
},
paramMappings
} }
} }
\ No newline at end of file
import { isTensor1D, isTensor4D, isTensor3D } from '../commons/isTensor'; import { tf } from '..';
import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensors';
import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory';
import { isTensor1D, isTensor3D, isTensor4D } from '../commons/isTensor';
import { loadWeightMap } from '../commons/loadWeightMap'; import { loadWeightMap } from '../commons/loadWeightMap';
import { BoxPredictionParams, MobileNetV1, PointwiseConvParams, PredictionLayerParams } from './types'; import { ConvParams, ParamMapping } from '../commons/types';
import { BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
const DEFAULT_MODEL_NAME = 'face_detection_model' const DEFAULT_MODEL_NAME = 'face_detection_model'
function extractorsFactory(weightMap: any) { function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
function extractPointwiseConvParams(prefix: string, idx: number): PointwiseConvParams { const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
const pointwise_conv_params = { function extractPointwiseConvParams(prefix: string, idx: number, mappedPrefix: string): PointwiseConvParams {
filters: weightMap[`${prefix}/Conv2d_${idx}_pointwise/weights`],
batch_norm_offset: weightMap[`${prefix}/Conv2d_${idx}_pointwise/convolution_bn_offset`]
}
if (!isTensor4D(pointwise_conv_params.filters)) {
throw new Error(`expected weightMap[${prefix}/Conv2d_${idx}_pointwise/weights] to be a Tensor4D, instead have ${pointwise_conv_params.filters}`)
}
if (!isTensor1D(pointwise_conv_params.batch_norm_offset)) { const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/Conv2d_${idx}_pointwise/weights`, 4, `${mappedPrefix}/filters`)
throw new Error(`expected weightMap[${prefix}/Conv2d_${idx}_pointwise/convolution_bn_offset] to be a Tensor1D, instead have ${pointwise_conv_params.batch_norm_offset}`) const batch_norm_offset = extractWeightEntry<tf.Tensor1D>(`${prefix}/Conv2d_${idx}_pointwise/convolution_bn_offset`, 1, `${mappedPrefix}/batch_norm_offset`)
}
return pointwise_conv_params return { filters, batch_norm_offset }
} }
function extractConvPairParams(idx: number): MobileNetV1.ConvPairParams { function extractConvPairParams(idx: number): MobileNetV1.ConvPairParams {
const depthwise_conv_params = { const mappedPrefix = `mobilenetv1/conv_${idx}`
filters: weightMap[`MobilenetV1/Conv2d_${idx}_depthwise/depthwise_weights`], const prefixDepthwiseConv = `MobilenetV1/Conv2d_${idx}_depthwise`
batch_norm_scale: weightMap[`MobilenetV1/Conv2d_${idx}_depthwise/BatchNorm/gamma`], const mappedPrefixDepthwiseConv = `${mappedPrefix}/depthwise_conv`
batch_norm_offset: weightMap[`MobilenetV1/Conv2d_${idx}_depthwise/BatchNorm/beta`], const mappedPrefixPointwiseConv = `${mappedPrefix}/pointwise_conv`
batch_norm_mean: weightMap[`MobilenetV1/Conv2d_${idx}_depthwise/BatchNorm/moving_mean`],
batch_norm_variance: weightMap[`MobilenetV1/Conv2d_${idx}_depthwise/BatchNorm/moving_variance`],
}
if (!isTensor4D(depthwise_conv_params.filters)) {
throw new Error(`expected weightMap[MobilenetV1/Conv2d_${idx}_depthwise/depthwise_weights] to be a Tensor4D, instead have ${depthwise_conv_params.filters}`)
}
if (!isTensor1D(depthwise_conv_params.batch_norm_scale)) {
throw new Error(`expected weightMap[MobilenetV1/Conv2d_${idx}_depthwise/BatchNorm/gamma] to be a Tensor1D, instead have ${depthwise_conv_params.batch_norm_scale}`)
}
if (!isTensor1D(depthwise_conv_params.batch_norm_offset)) {
throw new Error(`expected weightMap[MobilenetV1/Conv2d_${idx}_depthwise/BatchNorm/beta] to be a Tensor1D, instead have ${depthwise_conv_params.batch_norm_offset}`)
}
if (!isTensor1D(depthwise_conv_params.batch_norm_mean)) {
throw new Error(`expected weightMap[MobilenetV1/Conv2d_${idx}_depthwise/BatchNorm/moving_mean] to be a Tensor1D, instead have ${depthwise_conv_params.batch_norm_mean}`)
}
if (!isTensor1D(depthwise_conv_params.batch_norm_variance)) { const filters = extractWeightEntry<tf.Tensor4D>(`${prefixDepthwiseConv}/depthwise_weights`, 4, `${mappedPrefixDepthwiseConv}/filters`)
throw new Error(`expected weightMap[MobilenetV1/Conv2d_${idx}_depthwise/BatchNorm/moving_variance] to be a Tensor1D, instead have ${depthwise_conv_params.batch_norm_variance}`) const batch_norm_scale = extractWeightEntry<tf.Tensor1D>(`${prefixDepthwiseConv}/BatchNorm/gamma`, 1, `${mappedPrefixDepthwiseConv}/batch_norm_scale`)
} const batch_norm_offset = extractWeightEntry<tf.Tensor1D>(`${prefixDepthwiseConv}/BatchNorm/beta`, 1, `${mappedPrefixDepthwiseConv}/batch_norm_offset`)
const batch_norm_mean = extractWeightEntry<tf.Tensor1D>(`${prefixDepthwiseConv}/BatchNorm/moving_mean`, 1, `${mappedPrefixDepthwiseConv}/batch_norm_mean`)
const batch_norm_variance = extractWeightEntry<tf.Tensor1D>(`${prefixDepthwiseConv}/BatchNorm/moving_variance`, 1, `${mappedPrefixDepthwiseConv}/batch_norm_variance`)
return { return {
depthwise_conv_params, depthwise_conv: {
pointwise_conv_params: extractPointwiseConvParams('MobilenetV1', idx) filters,
batch_norm_scale,
batch_norm_offset,
batch_norm_mean,
batch_norm_variance
},
pointwise_conv: extractPointwiseConvParams('MobilenetV1', idx, mappedPrefixPointwiseConv)
} }
} }
function extractMobilenetV1Params(): MobileNetV1.Params { function extractMobilenetV1Params(): MobileNetV1.Params {
return { return {
conv_0_params: extractPointwiseConvParams('MobilenetV1', 0), conv_0: extractPointwiseConvParams('MobilenetV1', 0, 'mobilenetv1/conv_0'),
conv_pair_params: Array(13).fill(0).map((_, i) => extractConvPairParams(i + 1)) conv_1: extractConvPairParams(1),
conv_2: extractConvPairParams(2),
conv_3: extractConvPairParams(3),
conv_4: extractConvPairParams(4),
conv_5: extractConvPairParams(5),
conv_6: extractConvPairParams(6),
conv_7: extractConvPairParams(7),
conv_8: extractConvPairParams(8),
conv_9: extractConvPairParams(9),
conv_10: extractConvPairParams(10),
conv_11: extractConvPairParams(11),
conv_12: extractConvPairParams(12),
conv_13: extractConvPairParams(13)
} }
} }
function extractBoxPredictorParams(idx: number): BoxPredictionParams { function extractConvParams(prefix: string, mappedPrefix: string): ConvParams {
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/weights`, 4, `${mappedPrefix}/filters`)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/biases`, 1, `${mappedPrefix}/bias`)
const params = { return { filters, bias }
box_encoding_predictor_params: { }
filters: weightMap[`Prediction/BoxPredictor_${idx}/BoxEncodingPredictor/weights`],
bias: weightMap[`Prediction/BoxPredictor_${idx}/BoxEncodingPredictor/biases`]
},
class_predictor_params: {
filters: weightMap[`Prediction/BoxPredictor_${idx}/ClassPredictor/weights`],
bias: weightMap[`Prediction/BoxPredictor_${idx}/ClassPredictor/biases`]
}
}
if (!isTensor4D(params.box_encoding_predictor_params.filters)) {
throw new Error(`expected weightMap[Prediction/BoxPredictor_${idx}/BoxEncodingPredictor/weights] to be a Tensor4D, instead have ${params.box_encoding_predictor_params.filters}`)
}
if (!isTensor1D(params.box_encoding_predictor_params.bias)) {
throw new Error(`expected weightMap[Prediction/BoxPredictor_${idx}/BoxEncodingPredictor/biases] to be a Tensor1D, instead have ${params.box_encoding_predictor_params.bias}`)
}
if (!isTensor4D(params.class_predictor_params.filters)) { function extractBoxPredictorParams(idx: number): BoxPredictionParams {
throw new Error(`expected weightMap[Prediction/BoxPredictor_${idx}/ClassPredictor/weights] to be a Tensor4D, instead have ${params.class_predictor_params.filters}`)
}
if (!isTensor1D(params.class_predictor_params.bias)) { const box_encoding_predictor = extractConvParams(
throw new Error(`expected weightMap[Prediction/BoxPredictor_${idx}/ClassPredictor/biases] to be a Tensor1D, instead have ${params.class_predictor_params.bias}`) `Prediction/BoxPredictor_${idx}/BoxEncodingPredictor`,
} `prediction_layer/box_predictor_${idx}/box_encoding_predictor`
)
const class_predictor = extractConvParams(
`Prediction/BoxPredictor_${idx}/ClassPredictor`,
`prediction_layer/box_predictor_${idx}/class_predictor`
)
return params return { box_encoding_predictor, class_predictor }
} }
function extractPredictionLayerParams(): PredictionLayerParams { function extractPredictionLayerParams(): PredictionLayerParams {
return { return {
conv_0_params: extractPointwiseConvParams('Prediction', 0), conv_0: extractPointwiseConvParams('Prediction', 0, 'prediction_layer/conv_0'),
conv_1_params: extractPointwiseConvParams('Prediction', 1), conv_1: extractPointwiseConvParams('Prediction', 1, 'prediction_layer/conv_1'),
conv_2_params: extractPointwiseConvParams('Prediction', 2), conv_2: extractPointwiseConvParams('Prediction', 2, 'prediction_layer/conv_2'),
conv_3_params: extractPointwiseConvParams('Prediction', 3), conv_3: extractPointwiseConvParams('Prediction', 3, 'prediction_layer/conv_3'),
conv_4_params: extractPointwiseConvParams('Prediction', 4), conv_4: extractPointwiseConvParams('Prediction', 4, 'prediction_layer/conv_4'),
conv_5_params: extractPointwiseConvParams('Prediction', 5), conv_5: extractPointwiseConvParams('Prediction', 5, 'prediction_layer/conv_5'),
conv_6_params: extractPointwiseConvParams('Prediction', 6), conv_6: extractPointwiseConvParams('Prediction', 6, 'prediction_layer/conv_6'),
conv_7_params: extractPointwiseConvParams('Prediction', 7), conv_7: extractPointwiseConvParams('Prediction', 7, 'prediction_layer/conv_7'),
box_predictor_0_params: extractBoxPredictorParams(0), box_predictor_0: extractBoxPredictorParams(0),
box_predictor_1_params: extractBoxPredictorParams(1), box_predictor_1: extractBoxPredictorParams(1),
box_predictor_2_params: extractBoxPredictorParams(2), box_predictor_2: extractBoxPredictorParams(2),
box_predictor_3_params: extractBoxPredictorParams(3), box_predictor_3: extractBoxPredictorParams(3),
box_predictor_4_params: extractBoxPredictorParams(4), box_predictor_4: extractBoxPredictorParams(4),
box_predictor_5_params: extractBoxPredictorParams(5) box_predictor_5: extractBoxPredictorParams(5)
} }
} }
...@@ -124,24 +110,34 @@ function extractorsFactory(weightMap: any) { ...@@ -124,24 +110,34 @@ function extractorsFactory(weightMap: any) {
} }
} }
export async function loadQuantizedParams(uri: string | undefined): Promise<any> {//Promise<NetParams> { export async function loadQuantizedParams(
uri: string | undefined
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME) const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
const paramMappings: ParamMapping[] = []
const { const {
extractMobilenetV1Params, extractMobilenetV1Params,
extractPredictionLayerParams extractPredictionLayerParams
} = extractorsFactory(weightMap) } = extractorsFactory(weightMap, paramMappings)
const extra_dim = weightMap['Output/extra_dim'] const extra_dim = weightMap['Output/extra_dim']
paramMappings.push({ originalPath: 'Output/extra_dim', paramPath: 'output_layer/extra_dim' })
if (!isTensor3D(extra_dim)) { if (!isTensor3D(extra_dim)) {
throw new Error(`expected weightMap['Output/extra_dim'] to be a Tensor3D, instead have ${extra_dim}`) throw new Error(`expected weightMap['Output/extra_dim'] to be a Tensor3D, instead have ${extra_dim}`)
} }
return { const params = {
mobilenetv1_params: extractMobilenetV1Params(), mobilenetv1: extractMobilenetV1Params(),
prediction_layer_params: extractPredictionLayerParams(), prediction_layer: extractPredictionLayerParams(),
output_layer_params: { output_layer: {
extra_dim extra_dim
} }
} }
disposeUnusedWeightTensors(weightMap, paramMappings)
return { params, paramMappings }
} }
\ No newline at end of file
...@@ -34,13 +34,29 @@ export function mobileNetV1(x: tf.Tensor4D, params: MobileNetV1.Params) { ...@@ -34,13 +34,29 @@ export function mobileNetV1(x: tf.Tensor4D, params: MobileNetV1.Params) {
return tf.tidy(() => { return tf.tidy(() => {
let conv11 = null let conv11 = null
let out = pointwiseConvLayer(x, params.conv_0_params, [2, 2]) let out = pointwiseConvLayer(x, params.conv_0, [2, 2])
params.conv_pair_params.forEach((param, i) => { const convPairParams = [
params.conv_1,
params.conv_2,
params.conv_3,
params.conv_4,
params.conv_5,
params.conv_6,
params.conv_7,
params.conv_8,
params.conv_9,
params.conv_10,
params.conv_11,
params.conv_12,
params.conv_13
]
convPairParams.forEach((param, i) => {
const layerIdx = i + 1 const layerIdx = i + 1
const depthwiseConvStrides = getStridesForLayerIdx(layerIdx) const depthwiseConvStrides = getStridesForLayerIdx(layerIdx)
out = depthwiseConvLayer(out, param.depthwise_conv_params, depthwiseConvStrides) out = depthwiseConvLayer(out, param.depthwise_conv, depthwiseConvStrides)
out = pointwiseConvLayer(out, param.pointwise_conv_params, [1, 1]) out = pointwiseConvLayer(out, param.pointwise_conv, [1, 1])
if (layerIdx === 11) { if (layerIdx === 11) {
conv11 = out conv11 = out
} }
......
...@@ -11,21 +11,21 @@ export function predictionLayer( ...@@ -11,21 +11,21 @@ export function predictionLayer(
) { ) {
return tf.tidy(() => { return tf.tidy(() => {
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1]) const conv0 = pointwiseConvLayer(x, params.conv_0, [1, 1])
const conv1 = pointwiseConvLayer(conv0, params.conv_1_params, [2, 2]) const conv1 = pointwiseConvLayer(conv0, params.conv_1, [2, 2])
const conv2 = pointwiseConvLayer(conv1, params.conv_2_params, [1, 1]) const conv2 = pointwiseConvLayer(conv1, params.conv_2, [1, 1])
const conv3 = pointwiseConvLayer(conv2, params.conv_3_params, [2, 2]) const conv3 = pointwiseConvLayer(conv2, params.conv_3, [2, 2])
const conv4 = pointwiseConvLayer(conv3, params.conv_4_params, [1, 1]) const conv4 = pointwiseConvLayer(conv3, params.conv_4, [1, 1])
const conv5 = pointwiseConvLayer(conv4, params.conv_5_params, [2, 2]) const conv5 = pointwiseConvLayer(conv4, params.conv_5, [2, 2])
const conv6 = pointwiseConvLayer(conv5, params.conv_6_params, [1, 1]) const conv6 = pointwiseConvLayer(conv5, params.conv_6, [1, 1])
const conv7 = pointwiseConvLayer(conv6, params.conv_7_params, [2, 2]) const conv7 = pointwiseConvLayer(conv6, params.conv_7, [2, 2])
const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0_params) const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0)
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params) const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1)
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params) const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2)
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params) const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3)
const boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4_params) const boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4)
const boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5_params) const boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5)
const boxPredictions = tf.concat([ const boxPredictions = tf.concat([
boxPrediction0.boxPredictionEncoding, boxPrediction0.boxPredictionEncoding,
......
...@@ -18,37 +18,49 @@ export namespace MobileNetV1 { ...@@ -18,37 +18,49 @@ export namespace MobileNetV1 {
} }
export type ConvPairParams = { export type ConvPairParams = {
depthwise_conv_params: DepthwiseConvParams depthwise_conv: DepthwiseConvParams
pointwise_conv_params: PointwiseConvParams pointwise_conv: PointwiseConvParams
} }
export type Params = { export type Params = {
conv_0_params: PointwiseConvParams conv_0: PointwiseConvParams
conv_pair_params: ConvPairParams[] conv_1: ConvPairParams
conv_2: ConvPairParams
conv_3: ConvPairParams
conv_4: ConvPairParams
conv_5: ConvPairParams
conv_6: ConvPairParams
conv_7: ConvPairParams
conv_8: ConvPairParams
conv_9: ConvPairParams
conv_10: ConvPairParams
conv_11: ConvPairParams
conv_12: ConvPairParams
conv_13: ConvPairParams
} }
} }
export type BoxPredictionParams = { export type BoxPredictionParams = {
box_encoding_predictor_params: ConvParams box_encoding_predictor: ConvParams
class_predictor_params: ConvParams class_predictor: ConvParams
} }
export type PredictionLayerParams = { export type PredictionLayerParams = {
conv_0_params: PointwiseConvParams conv_0: PointwiseConvParams
conv_1_params: PointwiseConvParams conv_1: PointwiseConvParams
conv_2_params: PointwiseConvParams conv_2: PointwiseConvParams
conv_3_params: PointwiseConvParams conv_3: PointwiseConvParams
conv_4_params: PointwiseConvParams conv_4: PointwiseConvParams
conv_5_params: PointwiseConvParams conv_5: PointwiseConvParams
conv_6_params: PointwiseConvParams conv_6: PointwiseConvParams
conv_7_params: PointwiseConvParams conv_7: PointwiseConvParams
box_predictor_0_params: BoxPredictionParams box_predictor_0: BoxPredictionParams
box_predictor_1_params: BoxPredictionParams box_predictor_1: BoxPredictionParams
box_predictor_2_params: BoxPredictionParams box_predictor_2: BoxPredictionParams
box_predictor_3_params: BoxPredictionParams box_predictor_3: BoxPredictionParams
box_predictor_4_params: BoxPredictionParams box_predictor_4: BoxPredictionParams
box_predictor_5_params: BoxPredictionParams box_predictor_5: BoxPredictionParams
} }
export type OutputLayerParams = { export type OutputLayerParams = {
...@@ -56,7 +68,7 @@ export type OutputLayerParams = { ...@@ -56,7 +68,7 @@ export type OutputLayerParams = {
} }
export type NetParams = { export type NetParams = {
mobilenetv1_params: MobileNetV1.Params, mobilenetv1: MobileNetV1.Params,
prediction_layer_params: PredictionLayerParams, prediction_layer: PredictionLayerParams,
output_layer_params: OutputLayerParams output_layer: OutputLayerParams
} }
...@@ -24,36 +24,13 @@ function maxPool(x: tf.Tensor4D, strides: [number, number] = [2, 2]): tf.Tensor4 ...@@ -24,36 +24,13 @@ function maxPool(x: tf.Tensor4D, strides: [number, number] = [2, 2]): tf.Tensor4
export class FaceLandmarkNet extends NeuralNetwork<NetParams> { export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> { constructor() {
if (weightsOrUrl instanceof Float32Array) { super('FaceLandmarkNet')
this.extractWeights(weightsOrUrl)
return
}
if (weightsOrUrl && typeof weightsOrUrl !== 'string') {
throw new Error('FaceLandmarkNet.load - expected model uri, or weights as Float32Array')
}
const {
paramMappings,
params
} = await loadQuantizedParams(weightsOrUrl)
this._paramMappings = paramMappings
this._params = params
}
public extractWeights(weights: Float32Array) {
const {
paramMappings,
params
} = extractParams(weights)
this._paramMappings = paramMappings
this._params = params
} }
public forwardInput(input: NetInput): tf.Tensor2D { public forwardInput(input: NetInput): tf.Tensor2D {
const params = this._params
const { params } = this
if (!params) { if (!params) {
throw new Error('FaceLandmarkNet - load model before inference') throw new Error('FaceLandmarkNet - load model before inference')
...@@ -62,20 +39,20 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> { ...@@ -62,20 +39,20 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
return tf.tidy(() => { return tf.tidy(() => {
const batchTensor = input.toBatchTensor(128, true) const batchTensor = input.toBatchTensor(128, true)
let out = conv(batchTensor, params.conv0_params) let out = conv(batchTensor, params.conv0)
out = maxPool(out) out = maxPool(out)
out = conv(out, params.conv1_params) out = conv(out, params.conv1)
out = conv(out, params.conv2_params) out = conv(out, params.conv2)
out = maxPool(out) out = maxPool(out)
out = conv(out, params.conv3_params) out = conv(out, params.conv3)
out = conv(out, params.conv4_params) out = conv(out, params.conv4)
out = maxPool(out) out = maxPool(out)
out = conv(out, params.conv5_params) out = conv(out, params.conv5)
out = conv(out, params.conv6_params) out = conv(out, params.conv6)
out = maxPool(out, [1, 1]) out = maxPool(out, [1, 1])
out = conv(out, params.conv7_params) out = conv(out, params.conv7)
const fc0 = tf.relu(fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0_params)) const fc0 = tf.relu(fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0))
const fc1 = fullyConnectedLayer(fc0, params.fc1_params) const fc1 = fullyConnectedLayer(fc0, params.fc1)
const createInterleavedTensor = (fillX: number, fillY: number) => const createInterleavedTensor = (fillX: number, fillY: number) =>
tf.stack([ tf.stack([
...@@ -145,4 +122,12 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> { ...@@ -145,4 +122,12 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
? landmarksForBatch ? landmarksForBatch
: landmarksForBatch[0] : landmarksForBatch[0]
} }
protected loadQuantizedParams(uri: string | undefined) {
return loadQuantizedParams(uri)
}
protected extractParams(weights: Float32Array) {
return extractParams(weights)
}
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
import { extractWeightsFactory } from '../commons/extractWeightsFactory'; import { extractWeightsFactory } from '../commons/extractWeightsFactory';
import { ParamMapping } from '../commons/types'; import { ConvParams, ParamMapping } from '../commons/types';
import { FCParams, NetParams } from './types'; import { FCParams, NetParams } from './types';
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } { export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [] const paramMappings: ParamMapping[] = []
const { const {
...@@ -13,9 +13,29 @@ export function extractParams(weights: Float32Array): { params: NetParams, param ...@@ -13,9 +13,29 @@ export function extractParams(weights: Float32Array): { params: NetParams, param
getRemainingWeights getRemainingWeights
} = extractWeightsFactory(weights) } = extractWeightsFactory(weights)
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings) function extractConvParams(
channelsIn: number,
channelsOut: number,
filterSize: number,
mappedPrefix: string
): ConvParams {
const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut]
)
const bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` }
)
return { filters, bias }
}
function extractFcParams(channelsIn: number, channelsOut: number, mappedPrefix: string): FCParams { function extractFcParams(channelsIn: number, channelsOut: number, mappedPrefix: string): FCParams {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]) const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
const fc_bias = tf.tensor1d(extractWeights(channelsOut)) const fc_bias = tf.tensor1d(extractWeights(channelsOut))
...@@ -30,16 +50,16 @@ export function extractParams(weights: Float32Array): { params: NetParams, param ...@@ -30,16 +50,16 @@ export function extractParams(weights: Float32Array): { params: NetParams, param
} }
} }
const conv0_params = extractConvParams(3, 32, 3, 'conv0_params') const conv0 = extractConvParams(3, 32, 3, 'conv0')
const conv1_params = extractConvParams(32, 64, 3, 'conv1_params') const conv1 = extractConvParams(32, 64, 3, 'conv1')
const conv2_params = extractConvParams(64, 64, 3, 'conv2_params') const conv2 = extractConvParams(64, 64, 3, 'conv2')
const conv3_params = extractConvParams(64, 64, 3, 'conv3_params') const conv3 = extractConvParams(64, 64, 3, 'conv3')
const conv4_params = extractConvParams(64, 64, 3, 'conv4_params') const conv4 = extractConvParams(64, 64, 3, 'conv4')
const conv5_params = extractConvParams(64, 128, 3, 'conv5_params') const conv5 = extractConvParams(64, 128, 3, 'conv5')
const conv6_params = extractConvParams(128, 128, 3, 'conv6_params') const conv6 = extractConvParams(128, 128, 3, 'conv6')
const conv7_params = extractConvParams(128, 256, 3, 'conv7_params') const conv7 = extractConvParams(128, 256, 3, 'conv7')
const fc0_params = extractFcParams(6400, 1024, 'fc0_params') const fc0 = extractFcParams(6400, 1024, 'fc0')
const fc1_params = extractFcParams(1024, 136, 'fc1_params') const fc1 = extractFcParams(1024, 136, 'fc1')
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
...@@ -48,16 +68,16 @@ export function extractParams(weights: Float32Array): { params: NetParams, param ...@@ -48,16 +68,16 @@ export function extractParams(weights: Float32Array): { params: NetParams, param
return { return {
paramMappings, paramMappings,
params: { params: {
conv0_params, conv0,
conv1_params, conv1,
conv2_params, conv2,
conv3_params, conv3,
conv4_params, conv4,
conv5_params, conv5,
conv6_params, conv6,
conv7_params, conv7,
fc0_params, fc0,
fc1_params fc1
} }
} }
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { extractWeightEntry } from '../commons/extractWeightEntry'; import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensors';
import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory';
import { loadWeightMap } from '../commons/loadWeightMap'; import { loadWeightMap } from '../commons/loadWeightMap';
import { ConvParams, ParamMapping } from '../commons/types'; import { ConvParams, ParamMapping } from '../commons/types';
import { FCParams, NetParams } from './types'; import { FCParams, NetParams } from './types';
...@@ -9,30 +10,20 @@ const DEFAULT_MODEL_NAME = 'face_landmark_68_model' ...@@ -9,30 +10,20 @@ const DEFAULT_MODEL_NAME = 'face_landmark_68_model'
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) { function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
function extractConvParams(prefix: string, mappedPrefix: string): ConvParams { function extractConvParams(prefix: string, mappedPrefix: string): ConvParams {
const filtersEntry = extractWeightEntry(weightMap, `${prefix}/kernel`, 4) const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/kernel`, 4, `${mappedPrefix}/filters`)
const biasEntry = extractWeightEntry(weightMap, `${prefix}/bias`, 1) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1, `${mappedPrefix}/bias`)
paramMappings.push(
{ originalPath: filtersEntry.path, paramPath: `${mappedPrefix}/filters` }, return { filters, bias }
{ originalPath: biasEntry.path, paramPath: `${mappedPrefix}/bias` }
)
return {
filters: filtersEntry.tensor as tf.Tensor4D,
bias: biasEntry.tensor as tf.Tensor1D
}
} }
function extractFcParams(prefix: string, mappedPrefix: string): FCParams { function extractFcParams(prefix: string, mappedPrefix: string): FCParams {
const weightsEntry = extractWeightEntry(weightMap, `${prefix}/kernel`, 2) const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/kernel`, 2, `${mappedPrefix}/weights`)
const biasEntry = extractWeightEntry(weightMap, `${prefix}/bias`, 1) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1, `${mappedPrefix}/bias`)
paramMappings.push(
{ originalPath: weightsEntry.path, paramPath: `${mappedPrefix}/weights` }, return { weights, bias }
{ originalPath: biasEntry.path, paramPath: `${mappedPrefix}/bias` }
)
return {
weights: weightsEntry.tensor as tf.Tensor2D,
bias: biasEntry.tensor as tf.Tensor1D
}
} }
return { return {
...@@ -54,17 +45,19 @@ export async function loadQuantizedParams( ...@@ -54,17 +45,19 @@ export async function loadQuantizedParams(
} = extractorsFactory(weightMap, paramMappings) } = extractorsFactory(weightMap, paramMappings)
const params = { const params = {
conv0_params: extractConvParams('conv2d_0', 'conv0_params'), conv0: extractConvParams('conv2d_0', 'conv0'),
conv1_params: extractConvParams('conv2d_1', 'conv1_params'), conv1: extractConvParams('conv2d_1', 'conv1'),
conv2_params: extractConvParams('conv2d_2', 'conv2_params'), conv2: extractConvParams('conv2d_2', 'conv2'),
conv3_params: extractConvParams('conv2d_3', 'conv3_params'), conv3: extractConvParams('conv2d_3', 'conv3'),
conv4_params: extractConvParams('conv2d_4', 'conv4_params'), conv4: extractConvParams('conv2d_4', 'conv4'),
conv5_params: extractConvParams('conv2d_5', 'conv5_params'), conv5: extractConvParams('conv2d_5', 'conv5'),
conv6_params: extractConvParams('conv2d_6', 'conv6_params'), conv6: extractConvParams('conv2d_6', 'conv6'),
conv7_params: extractConvParams('conv2d_7', 'conv7_params'), conv7: extractConvParams('conv2d_7', 'conv7'),
fc0_params: extractFcParams('dense', 'fc0_params'), fc0: extractFcParams('dense', 'fc0'),
fc1_params: extractFcParams('logits', 'fc1_params') fc1: extractFcParams('logits', 'fc1')
} }
disposeUnusedWeightTensors(weightMap, paramMappings)
return { params, paramMappings } return { params, paramMappings }
} }
\ No newline at end of file
...@@ -8,14 +8,14 @@ export type FCParams = { ...@@ -8,14 +8,14 @@ export type FCParams = {
} }
export type NetParams = { export type NetParams = {
conv0_params: ConvParams conv0: ConvParams
conv1_params: ConvParams conv1: ConvParams
conv2_params: ConvParams conv2: ConvParams
conv3_params: ConvParams conv3: ConvParams
conv4_params: ConvParams conv4: ConvParams
conv5_params: ConvParams conv5: ConvParams
conv6_params: ConvParams conv6: ConvParams
conv7_params: ConvParams conv7: ConvParams
fc0_params: FCParams fc0: FCParams
fc1_params: FCParams fc1: FCParams
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NeuralNetwork } from '../commons/NeuralNetwork';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { toNetInput } from '../toNetInput'; import { toNetInput } from '../toNetInput';
import { TNetInput } from '../types'; import { TNetInput } from '../types';
...@@ -10,28 +11,17 @@ import { normalize } from './normalize'; ...@@ -10,28 +11,17 @@ import { normalize } from './normalize';
import { residual, residualDown } from './residualLayer'; import { residual, residualDown } from './residualLayer';
import { NetParams } from './types'; import { NetParams } from './types';
export class FaceRecognitionNet { export class FaceRecognitionNet extends NeuralNetwork<NetParams> {
private _params: NetParams constructor() {
super('FaceRecognitionNet')
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> {
if (weightsOrUrl instanceof Float32Array) {
this.extractWeights(weightsOrUrl)
return
}
if (weightsOrUrl && typeof weightsOrUrl !== 'string') {
throw new Error('FaceLandmarkNet.load - expected model uri, or weights as Float32Array')
}
this._params = await loadQuantizedParams(weightsOrUrl)
}
public extractWeights(weights: Float32Array) {
this._params = extractParams(weights)
} }
public forwardInput(input: NetInput): tf.Tensor2D { public forwardInput(input: NetInput): tf.Tensor2D {
if (!this._params) {
const { params } = this
if (!params) {
throw new Error('FaceRecognitionNet - load model before inference') throw new Error('FaceRecognitionNet - load model before inference')
} }
...@@ -40,29 +30,29 @@ export class FaceRecognitionNet { ...@@ -40,29 +30,29 @@ export class FaceRecognitionNet {
const normalized = normalize(batchTensor) const normalized = normalize(batchTensor)
let out = convDown(normalized, this._params.conv32_down) let out = convDown(normalized, params.conv32_down)
out = tf.maxPool(out, 3, 2, 'valid') out = tf.maxPool(out, 3, 2, 'valid')
out = residual(out, this._params.conv32_1) out = residual(out, params.conv32_1)
out = residual(out, this._params.conv32_2) out = residual(out, params.conv32_2)
out = residual(out, this._params.conv32_3) out = residual(out, params.conv32_3)
out = residualDown(out, this._params.conv64_down) out = residualDown(out, params.conv64_down)
out = residual(out, this._params.conv64_1) out = residual(out, params.conv64_1)
out = residual(out, this._params.conv64_2) out = residual(out, params.conv64_2)
out = residual(out, this._params.conv64_3) out = residual(out, params.conv64_3)
out = residualDown(out, this._params.conv128_down) out = residualDown(out, params.conv128_down)
out = residual(out, this._params.conv128_1) out = residual(out, params.conv128_1)
out = residual(out, this._params.conv128_2) out = residual(out, params.conv128_2)
out = residualDown(out, this._params.conv256_down) out = residualDown(out, params.conv256_down)
out = residual(out, this._params.conv256_1) out = residual(out, params.conv256_1)
out = residual(out, this._params.conv256_2) out = residual(out, params.conv256_2)
out = residualDown(out, this._params.conv256_down_out) out = residualDown(out, params.conv256_down_out)
const globalAvg = out.mean([1, 2]) as tf.Tensor2D const globalAvg = out.mean([1, 2]) as tf.Tensor2D
const fullyConnected = tf.matMul(globalAvg, this._params.fc) const fullyConnected = tf.matMul(globalAvg, params.fc)
return fullyConnected return fullyConnected
}) })
...@@ -89,4 +79,12 @@ export class FaceRecognitionNet { ...@@ -89,4 +79,12 @@ export class FaceRecognitionNet {
? faceDescriptorsForBatch ? faceDescriptorsForBatch
: faceDescriptorsForBatch[0] : faceDescriptorsForBatch[0]
} }
protected loadQuantizedParams(uri: string | undefined) {
return loadQuantizedParams(uri)
}
protected extractParams(weights: Float32Array) {
return extractParams(weights)
}
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { extractWeightsFactory } from '../commons/extractWeightsFactory'; import { extractWeightsFactory } from '../commons/extractWeightsFactory';
import { ExtractWeightsFunction } from '../commons/types'; import { ConvParams, ExtractWeightsFunction, ParamMapping } from '../commons/types';
import { isFloat } from '../utils'; import { isFloat } from '../utils';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types'; import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
function extractorsFactory(extractWeights: ExtractWeightsFunction) { function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D { function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {
const weights = extractWeights(numFilterValues) const weights = extractWeights(numFilterValues)
...@@ -15,15 +15,42 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction) { ...@@ -15,15 +15,42 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction) {
throw new Error(`depth has to be an integer: ${depth}, weights.length: ${weights.length}, numFilters: ${numFilters}, filterSize: ${filterSize}`) throw new Error(`depth has to be an integer: ${depth}, weights.length: ${weights.length}, numFilters: ${numFilters}, filterSize: ${filterSize}`)
} }
return tf.transpose( return tf.tidy(
tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]), () => tf.transpose(
[2, 3, 1, 0] tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]),
[2, 3, 1, 0]
)
) )
} }
function extractScaleLayerParams(numWeights: number): ScaleLayerParams { function extractConvParams(
numFilterValues: number,
numFilters: number,
filterSize: number,
mappedPrefix: string
): ConvParams {
const filters = extractFilterValues(numFilterValues, numFilters, filterSize)
const bias = tf.tensor1d(extractWeights(numFilters))
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` }
)
return { filters, bias }
}
function extractScaleLayerParams(numWeights: number, mappedPrefix: string): ScaleLayerParams {
const weights = tf.tensor1d(extractWeights(numWeights)) const weights = tf.tensor1d(extractWeights(numWeights))
const biases = tf.tensor1d(extractWeights(numWeights)) const biases = tf.tensor1d(extractWeights(numWeights))
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/biases` }
)
return { return {
weights, weights,
biases biases
...@@ -33,34 +60,28 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction) { ...@@ -33,34 +60,28 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction) {
function extractConvLayerParams( function extractConvLayerParams(
numFilterValues: number, numFilterValues: number,
numFilters: number, numFilters: number,
filterSize: number filterSize: number,
mappedPrefix: string
): ConvLayerParams { ): ConvLayerParams {
const conv_filters = extractFilterValues(numFilterValues, numFilters, filterSize)
const conv_bias = tf.tensor1d(extractWeights(numFilters))
const scale = extractScaleLayerParams(numFilters)
return { const conv = extractConvParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv`)
conv: { const scale = extractScaleLayerParams(numFilters, `${mappedPrefix}/scale`)
filters: conv_filters,
bias: conv_bias return { conv, scale }
},
scale
}
} }
function extractResidualLayerParams( function extractResidualLayerParams(
numFilterValues: number, numFilterValues: number,
numFilters: number, numFilters: number,
filterSize: number, filterSize: number,
mappedPrefix: string,
isDown: boolean = false isDown: boolean = false
): ResidualLayerParams { ): ResidualLayerParams {
const conv1: ConvLayerParams = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize)
const conv2: ConvLayerParams = extractConvLayerParams(numFilterValues, numFilters, filterSize)
return { const conv1 = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv1`)
conv1, const conv2 = extractConvLayerParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv2`)
conv2
} return { conv1, conv2 }
} }
return { return {
...@@ -70,43 +91,49 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction) { ...@@ -70,43 +91,49 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction) {
} }
export function extractParams(weights: Float32Array): NetParams { export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights
} = extractWeightsFactory(weights) } = extractWeightsFactory(weights)
const paramMappings: ParamMapping[] = []
const { const {
extractConvLayerParams, extractConvLayerParams,
extractResidualLayerParams extractResidualLayerParams
} = extractorsFactory(extractWeights) } = extractorsFactory(extractWeights, paramMappings)
const conv32_down = extractConvLayerParams(4704, 32, 7) const conv32_down = extractConvLayerParams(4704, 32, 7, 'conv32_down')
const conv32_1 = extractResidualLayerParams(9216, 32, 3) const conv32_1 = extractResidualLayerParams(9216, 32, 3, 'conv32_1')
const conv32_2 = extractResidualLayerParams(9216, 32, 3) const conv32_2 = extractResidualLayerParams(9216, 32, 3, 'conv32_2')
const conv32_3 = extractResidualLayerParams(9216, 32, 3) const conv32_3 = extractResidualLayerParams(9216, 32, 3, 'conv32_3')
const conv64_down = extractResidualLayerParams(36864, 64, 3, true) const conv64_down = extractResidualLayerParams(36864, 64, 3, 'conv64_down', true)
const conv64_1 = extractResidualLayerParams(36864, 64, 3) const conv64_1 = extractResidualLayerParams(36864, 64, 3, 'conv64_1')
const conv64_2 = extractResidualLayerParams(36864, 64, 3) const conv64_2 = extractResidualLayerParams(36864, 64, 3, 'conv64_2')
const conv64_3 = extractResidualLayerParams(36864, 64, 3) const conv64_3 = extractResidualLayerParams(36864, 64, 3, 'conv64_3')
const conv128_down = extractResidualLayerParams(147456, 128, 3, true) const conv128_down = extractResidualLayerParams(147456, 128, 3, 'conv128_down', true)
const conv128_1 = extractResidualLayerParams(147456, 128, 3) const conv128_1 = extractResidualLayerParams(147456, 128, 3, 'conv128_1')
const conv128_2 = extractResidualLayerParams(147456, 128, 3) const conv128_2 = extractResidualLayerParams(147456, 128, 3, 'conv128_2')
const conv256_down = extractResidualLayerParams(589824, 256, 3, true) const conv256_down = extractResidualLayerParams(589824, 256, 3, 'conv256_down', true)
const conv256_1 = extractResidualLayerParams(589824, 256, 3) const conv256_1 = extractResidualLayerParams(589824, 256, 3, 'conv256_1')
const conv256_2 = extractResidualLayerParams(589824, 256, 3) const conv256_2 = extractResidualLayerParams(589824, 256, 3, 'conv256_2')
const conv256_down_out = extractResidualLayerParams(589824, 256, 3) const conv256_down_out = extractResidualLayerParams(589824, 256, 3, 'conv256_down_out')
const fc = tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0]) const fc = tf.tidy(
() => tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0])
)
paramMappings.push({ paramPath: `fc` })
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
} }
return { const params = {
conv32_down, conv32_down,
conv32_1, conv32_1,
conv32_2, conv32_2,
...@@ -124,4 +151,6 @@ export function extractParams(weights: Float32Array): NetParams { ...@@ -124,4 +151,6 @@ export function extractParams(weights: Float32Array): NetParams {
conv256_down_out, conv256_down_out,
fc fc
} }
return { params, paramMappings }
} }
\ No newline at end of file
import { isTensor1D, isTensor2D, isTensor4D } from '../commons/isTensor'; import * as tf from '@tensorflow/tfjs-core';
import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensors';
import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory';
import { isTensor2D } from '../commons/isTensor';
import { loadWeightMap } from '../commons/loadWeightMap'; import { loadWeightMap } from '../commons/loadWeightMap';
import { ConvLayerParams, ResidualLayerParams, ScaleLayerParams } from './types'; import { ParamMapping } from '../commons/types';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
const DEFAULT_MODEL_NAME = 'face_recognition_model' const DEFAULT_MODEL_NAME = 'face_recognition_model'
function extractorsFactory(weightMap: any) { function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
function extractScaleLayerParams(prefix: string): ScaleLayerParams { const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
const params = {
weights: weightMap[`${prefix}/scale/weights`],
biases: weightMap[`${prefix}/scale/biases`]
}
if (!isTensor1D(params.weights)) { function extractScaleLayerParams(prefix: string): ScaleLayerParams {
throw new Error(`expected weightMap[${prefix}/scale/weights] to be a Tensor1D, instead have ${params.weights}`)
}
if (!isTensor1D(params.biases)) { const weights = extractWeightEntry<tf.Tensor1D>(`${prefix}/scale/weights`, 1)
throw new Error(`expected weightMap[${prefix}/scale/biases] to be a Tensor1D, instead have ${params.biases}`) const biases = extractWeightEntry<tf.Tensor1D>(`${prefix}/scale/biases`, 1)
}
return params return { weights, biases }
} }
function extractConvLayerParams(prefix: string): ConvLayerParams { function extractConvLayerParams(prefix: string): ConvLayerParams {
const params = {
filters: weightMap[`${prefix}/conv/filters`],
bias: weightMap[`${prefix}/conv/bias`]
}
if (!isTensor4D(params.filters)) {
throw new Error(`expected weightMap[${prefix}/conv/filters] to be a Tensor1D, instead have ${params.filters}`)
}
if (!isTensor1D(params.bias)) { const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/conv/filters`, 4)
throw new Error(`expected weightMap[${prefix}/conv/bias] to be a Tensor1D, instead have ${params.bias}`) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/conv/bias`, 1)
} const scale = extractScaleLayerParams(prefix)
return { return { conv: { filters, bias }, scale }
conv: params,
scale: extractScaleLayerParams(prefix)
}
} }
function extractResidualLayerParams(prefix: string): ResidualLayerParams { function extractResidualLayerParams(prefix: string): ResidualLayerParams {
...@@ -57,13 +44,17 @@ function extractorsFactory(weightMap: any) { ...@@ -57,13 +44,17 @@ function extractorsFactory(weightMap: any) {
} }
export async function loadQuantizedParams(uri: string | undefined): Promise<any> { export async function loadQuantizedParams(
uri: string | undefined
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME) const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
const paramMappings: ParamMapping[] = []
const { const {
extractConvLayerParams, extractConvLayerParams,
extractResidualLayerParams extractResidualLayerParams
} = extractorsFactory(weightMap) } = extractorsFactory(weightMap, paramMappings)
const conv32_down = extractConvLayerParams('conv32_down') const conv32_down = extractConvLayerParams('conv32_down')
const conv32_1 = extractResidualLayerParams('conv32_1') const conv32_1 = extractResidualLayerParams('conv32_1')
...@@ -85,12 +76,13 @@ export async function loadQuantizedParams(uri: string | undefined): Promise<any> ...@@ -85,12 +76,13 @@ export async function loadQuantizedParams(uri: string | undefined): Promise<any>
const conv256_down_out = extractResidualLayerParams('conv256_down_out') const conv256_down_out = extractResidualLayerParams('conv256_down_out')
const fc = weightMap['fc'] const fc = weightMap['fc']
paramMappings.push({ originalPath: 'fc', paramPath: 'fc' })
if (!isTensor2D(fc)) { if (!isTensor2D(fc)) {
throw new Error(`expected weightMap[fc] to be a Tensor2D, instead have ${fc}`) throw new Error(`expected weightMap[fc] to be a Tensor2D, instead have ${fc}`)
} }
return { const params = {
conv32_down, conv32_down,
conv32_1, conv32_1,
conv32_2, conv32_2,
...@@ -108,4 +100,8 @@ export async function loadQuantizedParams(uri: string | undefined): Promise<any> ...@@ -108,4 +100,8 @@ export async function loadQuantizedParams(uri: string | undefined): Promise<any>
conv256_down_out, conv256_down_out,
fc fc
} }
disposeUnusedWeightTensors(weightMap, paramMappings)
return { params, paramMappings }
} }
\ No newline at end of file
import { NeuralNetwork } from '../../../src/commons/NeuralNetwork';
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NeuralNetwork } from '../../../src/commons/NeuralNetwork';
class FakeNeuralNetwork extends NeuralNetwork<any> { class FakeNeuralNetwork extends NeuralNetwork<any> {
constructor( constructor(
convFilter: tf.Tensor = tf.tensor(0), convFilter: tf.Tensor = tf.tensor(0),
convBias: tf.Tensor = tf.tensor(0), convBias: tf.Tensor = tf.tensor(0),
fcWeights: tf.Tensor = tf.tensor(0) fcWeights: tf.Tensor = tf.tensor(0)
) { ) {
super() super('FakeNeuralNetwork')
this._params = { this._params = {
conv: { conv: {
filter: convFilter, filter: convFilter,
......
import * as faceapi from '../../../src'; import * as faceapi from '../../../src';
import { FaceDetection } from '../../../src/faceDetectionNet/FaceDetection'; import { FaceDetection } from '../../../src/faceDetectionNet/FaceDetection';
import { IRect } from '../../../src/Rect'; import { IRect } from '../../../src/Rect';
import { expectMaxDelta } from '../../utils'; import { expectAllTensorsReleased, expectMaxDelta } from '../../utils';
function expectRectClose( function expectRectClose(
result: IRect, result: IRect,
...@@ -110,4 +110,33 @@ describe('faceDetectionNet', () => { ...@@ -110,4 +110,33 @@ describe('faceDetectionNet', () => {
}) })
describe('no memory leaks', () => {
describe('NeuralNetwork, uncompressed model', () => {
it('disposes all param tensors', async () => {
await expectAllTensorsReleased(async () => {
const res = await fetch('base/weights/uncompressed/face_detection_model.weights')
const weights = new Float32Array(await res.arrayBuffer())
const net = faceapi.faceDetectionNet(weights)
net.dispose()
})
})
})
describe('NeuralNetwork, quantized model', () => {
it('disposes all param tensors', async () => {
await expectAllTensorsReleased(async () => {
const net = new faceapi.FaceDetectionNet()
await net.load('base/weights')
net.dispose()
})
})
})
})
}) })
\ No newline at end of file
...@@ -238,6 +238,31 @@ describe('faceLandmarkNet', () => { ...@@ -238,6 +238,31 @@ describe('faceLandmarkNet', () => {
await faceLandmarkNet.load('base/weights') await faceLandmarkNet.load('base/weights')
}) })
describe('NeuralNetwork, uncompressed model', () => {
it('disposes all param tensors', async () => {
await expectAllTensorsReleased(async () => {
const res = await fetch('base/weights/uncompressed/face_landmark_68_model.weights')
const weights = new Float32Array(await res.arrayBuffer())
const net = faceapi.faceLandmarkNet(weights)
net.dispose()
})
})
})
describe('NeuralNetwork, quantized model', () => {
it('disposes all param tensors', async () => {
await expectAllTensorsReleased(async () => {
const net = new faceapi.FaceLandmarkNet()
await net.load('base/weights')
net.dispose()
})
})
})
describe('forwardInput', () => { describe('forwardInput', () => {
it('single image element', async () => { it('single image element', async () => {
......
...@@ -166,6 +166,35 @@ describe('faceRecognitionNet', () => { ...@@ -166,6 +166,35 @@ describe('faceRecognitionNet', () => {
faceRecognitionNet = faceapi.faceRecognitionNet(weights) faceRecognitionNet = faceapi.faceRecognitionNet(weights)
}) })
afterAll(async () => {
faceRecognitionNet.dispose()
})
describe('NeuralNetwork, uncompressed model', () => {
it('disposes all param tensors', async () => {
await expectAllTensorsReleased(async () => {
const res = await fetch('base/weights/uncompressed/face_recognition_model.weights')
const weights = new Float32Array(await res.arrayBuffer())
const net = faceapi.faceRecognitionNet(weights)
net.dispose()
})
})
})
describe('NeuralNetwork, quantized model', () => {
it('disposes all param tensors', async () => {
await expectAllTensorsReleased(async () => {
const net = new faceapi.FaceRecognitionNet()
await net.load('base/weights')
net.dispose()
})
})
})
describe('forwardInput', () => { describe('forwardInput', () => {
it('single image element', async () => { it('single image element', async () => {
...@@ -292,5 +321,4 @@ describe('faceRecognitionNet', () => { ...@@ -292,5 +321,4 @@ describe('faceRecognitionNet', () => {
}) })
}) })
}) })
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment