Commit 160fbccd by vincent

added NeuralNetwork base class which provides some common functionality, such as…

added NeuralNetwork base class which provides some common functionality, such as making networks trainable
parent 57b51bd4
import * as tf from '@tensorflow/tfjs-core';
import { ParamMapping } from './types';
export class NeuralNetwork<TNetParams> {
protected _params: TNetParams | undefined = undefined
protected _paramMappings: ParamMapping[] = []
public get params(): TNetParams | undefined {
return this._params
}
public get paramMappings(): ParamMapping[] {
return this._paramMappings
}
public getParamFromPath(paramPath: string): tf.Tensor {
const { obj, objProp } = this.traversePropertyPath(paramPath)
return obj[objProp]
}
public reassignParamFromPath(paramPath: string, tensor: tf.Tensor) {
const { obj, objProp } = this.traversePropertyPath(paramPath)
obj[objProp].dispose()
obj[objProp] = tensor
}
public getParamList() {
return this._paramMappings.map(({ paramPath }) => ({
path: paramPath,
tensor: this.getParamFromPath(paramPath)
}))
}
public getTrainableParams() {
return this.getParamList().filter(param => param.tensor instanceof tf.Variable)
}
public getFrozenParams() {
return this.getParamList().filter(param => !(param.tensor instanceof tf.Variable))
}
public variable() {
this.getFrozenParams().forEach(({ path, tensor }) => {
this.reassignParamFromPath(path, tf.variable(tensor))
})
}
public freeze() {
this.getTrainableParams().forEach(({ path, tensor }) => {
this.reassignParamFromPath(path, tf.tensor(tensor as any))
})
}
public dispose() {
this.getParamList().forEach(param => param.tensor.dispose())
this._params = undefined
}
private traversePropertyPath(paramPath: string) {
if (!this.params) {
throw new Error(`traversePropertyPath - model has no loaded params`)
}
const result = paramPath.split('/').reduce((res: { nextObj: any, obj?: any, objProp?: string }, objProp) => {
if (!res.nextObj.hasOwnProperty(objProp)) {
throw new Error(`traversePropertyPath - object does not have property ${objProp}, for path ${paramPath}`)
}
return { obj: res.nextObj, objProp, nextObj: res.nextObj[objProp] }
}, { nextObj: this.params })
const { obj, objProp } = result
if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
throw new Error(`traversePropertyPath - parameter is not a tensor, for path ${paramPath}`)
}
return { obj, objProp }
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { ConvParams, ExtractWeightsFunction } from './types';
import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction) {
export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
return function (
channelsIn: number,
channelsOut: number,
filterSize: number
filterSize: number,
mappedPrefix: string
): ConvParams {
const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
......@@ -14,6 +15,11 @@ export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction)
)
const bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` }
)
return {
filters,
bias
......
import { isTensor } from './isTensor';
export function extractWeightEntry(weightMap: any, path: string, paramRank: number) {
const tensor = weightMap[path]
if (!isTensor(tensor, paramRank)) {
throw new Error(`expected weightMap[${path}] to be a Tensor${paramRank}D, instead have ${tensor}`)
}
return { path, tensor }
}
\ No newline at end of file
......@@ -13,3 +13,8 @@ export type BatchReshapeInfo = {
paddingX: number
paddingY: number
}
export type ParamMapping = {
originalPath?: string
paramPath: string
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { convLayer } from '../commons/convLayer';
import { NeuralNetwork } from '../commons/NeuralNetwork';
import { ConvParams } from '../commons/types';
import { NetInput } from '../NetInput';
import { Point } from '../Point';
......@@ -21,9 +22,7 @@ function maxPool(x: tf.Tensor4D, strides: [number, number] = [2, 2]): tf.Tensor4
return tf.maxPool(x, [2, 2], strides, 'valid')
}
export class FaceLandmarkNet {
private _params: NetParams
export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> {
if (weightsOrUrl instanceof Float32Array) {
......@@ -34,11 +33,23 @@ export class FaceLandmarkNet {
if (weightsOrUrl && typeof weightsOrUrl !== 'string') {
throw new Error('FaceLandmarkNet.load - expected model uri, or weights as Float32Array')
}
this._params = await loadQuantizedParams(weightsOrUrl)
const {
paramMappings,
params
} = await loadQuantizedParams(weightsOrUrl)
this._paramMappings = paramMappings
this._params = params
}
public extractWeights(weights: Float32Array) {
this._params = extractParams(weights)
const {
paramMappings,
params
} = extractParams(weights)
this._paramMappings = paramMappings
this._params = params
}
public forwardInput(input: NetInput): tf.Tensor2D {
......
......@@ -2,41 +2,52 @@ import * as tf from '@tensorflow/tfjs-core';
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
import { ParamMapping } from '../commons/types';
import { FCParams, NetParams } from './types';
export function extractParams(weights: Float32Array): NetParams {
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const {
extractWeights,
getRemainingWeights
} = extractWeightsFactory(weights)
const extractConvParams = extractConvParamsFactory(extractWeights)
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
function extractFcParams(channelsIn: number, channelsOut: number,): FCParams {
function extractFcParams(channelsIn: number, channelsOut: number, mappedPrefix: string): FCParams {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/bias` }
)
return {
weights: fc_weights,
bias: fc_bias
}
}
const conv0_params = extractConvParams(3, 32, 3)
const conv1_params = extractConvParams(32, 64, 3)
const conv2_params = extractConvParams(64, 64, 3)
const conv3_params = extractConvParams(64, 64, 3)
const conv4_params = extractConvParams(64, 64, 3)
const conv5_params = extractConvParams(64, 128, 3)
const conv6_params = extractConvParams(128, 128, 3)
const conv7_params = extractConvParams(128, 256, 3)
const fc0_params = extractFcParams(6400, 1024)
const fc1_params = extractFcParams(1024, 136)
const conv0_params = extractConvParams(3, 32, 3, 'conv0_params')
const conv1_params = extractConvParams(32, 64, 3, 'conv1_params')
const conv2_params = extractConvParams(64, 64, 3, 'conv2_params')
const conv3_params = extractConvParams(64, 64, 3, 'conv3_params')
const conv4_params = extractConvParams(64, 64, 3, 'conv4_params')
const conv5_params = extractConvParams(64, 128, 3, 'conv5_params')
const conv6_params = extractConvParams(128, 128, 3, 'conv6_params')
const conv7_params = extractConvParams(128, 256, 3, 'conv7_params')
const fc0_params = extractFcParams(6400, 1024, 'fc0_params')
const fc1_params = extractFcParams(1024, 136, 'fc1_params')
if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
}
return {
paramMappings,
params: {
conv0_params,
conv1_params,
conv2_params,
......@@ -48,4 +59,5 @@ export function extractParams(weights: Float32Array): NetParams {
fc0_params,
fc1_params
}
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { extractWeightEntry } from '../commons/extractWeightEntry';
import { loadWeightMap } from '../commons/loadWeightMap';
import { ConvParams } from '../commons/types';
import { ConvParams, ParamMapping } from '../commons/types';
import { FCParams, NetParams } from './types';
import { isTensor4D, isTensor1D, isTensor2D } from '../commons/isTensor';
const DEFAULT_MODEL_NAME = 'face_landmark_68_model'
function extractorsFactory(weightMap: any) {
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
function extractConvParams(prefix: string): ConvParams {
const params = {
filters: weightMap[`${prefix}/kernel`] as tf.Tensor4D,
bias: weightMap[`${prefix}/bias`] as tf.Tensor1D
}
if (!isTensor4D(params.filters)) {
throw new Error(`expected weightMap[${prefix}/kernel] to be a Tensor4D, instead have ${params.filters}`)
}
if (!isTensor1D(params.bias)) {
throw new Error(`expected weightMap[${prefix}/bias] to be a Tensor1D, instead have ${params.bias}`)
}
return params
}
function extractFcParams(prefix: string): FCParams {
const params = {
weights: weightMap[`${prefix}/kernel`] as tf.Tensor2D,
bias: weightMap[`${prefix}/bias`] as tf.Tensor1D
function extractConvParams(prefix: string, mappedPrefix: string): ConvParams {
const filtersEntry = extractWeightEntry(weightMap, `${prefix}/kernel`, 4)
const biasEntry = extractWeightEntry(weightMap, `${prefix}/bias`, 1)
paramMappings.push(
{ originalPath: filtersEntry.path, paramPath: `${mappedPrefix}/filters` },
{ originalPath: biasEntry.path, paramPath: `${mappedPrefix}/bias` }
)
return {
filters: filtersEntry.tensor as tf.Tensor4D,
bias: biasEntry.tensor as tf.Tensor1D
}
if (!isTensor2D(params.weights)) {
throw new Error(`expected weightMap[${prefix}/kernel] to be a Tensor2D, instead have ${params.weights}`)
}
if (!isTensor1D(params.bias)) {
throw new Error(`expected weightMap[${prefix}/bias] to be a Tensor1D, instead have ${params.bias}`)
function extractFcParams(prefix: string, mappedPrefix: string): FCParams {
const weightsEntry = extractWeightEntry(weightMap, `${prefix}/kernel`, 2)
const biasEntry = extractWeightEntry(weightMap, `${prefix}/bias`, 1)
paramMappings.push(
{ originalPath: weightsEntry.path, paramPath: `${mappedPrefix}/weights` },
{ originalPath: biasEntry.path, paramPath: `${mappedPrefix}/bias` }
)
return {
weights: weightsEntry.tensor as tf.Tensor2D,
bias: biasEntry.tensor as tf.Tensor1D
}
return params
}
return {
......@@ -49,24 +41,30 @@ function extractorsFactory(weightMap: any) {
}
}
export async function loadQuantizedParams(uri: string | undefined): Promise<NetParams> {
export async function loadQuantizedParams(
uri: string | undefined
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
const paramMappings: ParamMapping[] = []
const {
extractConvParams,
extractFcParams
} = extractorsFactory(weightMap)
} = extractorsFactory(weightMap, paramMappings)
return {
conv0_params: extractConvParams('conv2d_0'),
conv1_params: extractConvParams('conv2d_1'),
conv2_params: extractConvParams('conv2d_2'),
conv3_params: extractConvParams('conv2d_3'),
conv4_params: extractConvParams('conv2d_4'),
conv5_params: extractConvParams('conv2d_5'),
conv6_params: extractConvParams('conv2d_6'),
conv7_params: extractConvParams('conv2d_7'),
fc0_params: extractFcParams('dense'),
fc1_params: extractFcParams('logits')
const params = {
conv0_params: extractConvParams('conv2d_0', 'conv0_params'),
conv1_params: extractConvParams('conv2d_1', 'conv1_params'),
conv2_params: extractConvParams('conv2d_2', 'conv2_params'),
conv3_params: extractConvParams('conv2d_3', 'conv3_params'),
conv4_params: extractConvParams('conv2d_4', 'conv4_params'),
conv5_params: extractConvParams('conv2d_5', 'conv5_params'),
conv6_params: extractConvParams('conv2d_6', 'conv6_params'),
conv7_params: extractConvParams('conv2d_7', 'conv7_params'),
fc0_params: extractFcParams('dense', 'fc0_params'),
fc1_params: extractFcParams('logits', 'fc1_params')
}
return { params, paramMappings }
}
\ No newline at end of file
import { NeuralNetwork } from '../../../src/commons/NeuralNetwork';
import * as tf from '@tensorflow/tfjs-core';
class FakeNeuralNetwork extends NeuralNetwork<any> {
constructor(
convFilter: tf.Tensor = tf.tensor(0),
convBias: tf.Tensor = tf.tensor(0),
fcWeights: tf.Tensor = tf.tensor(0)
) {
super()
this._params = {
conv: {
filter: convFilter,
bias: convBias,
},
fc: fcWeights
}
this._paramMappings = [
{ originalPath: 'conv2d/filter', paramPath: 'conv/filter' },
{ originalPath: 'conv2d/bias', paramPath: 'conv/bias' },
{ originalPath: 'dense/weights', paramPath: 'fc' }
]
}
}
describe('NeuralNetwork', () => {
describe('getParamFromPath', () => {
it('returns correct params', () => tf.tidy(() => {
const convFilter = tf.tensor(0)
const convBias = tf.tensor(0)
const fcWeights = tf.tensor(0)
const net = new FakeNeuralNetwork(convFilter, convBias, fcWeights)
expect(net.getParamFromPath('conv/filter')).toEqual(convFilter)
expect(net.getParamFromPath('conv/bias')).toEqual(convBias)
expect(net.getParamFromPath('fc')).toEqual(fcWeights)
}))
it('throws if param is not a tensor', () => tf.tidy(() => {
const net = new FakeNeuralNetwork(null as any)
const fakePath = 'conv/filter'
expect(
() => net.getParamFromPath(fakePath)
).toThrowError(`traversePropertyPath - parameter is not a tensor, for path ${fakePath}`)
}))
it('throws if key path invalid', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
const fakePath = 'conv2d/foo'
expect(
() => net.getParamFromPath(fakePath)
).toThrowError(`traversePropertyPath - object does not have property conv2d, for path ${fakePath}`)
}))
})
describe('reassignParamFromPath', () => {
it('sets correct params', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
const convFilter = tf.tensor(0)
const convBias = tf.tensor(0)
const fcWeights = tf.tensor(0)
net.reassignParamFromPath('conv/filter', convFilter)
net.reassignParamFromPath('conv/bias', convBias)
net.reassignParamFromPath('fc', fcWeights)
expect(net.params.conv.filter).toEqual(convFilter)
expect(net.params.conv.bias).toEqual(convBias)
expect(net.params.fc).toEqual(fcWeights)
}))
it('throws if param is not a tensor', () => tf.tidy(() => {
const net = new FakeNeuralNetwork(null as any)
const fakePath = 'conv/filter'
expect(
() => net.reassignParamFromPath(fakePath, tf.tensor(0))
).toThrowError(`traversePropertyPath - parameter is not a tensor, for path ${fakePath}`)
}))
it('throws if key path invalid', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
const fakePath = 'conv2d/foo'
expect(
() => net.reassignParamFromPath(fakePath, tf.tensor(0))
).toThrowError(`traversePropertyPath - object does not have property conv2d, for path ${fakePath}`)
}))
})
describe('getParamList', () => {
it('returns param tensors with path', () => tf.tidy(() => {
const convFilter = tf.tensor(0)
const convBias = tf.tensor(0)
const fcWeights = tf.tensor(0)
const net = new FakeNeuralNetwork(convFilter, convBias, fcWeights)
const paramList = net.getParamList()
expect(paramList.length).toEqual(3)
expect(paramList[0].path).toEqual('conv/filter')
expect(paramList[1].path).toEqual('conv/bias')
expect(paramList[2].path).toEqual('fc')
expect(paramList[0].tensor).toEqual(convFilter)
expect(paramList[1].tensor).toEqual(convBias)
expect(paramList[2].tensor).toEqual(fcWeights)
}))
})
describe('getFrozenParams', () => {
it('returns all frozen params', () => tf.tidy(() => {
const convFilter = tf.tensor(0)
const convBias = tf.tensor(0)
const fcWeights = tf.variable(tf.scalar(0))
const net = new FakeNeuralNetwork(convFilter, convBias, fcWeights)
const frozenParams = net.getFrozenParams()
expect(frozenParams.length).toEqual(2)
expect(frozenParams[0].path).toEqual('conv/filter')
expect(frozenParams[1].path).toEqual('conv/bias')
expect(frozenParams[0].tensor).toEqual(convFilter)
expect(frozenParams[1].tensor).toEqual(convBias)
}))
})
describe('getTrainableParams', () => {
it('returns all trainable params', () => tf.tidy(() => {
const convFilter = tf.variable(tf.scalar(0))
const convBias = tf.variable(tf.scalar(0))
const fcWeights = tf.tensor(0)
const net = new FakeNeuralNetwork(convFilter, convBias, fcWeights)
const trainableParams = net.getTrainableParams()
expect(trainableParams.length).toEqual(2)
expect(trainableParams[0].path).toEqual('conv/filter')
expect(trainableParams[1].path).toEqual('conv/bias')
expect(trainableParams[0].tensor).toEqual(convFilter)
expect(trainableParams[1].tensor).toEqual(convBias)
}))
})
describe('dispose', () => {
it('disposes all param tensors', () => tf.tidy(() => {
const numTensors = tf.memory().numTensors
const net = new FakeNeuralNetwork()
net.dispose()
expect(net.params).toBe(undefined)
expect(tf.memory().numTensors - numTensors).toEqual(0)
}))
})
describe('variable', () => {
it('make all param tensors trainable', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
net.variable()
expect(net.params.conv.filter instanceof tf.Variable).toBe(true)
expect(net.params.conv.bias instanceof tf.Variable).toBe(true)
expect(net.params.fc instanceof tf.Variable).toBe(true)
}))
it('disposes old tensors', () => tf.tidy(() => {
const net = new FakeNeuralNetwork()
const numTensors = tf.memory().numTensors
net.variable()
expect(tf.memory().numTensors - numTensors).toEqual(0)
}))
})
describe('freeze', () => {
it('freezes all param variables', () => tf.tidy(() => {
const net = new FakeNeuralNetwork(
tf.variable(tf.scalar(0)),
tf.variable(tf.scalar(0)),
tf.variable(tf.scalar(0))
)
net.freeze()
expect(net.params.conv.filter instanceof tf.Variable).toBe(false)
expect(net.params.conv.bias instanceof tf.Variable).toBe(false)
expect(net.params.fc instanceof tf.Variable).toBe(false)
}))
it('disposes old tensors', () => tf.tidy(() => {
const net = new FakeNeuralNetwork(
tf.variable(tf.scalar(0)),
tf.variable(tf.scalar(0)),
tf.variable(tf.scalar(0))
)
const numTensors = tf.memory().numTensors
net.freeze()
expect(tf.memory().numTensors - numTensors).toEqual(0)
}))
})
})
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment