Commit fa5a59a6 by vincent

fix weight reading

parent 62163b6d
import * as tf from '@tensorflow/tfjs-core';
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
import { FaceLandmarkNet } from './types';
import * as tf from '@tensorflow/tfjs-core';
export function extractParams(weights: Float32Array): FaceLandmarkNet.NetParams {
const {
......@@ -20,20 +21,31 @@ export function extractParams(weights: Float32Array): FaceLandmarkNet.NetParams
}
}
const conv0_params = extractConvParams(3, 32, 3)
const conv1_params = extractConvParams(32, 64, 3)
const conv2_params = extractConvParams(64, 64, 3)
const conv3_params = extractConvParams(64, 64, 3)
const conv4_params = extractConvParams(64, 64, 3)
const conv5_params = extractConvParams(64, 128, 3)
const conv6_params = extractConvParams(128, 128, 3)
const conv7_params = extractConvParams(128, 256, 3)
const fc0_params = extractFcParams(6400, 1024)
const fc1_params = extractFcParams(1024, 136)
if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
}
return {
conv0_params: extractConvParams(3, 32, 3),
conv1_params: extractConvParams(32, 64, 3),
conv2_params: extractConvParams(64, 64, 3),
conv3_params: extractConvParams(64, 64, 3),
conv4_params: extractConvParams(64, 64, 3),
conv5_params: extractConvParams(64, 128, 3),
conv6_params: extractConvParams(128, 128, 3),
conv7_params: extractConvParams(128, 256, 3),
fc0_params: extractFcParams(6400, 1024),
fc1_params:extractFcParams(1024, 136)
conv0_params,
conv1_params,
conv2_params,
conv3_params,
conv4_params,
conv5_params,
conv6_params,
conv7_params,
fc0_params,
fc1_params
}
}
\ No newline at end of file
......@@ -2,6 +2,7 @@ import * as tf from '@tensorflow/tfjs-core';
import { euclideanDistance } from './euclideanDistance';
import { faceDetectionNet } from './faceDetectionNet';
import { faceLandmarkNet } from './faceLandmarkNet';
import { faceRecognitionNet } from './faceRecognitionNet';
import { NetInput } from './NetInput';
import { padToSquare } from './padToSquare';
......@@ -9,6 +10,7 @@ import { padToSquare } from './padToSquare';
export {
euclideanDistance,
faceDetectionNet,
faceLandmarkNet,
faceRecognitionNet,
NetInput,
tf,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment