Commit fa5a59a6 by vincent

fix weight reading

parent 62163b6d
import * as tf from '@tensorflow/tfjs-core';
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory'; import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
import { extractWeightsFactory } from '../commons/extractWeightsFactory'; import { extractWeightsFactory } from '../commons/extractWeightsFactory';
import { FaceLandmarkNet } from './types'; import { FaceLandmarkNet } from './types';
import * as tf from '@tensorflow/tfjs-core';
export function extractParams(weights: Float32Array): FaceLandmarkNet.NetParams { export function extractParams(weights: Float32Array): FaceLandmarkNet.NetParams {
const { const {
...@@ -20,20 +21,31 @@ export function extractParams(weights: Float32Array): FaceLandmarkNet.NetParams ...@@ -20,20 +21,31 @@ export function extractParams(weights: Float32Array): FaceLandmarkNet.NetParams
} }
} }
const conv0_params = extractConvParams(3, 32, 3)
const conv1_params = extractConvParams(32, 64, 3)
const conv2_params = extractConvParams(64, 64, 3)
const conv3_params = extractConvParams(64, 64, 3)
const conv4_params = extractConvParams(64, 64, 3)
const conv5_params = extractConvParams(64, 128, 3)
const conv6_params = extractConvParams(128, 128, 3)
const conv7_params = extractConvParams(128, 256, 3)
const fc0_params = extractFcParams(6400, 1024)
const fc1_params = extractFcParams(1024, 136)
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
} }
return { return {
conv0_params: extractConvParams(3, 32, 3), conv0_params,
conv1_params: extractConvParams(32, 64, 3), conv1_params,
conv2_params: extractConvParams(64, 64, 3), conv2_params,
conv3_params: extractConvParams(64, 64, 3), conv3_params,
conv4_params: extractConvParams(64, 64, 3), conv4_params,
conv5_params: extractConvParams(64, 128, 3), conv5_params,
conv6_params: extractConvParams(128, 128, 3), conv6_params,
conv7_params: extractConvParams(128, 256, 3), conv7_params,
fc0_params: extractFcParams(6400, 1024), fc0_params,
fc1_params:extractFcParams(1024, 136) fc1_params
} }
} }
\ No newline at end of file
...@@ -2,6 +2,7 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,6 +2,7 @@ import * as tf from '@tensorflow/tfjs-core';
import { euclideanDistance } from './euclideanDistance'; import { euclideanDistance } from './euclideanDistance';
import { faceDetectionNet } from './faceDetectionNet'; import { faceDetectionNet } from './faceDetectionNet';
import { faceLandmarkNet } from './faceLandmarkNet';
import { faceRecognitionNet } from './faceRecognitionNet'; import { faceRecognitionNet } from './faceRecognitionNet';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import { padToSquare } from './padToSquare'; import { padToSquare } from './padToSquare';
...@@ -9,6 +10,7 @@ import { padToSquare } from './padToSquare'; ...@@ -9,6 +10,7 @@ import { padToSquare } from './padToSquare';
export { export {
euclideanDistance, euclideanDistance,
faceDetectionNet, faceDetectionNet,
faceLandmarkNet,
faceRecognitionNet, faceRecognitionNet,
NetInput, NetInput,
tf, tf,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment