Commit 886cf99e by vincent

init tiny yolov2 + weight loading

parent 9ef63518
......@@ -21,5 +21,6 @@ export * from './faceRecognitionNet';
export * from './globalApi';
export * from './mtcnn';
export * from './padToSquare';
export * from './tinyYolov2';
export * from './toNetInput';
export * from './utils'
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { convLayer } from '../commons/convLayer';
import { NeuralNetwork } from '../commons/NeuralNetwork';
import { NetInput } from '../NetInput';
import { toNetInput } from '../toNetInput';
import { TNetInput } from '../types';
import { convWithBatchNorm } from './convWithBatchNorm';
import { extractParams } from './extractParams';
import { NetParams } from './types';
export class TinyYolov2 extends NeuralNetwork<NetParams> {
constructor() {
super('TinyYolov2')
}
public async forwardInput(input: NetInput): Promise<any> {
const { params } = this
if (!params) {
throw new Error('TinyYolov2 - load model before inference')
}
const out = tf.tidy(() => {
const batchTensor = input.toBatchTensor(416).div(tf.scalar(255)).toFloat()
let out = tf.pad(batchTensor, [[0, 0], [1, 1], [1, 1], [0, 0]]) as tf.Tensor4D
out = convWithBatchNorm(out, params.conv0)
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
out = convWithBatchNorm(out, params.conv1)
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
out = convWithBatchNorm(out, params.conv2)
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
out = convWithBatchNorm(out, params.conv3)
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
out = convWithBatchNorm(out, params.conv4)
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
out = convWithBatchNorm(out, params.conv5)
out = tf.maxPool(out, [2, 2], [1, 1], 'valid')
out = convWithBatchNorm(out, params.conv6)
out = convWithBatchNorm(out, params.conv7)
out = convLayer(out, params.conv8, 'valid', false)
return out
})
return out
}
public async forward(input: TNetInput): Promise<any> {
return await this.forwardInput(await toNetInput(input, true, true))
}
/* TODO
protected loadQuantizedParams(uri: string | undefined) {
return loadQuantizedParams(uri)
}
*/
protected extractParams(weights: Float32Array) {
return extractParams(weights)
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { leaky } from './leaky';
import { ConvWithBatchNorm } from './types';
export function convWithBatchNorm(x: tf.Tensor4D, params: ConvWithBatchNorm): tf.Tensor4D {
return tf.tidy(() => {
let out = tf.conv2d(x, params.conv.filters, [1, 1], 'valid')
out = tf.sub(out, params.bn.sub)
out = tf.mul(out, params.bn.truediv)
out = tf.add(out, params.conv.bias)
return leaky(out)
})
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
import { ExtractWeightsFunction, ParamMapping } from '../commons/types';
import { BatchNorm, ConvWithBatchNorm, NetParams } from './types';
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
function extractBatchNormParams(size: number, mappedPrefix: string): BatchNorm {
const sub = tf.tensor1d(extractWeights(size))
const truediv = tf.tensor1d(extractWeights(size))
paramMappings.push(
{ paramPath: `${mappedPrefix}/sub` },
{ paramPath: `${mappedPrefix}/truediv` }
)
return { sub, truediv }
}
function extractConvWithBatchNormParams(channelsIn: number, channelsOut: number, mappedPrefix: string): ConvWithBatchNorm {
const conv = extractConvParams(channelsIn, channelsOut, 3, `${mappedPrefix}/conv`)
const bn = extractBatchNormParams(channelsOut, `${mappedPrefix}/bn`)
return { conv, bn }
}
return {
extractConvParams,
extractConvWithBatchNormParams
}
}
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const {
extractWeights,
getRemainingWeights
} = extractWeightsFactory(weights)
const paramMappings: ParamMapping[] = []
const {
extractConvParams,
extractConvWithBatchNormParams
} = extractorsFactory(extractWeights, paramMappings)
const conv0 = extractConvWithBatchNormParams(3, 16, 'conv0')
const conv1 = extractConvWithBatchNormParams(16, 32, 'conv1')
const conv2 = extractConvWithBatchNormParams(32, 64, 'conv2')
const conv3 = extractConvWithBatchNormParams(64, 128, 'conv3')
const conv4 = extractConvWithBatchNormParams(128, 256, 'conv4')
const conv5 = extractConvWithBatchNormParams(256, 512, 'conv5')
const conv6 = extractConvWithBatchNormParams(512, 1024, 'conv6')
const conv7 = extractConvWithBatchNormParams(1024, 1024, 'conv7')
const conv8 = extractConvParams(1024, 30, 1, 'conv8')
if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
}
const params = { conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8 }
return { params, paramMappings }
}
\ No newline at end of file
import { TinyYolov2 } from './TinyYolov2';
export * from './TinyYolov2';
export function createTinyYolov2(weights: Float32Array) {
const net = new TinyYolov2()
net.extractWeights(weights)
return net
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
export function leaky(x: tf.Tensor4D): tf.Tensor4D {
return tf.tidy(() => {
return tf.maximum(x, tf.mul(x, tf.scalar(0.10000000149011612)))
})
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { ConvParams } from '../commons/types';
export type BatchNorm = {
sub: tf.Tensor1D
truediv: tf.Tensor1D
}
export type ConvWithBatchNorm = {
conv: ConvParams
bn: BatchNorm
}
export type NetParams = {
conv0: ConvWithBatchNorm
conv1: ConvWithBatchNorm
conv2: ConvWithBatchNorm
conv3: ConvWithBatchNorm
conv4: ConvWithBatchNorm
conv5: ConvWithBatchNorm
conv6: ConvWithBatchNorm
conv7: ConvWithBatchNorm
conv8: ConvParams
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment