Commit 160fbccd by vincent

added NeuralNetwork base class which provides some common functionality, such as…

added NeuralNetwork base class which provides some common functionality, such as making networks trainable
parent 57b51bd4
import * as tf from '@tensorflow/tfjs-core';
import { ParamMapping } from './types';
export class NeuralNetwork<TNetParams> {
protected _params: TNetParams | undefined = undefined
protected _paramMappings: ParamMapping[] = []
public get params(): TNetParams | undefined {
return this._params
}
public get paramMappings(): ParamMapping[] {
return this._paramMappings
}
public getParamFromPath(paramPath: string): tf.Tensor {
const { obj, objProp } = this.traversePropertyPath(paramPath)
return obj[objProp]
}
public reassignParamFromPath(paramPath: string, tensor: tf.Tensor) {
const { obj, objProp } = this.traversePropertyPath(paramPath)
obj[objProp].dispose()
obj[objProp] = tensor
}
public getParamList() {
return this._paramMappings.map(({ paramPath }) => ({
path: paramPath,
tensor: this.getParamFromPath(paramPath)
}))
}
public getTrainableParams() {
return this.getParamList().filter(param => param.tensor instanceof tf.Variable)
}
public getFrozenParams() {
return this.getParamList().filter(param => !(param.tensor instanceof tf.Variable))
}
public variable() {
this.getFrozenParams().forEach(({ path, tensor }) => {
this.reassignParamFromPath(path, tf.variable(tensor))
})
}
public freeze() {
this.getTrainableParams().forEach(({ path, tensor }) => {
this.reassignParamFromPath(path, tf.tensor(tensor as any))
})
}
public dispose() {
this.getParamList().forEach(param => param.tensor.dispose())
this._params = undefined
}
private traversePropertyPath(paramPath: string) {
if (!this.params) {
throw new Error(`traversePropertyPath - model has no loaded params`)
}
const result = paramPath.split('/').reduce((res: { nextObj: any, obj?: any, objProp?: string }, objProp) => {
if (!res.nextObj.hasOwnProperty(objProp)) {
throw new Error(`traversePropertyPath - object does not have property ${objProp}, for path ${paramPath}`)
}
return { obj: res.nextObj, objProp, nextObj: res.nextObj[objProp] }
}, { nextObj: this.params })
const { obj, objProp } = result
if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
throw new Error(`traversePropertyPath - parameter is not a tensor, for path ${paramPath}`)
}
return { obj, objProp }
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { ConvParams, ExtractWeightsFunction } from './types'; import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction) { export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
return function ( return function (
channelsIn: number, channelsIn: number,
channelsOut: number, channelsOut: number,
filterSize: number filterSize: number,
mappedPrefix: string
): ConvParams { ): ConvParams {
const filters = tf.tensor4d( const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize), extractWeights(channelsIn * channelsOut * filterSize * filterSize),
...@@ -14,6 +15,11 @@ export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction) ...@@ -14,6 +15,11 @@ export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction)
) )
const bias = tf.tensor1d(extractWeights(channelsOut)) const bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` }
)
return { return {
filters, filters,
bias bias
......
import { isTensor } from './isTensor';
export function extractWeightEntry(weightMap: any, path: string, paramRank: number) {
const tensor = weightMap[path]
if (!isTensor(tensor, paramRank)) {
throw new Error(`expected weightMap[${path}] to be a Tensor${paramRank}D, instead have ${tensor}`)
}
return { path, tensor }
}
\ No newline at end of file
...@@ -13,3 +13,8 @@ export type BatchReshapeInfo = { ...@@ -13,3 +13,8 @@ export type BatchReshapeInfo = {
paddingX: number paddingX: number
paddingY: number paddingY: number
} }
export type ParamMapping = {
originalPath?: string
paramPath: string
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { convLayer } from '../commons/convLayer'; import { convLayer } from '../commons/convLayer';
import { NeuralNetwork } from '../commons/NeuralNetwork';
import { ConvParams } from '../commons/types'; import { ConvParams } from '../commons/types';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { Point } from '../Point'; import { Point } from '../Point';
...@@ -21,9 +22,7 @@ function maxPool(x: tf.Tensor4D, strides: [number, number] = [2, 2]): tf.Tensor4 ...@@ -21,9 +22,7 @@ function maxPool(x: tf.Tensor4D, strides: [number, number] = [2, 2]): tf.Tensor4
return tf.maxPool(x, [2, 2], strides, 'valid') return tf.maxPool(x, [2, 2], strides, 'valid')
} }
export class FaceLandmarkNet { export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
private _params: NetParams
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> { public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> {
if (weightsOrUrl instanceof Float32Array) { if (weightsOrUrl instanceof Float32Array) {
...@@ -34,11 +33,23 @@ export class FaceLandmarkNet { ...@@ -34,11 +33,23 @@ export class FaceLandmarkNet {
if (weightsOrUrl && typeof weightsOrUrl !== 'string') { if (weightsOrUrl && typeof weightsOrUrl !== 'string') {
throw new Error('FaceLandmarkNet.load - expected model uri, or weights as Float32Array') throw new Error('FaceLandmarkNet.load - expected model uri, or weights as Float32Array')
} }
this._params = await loadQuantizedParams(weightsOrUrl) const {
paramMappings,
params
} = await loadQuantizedParams(weightsOrUrl)
this._paramMappings = paramMappings
this._params = params
} }
public extractWeights(weights: Float32Array) { public extractWeights(weights: Float32Array) {
this._params = extractParams(weights) const {
paramMappings,
params
} = extractParams(weights)
this._paramMappings = paramMappings
this._params = params
} }
public forwardInput(input: NetInput): tf.Tensor2D { public forwardInput(input: NetInput): tf.Tensor2D {
......
...@@ -2,50 +2,62 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,50 +2,62 @@ import * as tf from '@tensorflow/tfjs-core';
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory'; import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
import { extractWeightsFactory } from '../commons/extractWeightsFactory'; import { extractWeightsFactory } from '../commons/extractWeightsFactory';
import { ParamMapping } from '../commons/types';
import { FCParams, NetParams } from './types'; import { FCParams, NetParams } from './types';
export function extractParams(weights: Float32Array): NetParams { export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights
} = extractWeightsFactory(weights) } = extractWeightsFactory(weights)
const extractConvParams = extractConvParamsFactory(extractWeights) const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
function extractFcParams(channelsIn: number, channelsOut: number,): FCParams { function extractFcParams(channelsIn: number, channelsOut: number, mappedPrefix: string): FCParams {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]) const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
const fc_bias = tf.tensor1d(extractWeights(channelsOut)) const fc_bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/bias` }
)
return { return {
weights: fc_weights, weights: fc_weights,
bias: fc_bias bias: fc_bias
} }
} }
const conv0_params = extractConvParams(3, 32, 3) const conv0_params = extractConvParams(3, 32, 3, 'conv0_params')
const conv1_params = extractConvParams(32, 64, 3) const conv1_params = extractConvParams(32, 64, 3, 'conv1_params')
const conv2_params = extractConvParams(64, 64, 3) const conv2_params = extractConvParams(64, 64, 3, 'conv2_params')
const conv3_params = extractConvParams(64, 64, 3) const conv3_params = extractConvParams(64, 64, 3, 'conv3_params')
const conv4_params = extractConvParams(64, 64, 3) const conv4_params = extractConvParams(64, 64, 3, 'conv4_params')
const conv5_params = extractConvParams(64, 128, 3) const conv5_params = extractConvParams(64, 128, 3, 'conv5_params')
const conv6_params = extractConvParams(128, 128, 3) const conv6_params = extractConvParams(128, 128, 3, 'conv6_params')
const conv7_params = extractConvParams(128, 256, 3) const conv7_params = extractConvParams(128, 256, 3, 'conv7_params')
const fc0_params = extractFcParams(6400, 1024) const fc0_params = extractFcParams(6400, 1024, 'fc0_params')
const fc1_params = extractFcParams(1024, 136) const fc1_params = extractFcParams(1024, 136, 'fc1_params')
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
} }
return { return {
conv0_params, paramMappings,
conv1_params, params: {
conv2_params, conv0_params,
conv3_params, conv1_params,
conv4_params, conv2_params,
conv5_params, conv3_params,
conv6_params, conv4_params,
conv7_params, conv5_params,
fc0_params, conv6_params,
fc1_params conv7_params,
fc0_params,
fc1_params
}
} }
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { extractWeightEntry } from '../commons/extractWeightEntry';
import { loadWeightMap } from '../commons/loadWeightMap'; import { loadWeightMap } from '../commons/loadWeightMap';
import { ConvParams } from '../commons/types'; import { ConvParams, ParamMapping } from '../commons/types';
import { FCParams, NetParams } from './types'; import { FCParams, NetParams } from './types';
import { isTensor4D, isTensor1D, isTensor2D } from '../commons/isTensor';
const DEFAULT_MODEL_NAME = 'face_landmark_68_model' const DEFAULT_MODEL_NAME = 'face_landmark_68_model'
function extractorsFactory(weightMap: any) { function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
function extractConvParams(prefix: string): ConvParams { function extractConvParams(prefix: string, mappedPrefix: string): ConvParams {
const params = { const filtersEntry = extractWeightEntry(weightMap, `${prefix}/kernel`, 4)
filters: weightMap[`${prefix}/kernel`] as tf.Tensor4D, const biasEntry = extractWeightEntry(weightMap, `${prefix}/bias`, 1)
bias: weightMap[`${prefix}/bias`] as tf.Tensor1D paramMappings.push(
{ originalPath: filtersEntry.path, paramPath: `${mappedPrefix}/filters` },
{ originalPath: biasEntry.path, paramPath: `${mappedPrefix}/bias` }
)
return {
filters: filtersEntry.tensor as tf.Tensor4D,
bias: biasEntry.tensor as tf.Tensor1D
} }
if (!isTensor4D(params.filters)) {
throw new Error(`expected weightMap[${prefix}/kernel] to be a Tensor4D, instead have ${params.filters}`)
}
if (!isTensor1D(params.bias)) {
throw new Error(`expected weightMap[${prefix}/bias] to be a Tensor1D, instead have ${params.bias}`)
}
return params
} }
function extractFcParams(prefix: string): FCParams { function extractFcParams(prefix: string, mappedPrefix: string): FCParams {
const params = { const weightsEntry = extractWeightEntry(weightMap, `${prefix}/kernel`, 2)
weights: weightMap[`${prefix}/kernel`] as tf.Tensor2D, const biasEntry = extractWeightEntry(weightMap, `${prefix}/bias`, 1)
bias: weightMap[`${prefix}/bias`] as tf.Tensor1D paramMappings.push(
} { originalPath: weightsEntry.path, paramPath: `${mappedPrefix}/weights` },
{ originalPath: biasEntry.path, paramPath: `${mappedPrefix}/bias` }
if (!isTensor2D(params.weights)) { )
throw new Error(`expected weightMap[${prefix}/kernel] to be a Tensor2D, instead have ${params.weights}`) return {
weights: weightsEntry.tensor as tf.Tensor2D,
bias: biasEntry.tensor as tf.Tensor1D
} }
if (!isTensor1D(params.bias)) {
throw new Error(`expected weightMap[${prefix}/bias] to be a Tensor1D, instead have ${params.bias}`)
}
return params
} }
return { return {
...@@ -49,24 +41,30 @@ function extractorsFactory(weightMap: any) { ...@@ -49,24 +41,30 @@ function extractorsFactory(weightMap: any) {
} }
} }
export async function loadQuantizedParams(uri: string | undefined): Promise<NetParams> { export async function loadQuantizedParams(
uri: string | undefined
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME) const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
const paramMappings: ParamMapping[] = []
const { const {
extractConvParams, extractConvParams,
extractFcParams extractFcParams
} = extractorsFactory(weightMap) } = extractorsFactory(weightMap, paramMappings)
return { const params = {
conv0_params: extractConvParams('conv2d_0'), conv0_params: extractConvParams('conv2d_0', 'conv0_params'),
conv1_params: extractConvParams('conv2d_1'), conv1_params: extractConvParams('conv2d_1', 'conv1_params'),
conv2_params: extractConvParams('conv2d_2'), conv2_params: extractConvParams('conv2d_2', 'conv2_params'),
conv3_params: extractConvParams('conv2d_3'), conv3_params: extractConvParams('conv2d_3', 'conv3_params'),
conv4_params: extractConvParams('conv2d_4'), conv4_params: extractConvParams('conv2d_4', 'conv4_params'),
conv5_params: extractConvParams('conv2d_5'), conv5_params: extractConvParams('conv2d_5', 'conv5_params'),
conv6_params: extractConvParams('conv2d_6'), conv6_params: extractConvParams('conv2d_6', 'conv6_params'),
conv7_params: extractConvParams('conv2d_7'), conv7_params: extractConvParams('conv2d_7', 'conv7_params'),
fc0_params: extractFcParams('dense'), fc0_params: extractFcParams('dense', 'fc0_params'),
fc1_params: extractFcParams('logits') fc1_params: extractFcParams('logits', 'fc1_params')
} }
return { params, paramMappings }
} }
\ No newline at end of file
import { NeuralNetwork } from '../../../src/commons/NeuralNetwork';
import * as tf from '@tensorflow/tfjs-core';
class FakeNeuralNetwork extends NeuralNetwork<any> {
constructor(
convFilter: tf.Tensor = tf.tensor(0),
convBias: tf.Tensor = tf.tensor(0),
fcWeights: tf.Tensor = tf.tensor(0)
) {
super()
this._params = {
conv: {
filter: convFilter,
bias: convBias,
},
fc: fcWeights
}
this._paramMappings = [
{ originalPath: 'conv2d/filter', paramPath: 'conv/filter' },
{ originalPath: 'conv2d/bias', paramPath: 'conv/bias' },
{ originalPath: 'dense/weights', paramPath: 'fc' }
]
}
}
describe('NeuralNetwork', () => {
describe('getParamFromPath', () => {
it('returns correct params', () => tf.tidy(() => {
const convFilter = tf.tensor(0)
const convBias = tf.tensor(0)
const fcWeights = tf.tensor(0)
const net = new FakeNeuralNetwork(convFilter, convBias, fcWeights)
expect(net.getParamFromPath('conv/filter')).toEqual(convFilter)
expect(net.getParamFromPath('conv/bias')).toEqual(convBias)
expect(net.getParamFromPath('fc')).toEqual(fcWeights)
}))
it('throws if param is not a tensor', () => tf.tidy(() => {
const net = new FakeNeuralNetwork(null as any)
const fakePath = 'conv/filter'
expect(
() => net.getParamFromPath(fakePath)
).toThrowError(`traversePropertyPath - parameter is not a tensor, for path ${fakePath}`)
}))
it('throws if key path invalid', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
const fakePath = 'conv2d/foo'
expect(
() => net.getParamFromPath(fakePath)
).toThrowError(`traversePropertyPath - object does not have property conv2d, for path ${fakePath}`)
}))
})
describe('reassignParamFromPath', () => {
it('sets correct params', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
const convFilter = tf.tensor(0)
const convBias = tf.tensor(0)
const fcWeights = tf.tensor(0)
net.reassignParamFromPath('conv/filter', convFilter)
net.reassignParamFromPath('conv/bias', convBias)
net.reassignParamFromPath('fc', fcWeights)
expect(net.params.conv.filter).toEqual(convFilter)
expect(net.params.conv.bias).toEqual(convBias)
expect(net.params.fc).toEqual(fcWeights)
}))
it('throws if param is not a tensor', () => tf.tidy(() => {
const net = new FakeNeuralNetwork(null as any)
const fakePath = 'conv/filter'
expect(
() => net.reassignParamFromPath(fakePath, tf.tensor(0))
).toThrowError(`traversePropertyPath - parameter is not a tensor, for path ${fakePath}`)
}))
it('throws if key path invalid', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
const fakePath = 'conv2d/foo'
expect(
() => net.reassignParamFromPath(fakePath, tf.tensor(0))
).toThrowError(`traversePropertyPath - object does not have property conv2d, for path ${fakePath}`)
}))
})
describe('getParamList', () => {
it('returns param tensors with path', () => tf.tidy(() => {
const convFilter = tf.tensor(0)
const convBias = tf.tensor(0)
const fcWeights = tf.tensor(0)
const net = new FakeNeuralNetwork(convFilter, convBias, fcWeights)
const paramList = net.getParamList()
expect(paramList.length).toEqual(3)
expect(paramList[0].path).toEqual('conv/filter')
expect(paramList[1].path).toEqual('conv/bias')
expect(paramList[2].path).toEqual('fc')
expect(paramList[0].tensor).toEqual(convFilter)
expect(paramList[1].tensor).toEqual(convBias)
expect(paramList[2].tensor).toEqual(fcWeights)
}))
})
describe('getFrozenParams', () => {
it('returns all frozen params', () => tf.tidy(() => {
const convFilter = tf.tensor(0)
const convBias = tf.tensor(0)
const fcWeights = tf.variable(tf.scalar(0))
const net = new FakeNeuralNetwork(convFilter, convBias, fcWeights)
const frozenParams = net.getFrozenParams()
expect(frozenParams.length).toEqual(2)
expect(frozenParams[0].path).toEqual('conv/filter')
expect(frozenParams[1].path).toEqual('conv/bias')
expect(frozenParams[0].tensor).toEqual(convFilter)
expect(frozenParams[1].tensor).toEqual(convBias)
}))
})
describe('getTrainableParams', () => {
it('returns all trainable params', () => tf.tidy(() => {
const convFilter = tf.variable(tf.scalar(0))
const convBias = tf.variable(tf.scalar(0))
const fcWeights = tf.tensor(0)
const net = new FakeNeuralNetwork(convFilter, convBias, fcWeights)
const trainableParams = net.getTrainableParams()
expect(trainableParams.length).toEqual(2)
expect(trainableParams[0].path).toEqual('conv/filter')
expect(trainableParams[1].path).toEqual('conv/bias')
expect(trainableParams[0].tensor).toEqual(convFilter)
expect(trainableParams[1].tensor).toEqual(convBias)
}))
})
describe('dispose', () => {
it('disposes all param tensors', () => tf.tidy(() => {
const numTensors = tf.memory().numTensors
const net = new FakeNeuralNetwork()
net.dispose()
expect(net.params).toBe(undefined)
expect(tf.memory().numTensors - numTensors).toEqual(0)
}))
})
describe('variable', () => {
it('make all param tensors trainable', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
net.variable()
expect(net.params.conv.filter instanceof tf.Variable).toBe(true)
expect(net.params.conv.bias instanceof tf.Variable).toBe(true)
expect(net.params.fc instanceof tf.Variable).toBe(true)
}))
it('disposes old tensors', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
const numTensors = tf.memory().numTensors
net.variable()
expect(tf.memory().numTensors - numTensors).toEqual(0)
}))
})
describe('freeze', () => {
it('freezes all param variables', () => tf.tidy(() => {
const net = new FakeNeuralNetwork(
tf.variable(tf.scalar(0)),
tf.variable(tf.scalar(0)),
tf.variable(tf.scalar(0))
)
net.freeze()
expect(net.params.conv.filter instanceof tf.Variable).toBe(false)
expect(net.params.conv.bias instanceof tf.Variable).toBe(false)
expect(net.params.fc instanceof tf.Variable).toBe(false)
}))
it('disposes old tensors', () => tf.tidy(() => {
const net = new FakeNeuralNetwork(
tf.variable(tf.scalar(0)),
tf.variable(tf.scalar(0)),
tf.variable(tf.scalar(0))
)
const numTensors = tf.memory().numTensors
net.freeze()
expect(tf.memory().numTensors - numTensors).toEqual(0)
}))
})
})
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment