Commit 43bf6889 by vincent

move shared logic for nets working on face images to FaceProcessor

parent b3bcb3ef
import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor';
import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
import { FaceProcessor } from '../faceProcessor/FaceProcessor';
export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> {
constructor(faceFeatureExtractor: FaceFeatureExtractor) {
super('FaceExpressionNet', faceFeatureExtractor)
}
public dispose(throwOnRedispose: boolean = true) {
this.faceFeatureExtractor.dispose(throwOnRedispose)
super.dispose(throwOnRedispose)
}
protected getDefaultModelName(): string {
return 'face_expression_model'
}
protected getClassifierChannelsIn(): number {
return 256
}
protected getClassifierChannelsOut(): number {
return 7
}
}
\ No newline at end of file
export * from './FaceExpressionNet';
\ No newline at end of file
...@@ -5,7 +5,7 @@ import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2'; ...@@ -5,7 +5,7 @@ import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
import { depthwiseSeparableConv } from './depthwiseSeparableConv'; import { depthwiseSeparableConv } from './depthwiseSeparableConv';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap'; import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { DenseBlock4Params, NetParams } from './types'; import { DenseBlock4Params, IFaceFeatureExtractor, FaceFeatureExtractorParams } from './types';
function denseBlock( function denseBlock(
x: tf.Tensor4D, x: tf.Tensor4D,
...@@ -33,7 +33,7 @@ function denseBlock( ...@@ -33,7 +33,7 @@ function denseBlock(
}) })
} }
export class FaceFeatureExtractor extends NeuralNetwork<NetParams> { export class FaceFeatureExtractor extends NeuralNetwork<FaceFeatureExtractorParams> implements IFaceFeatureExtractor<FaceFeatureExtractorParams> {
constructor() { constructor() {
super('FaceFeatureExtractor') super('FaceFeatureExtractor')
......
...@@ -5,7 +5,7 @@ import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2'; ...@@ -5,7 +5,7 @@ import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
import { depthwiseSeparableConv } from './depthwiseSeparableConv'; import { depthwiseSeparableConv } from './depthwiseSeparableConv';
import { extractParamsFromWeigthMapTiny } from './extractParamsFromWeigthMapTiny'; import { extractParamsFromWeigthMapTiny } from './extractParamsFromWeigthMapTiny';
import { extractParamsTiny } from './extractParamsTiny'; import { extractParamsTiny } from './extractParamsTiny';
import { DenseBlock3Params, TinyNetParams } from './types'; import { DenseBlock3Params, IFaceFeatureExtractor, TinyFaceFeatureExtractorParams } from './types';
function denseBlock( function denseBlock(
x: tf.Tensor4D, x: tf.Tensor4D,
...@@ -30,7 +30,7 @@ function denseBlock( ...@@ -30,7 +30,7 @@ function denseBlock(
}) })
} }
export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyNetParams> { export class TinyFaceFeatureExtractor extends NeuralNetwork<TinyFaceFeatureExtractorParams> implements IFaceFeatureExtractor<TinyFaceFeatureExtractorParams> {
constructor() { constructor() {
super('TinyFaceFeatureExtractor') super('TinyFaceFeatureExtractor')
......
import { extractWeightsFactory, ParamMapping } from 'tfjs-image-recognition-base'; import { extractWeightsFactory, ParamMapping } from 'tfjs-image-recognition-base';
import { extractorsFactory } from './extractorsFactory'; import { extractorsFactory } from './extractorsFactory';
import { NetParams } from './types'; import { FaceFeatureExtractorParams } from './types';
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } { export function extractParams(weights: Float32Array): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [] const paramMappings: ParamMapping[] = []
......
...@@ -2,11 +2,11 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,11 +2,11 @@ import * as tf from '@tensorflow/tfjs-core';
import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base'; import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base';
import { loadParamsFactory } from './loadParamsFactory'; import { loadParamsFactory } from './loadParamsFactory';
import { NetParams } from './types'; import { FaceFeatureExtractorParams } from './types';
export function extractParamsFromWeigthMap( export function extractParamsFromWeigthMap(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap
): { params: NetParams, paramMappings: ParamMapping[] } { ): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [] const paramMappings: ParamMapping[] = []
......
...@@ -2,11 +2,11 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,11 +2,11 @@ import * as tf from '@tensorflow/tfjs-core';
import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base'; import { disposeUnusedWeightTensors, ParamMapping } from 'tfjs-image-recognition-base';
import { loadParamsFactory } from './loadParamsFactory'; import { loadParamsFactory } from './loadParamsFactory';
import { TinyNetParams } from './types'; import { TinyFaceFeatureExtractorParams } from './types';
export function extractParamsFromWeigthMapTiny( export function extractParamsFromWeigthMapTiny(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap
): { params: TinyNetParams, paramMappings: ParamMapping[] } { ): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [] const paramMappings: ParamMapping[] = []
......
import { extractWeightsFactory, ParamMapping } from 'tfjs-image-recognition-base'; import { extractWeightsFactory, ParamMapping } from 'tfjs-image-recognition-base';
import { extractorsFactory } from './extractorsFactory'; import { extractorsFactory } from './extractorsFactory';
import { TinyNetParams } from './types'; import { TinyFaceFeatureExtractorParams } from './types';
export function extractParamsTiny(weights: Float32Array): { params: TinyNetParams, paramMappings: ParamMapping[] } { export function extractParamsTiny(weights: Float32Array): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [] const paramMappings: ParamMapping[] = []
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput, NeuralNetwork } from 'tfjs-image-recognition-base';
import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2'; import { ConvParams, SeparableConvParams } from 'tfjs-tiny-yolov2';
export type ConvWithBatchNormParams = BatchNormParams & { export type ConvWithBatchNormParams = BatchNormParams & {
...@@ -27,16 +28,19 @@ export type DenseBlock4Params = DenseBlock3Params & { ...@@ -27,16 +28,19 @@ export type DenseBlock4Params = DenseBlock3Params & {
conv3: SeparableConvParams conv3: SeparableConvParams
} }
export type TinyNetParams = { export type TinyFaceFeatureExtractorParams = {
dense0: DenseBlock3Params dense0: DenseBlock3Params
dense1: DenseBlock3Params dense1: DenseBlock3Params
dense2: DenseBlock3Params dense2: DenseBlock3Params
} }
export type NetParams = { export type FaceFeatureExtractorParams = {
dense0: DenseBlock4Params dense0: DenseBlock4Params
dense1: DenseBlock4Params dense1: DenseBlock4Params
dense2: DenseBlock4Params dense2: DenseBlock4Params
dense3: DenseBlock4Params dense3: DenseBlock4Params
} }
export interface IFaceFeatureExtractor<TNetParams extends TinyFaceFeatureExtractorParams | FaceFeatureExtractorParams> extends NeuralNetwork<TNetParams> {
forward(input: NetInput): tf.Tensor4D
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { NetInput } from 'tfjs-image-recognition-base';
import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor'; import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor';
import { extractParams } from './extractParams'; import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { FaceLandmark68NetBase } from './FaceLandmark68NetBase'; import { FaceLandmark68NetBase } from './FaceLandmark68NetBase';
import { fullyConnectedLayer } from './fullyConnectedLayer';
import { NetParams } from './types';
import { seperateWeightMaps } from './util';
export class FaceLandmark68Net extends FaceLandmark68NetBase<NetParams> {
private static classifierNumFilters: number = 256
private _faceFeatureExtractor: FaceFeatureExtractor
constructor(faceFeatureExtractor: FaceFeatureExtractor) {
super('FaceLandmark68Net')
this._faceFeatureExtractor = faceFeatureExtractor
}
public get faceFeatureExtractor(): FaceFeatureExtractor {
return this._faceFeatureExtractor
}
public runNet(input: NetInput | tf.Tensor4D): tf.Tensor2D {
const { params } = this
if (!params) { export class FaceLandmark68Net extends FaceLandmark68NetBase<FaceFeatureExtractorParams> {
throw new Error('FaceLandmark68Net - load model before inference')
}
if (!this.faceFeatureExtractor.isLoaded) { constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) {
throw new Error('FaceLandmark68Net - load face feature extractor model before inference') super('FaceLandmark68Net', faceFeatureExtractor)
}
return tf.tidy(() => {
const bottleneckFeatures = input instanceof NetInput
? this.faceFeatureExtractor.forward(input)
: input
return fullyConnectedLayer(bottleneckFeatures.as2D(bottleneckFeatures.shape[0], -1), params.fc)
})
}
public dispose(throwOnRedispose: boolean = true) {
this.faceFeatureExtractor.dispose(throwOnRedispose)
super.dispose(throwOnRedispose)
} }
protected getDefaultModelName(): string { protected getDefaultModelName(): string {
return 'face_landmark_68_model' return 'face_landmark_68_model'
} }
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) { protected getClassifierChannelsIn(): number {
return 256
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap)
return extractParamsFromWeigthMap(classifierMap)
}
protected extractParams(weights: Float32Array) {
const classifierWeightSize = 136 * FaceLandmark68Net.classifierNumFilters + 136
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
return extractParams(classifierWeights, FaceLandmark68Net.classifierNumFilters)
} }
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { IDimensions, isEven, NetInput, NeuralNetwork, Point, TNetInput, toNetInput } from 'tfjs-image-recognition-base'; import { IDimensions, isEven, NetInput, Point, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
import { FaceLandmarks68 } from '../classes/FaceLandmarks68'; import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
import { FaceFeatureExtractorParams, TinyFaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
import { FaceProcessor } from '../faceProcessor/FaceProcessor';
export abstract class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<NetParams> { export abstract class FaceLandmark68NetBase<
TExtractorParams extends FaceFeatureExtractorParams | TinyFaceFeatureExtractorParams
// TODO: make super.name protected >
private __name: string extends FaceProcessor<TExtractorParams> {
constructor(_name: string) {
super(_name)
this.__name = _name
}
public abstract runNet(netInput: NetInput): tf.Tensor2D
public postProcess(output: tf.Tensor2D, inputSize: number, originalDimensions: IDimensions[]): tf.Tensor2D { public postProcess(output: tf.Tensor2D, inputSize: number, originalDimensions: IDimensions[]): tf.Tensor2D {
...@@ -103,4 +98,8 @@ export abstract class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<Net ...@@ -103,4 +98,8 @@ export abstract class FaceLandmark68NetBase<NetParams> extends NeuralNetwork<Net
? landmarksForBatch ? landmarksForBatch
: landmarksForBatch[0] : landmarksForBatch[0]
} }
protected getClassifierChannelsOut(): number {
return 136
}
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import { TinyFaceFeatureExtractorParams } from 'src/faceFeatureExtractor/types';
import { NetInput } from 'tfjs-image-recognition-base';
import { TinyFaceFeatureExtractor } from '../faceFeatureExtractor/TinyFaceFeatureExtractor'; import { TinyFaceFeatureExtractor } from '../faceFeatureExtractor/TinyFaceFeatureExtractor';
import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { FaceLandmark68NetBase } from './FaceLandmark68NetBase'; import { FaceLandmark68NetBase } from './FaceLandmark68NetBase';
import { fullyConnectedLayer } from './fullyConnectedLayer';
import { NetParams } from './types';
import { seperateWeightMaps } from './util';
export class FaceLandmark68TinyNet extends FaceLandmark68NetBase<NetParams> { export class FaceLandmark68TinyNet extends FaceLandmark68NetBase<TinyFaceFeatureExtractorParams> {
private static classifierNumFilters: number = 128 constructor(faceFeatureExtractor: TinyFaceFeatureExtractor = new TinyFaceFeatureExtractor()) {
super('FaceLandmark68TinyNet', faceFeatureExtractor)
private _faceFeatureExtractor: TinyFaceFeatureExtractor
constructor(faceFeatureExtractor: TinyFaceFeatureExtractor) {
super('FaceLandmark68TinyNet')
this._faceFeatureExtractor = faceFeatureExtractor
}
public get faceFeatureExtractor(): TinyFaceFeatureExtractor {
return this._faceFeatureExtractor
}
public runNet(input: NetInput | tf.Tensor4D): tf.Tensor2D {
const { params } = this
if (!params) {
throw new Error('FaceLandmark68TinyNet - load model before inference')
}
if (!this.faceFeatureExtractor.isLoaded) {
throw new Error('FaceLandmark68TinyNet - load face feature extractor model before inference')
}
return tf.tidy(() => {
const bottleneckFeatures = input instanceof NetInput
? this.faceFeatureExtractor.forward(input)
: input
return fullyConnectedLayer(bottleneckFeatures.as2D(bottleneckFeatures.shape[0], -1), params.fc)
})
}
public dispose(throwOnRedispose: boolean = true) {
this.faceFeatureExtractor.dispose(throwOnRedispose)
super.dispose(throwOnRedispose)
} }
protected getDefaultModelName(): string { protected getDefaultModelName(): string {
return 'face_landmark_68_tiny_model' return 'face_landmark_68_tiny_model'
} }
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) { protected getClassifierChannelsIn(): number {
return 128
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap)
return extractParamsFromWeigthMap(classifierMap)
}
protected extractParams(weights: Float32Array) {
const classifierWeightSize = 136 * FaceLandmark68TinyNet.classifierNumFilters + 136
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
return extractParams(classifierWeights, FaceLandmark68TinyNet.classifierNumFilters)
} }
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { NetInput, NeuralNetwork } from 'tfjs-image-recognition-base';
import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
import {
FaceFeatureExtractorParams,
IFaceFeatureExtractor,
TinyFaceFeatureExtractorParams,
} from '../faceFeatureExtractor/types';
import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { NetParams } from './types';
import { seperateWeightMaps } from './util';
export abstract class FaceProcessor<
TExtractorParams extends FaceFeatureExtractorParams | TinyFaceFeatureExtractorParams
>
extends NeuralNetwork<NetParams> {
protected _faceFeatureExtractor: IFaceFeatureExtractor<TExtractorParams>
constructor(_name: string, faceFeatureExtractor: IFaceFeatureExtractor<TExtractorParams>) {
super(_name)
this._faceFeatureExtractor = faceFeatureExtractor
}
public get faceFeatureExtractor(): IFaceFeatureExtractor<TExtractorParams> {
return this._faceFeatureExtractor
}
protected abstract getDefaultModelName(): string
protected abstract getClassifierChannelsIn(): number
protected abstract getClassifierChannelsOut(): number
public runNet(input: NetInput | tf.Tensor4D): tf.Tensor2D {
const { params } = this
if (!params) {
throw new Error(`${this._name} - load model before inference`)
}
return tf.tidy(() => {
const bottleneckFeatures = input instanceof NetInput
? this.faceFeatureExtractor.forward(input)
: input
return fullyConnectedLayer(bottleneckFeatures.as2D(bottleneckFeatures.shape[0], -1), params.fc)
})
}
public dispose(throwOnRedispose: boolean = true) {
this.faceFeatureExtractor.dispose(throwOnRedispose)
super.dispose(throwOnRedispose)
}
protected extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap) {
const { featureExtractorMap, classifierMap } = seperateWeightMaps(weightMap)
this.faceFeatureExtractor.loadFromWeightMap(featureExtractorMap)
return extractParamsFromWeigthMap(classifierMap)
}
protected extractParams(weights: Float32Array) {
const cIn = this.getClassifierChannelsIn()
const cOut = this.getClassifierChannelsOut()
const classifierWeightSize = (cOut * cIn )+ cOut
const featureExtractorWeights = weights.slice(0, weights.length - classifierWeightSize)
const classifierWeights = weights.slice(weights.length - classifierWeightSize)
this.faceFeatureExtractor.extractWeights(featureExtractorWeights)
return extractParams(classifierWeights, cIn, cOut)
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { extractWeightsFactory, ParamMapping } from 'tfjs-image-recognition-base'; import { extractWeightsFactory, ParamMapping } from 'tfjs-image-recognition-base';
import { FCParams } from 'tfjs-tiny-yolov2'; import { extractFCParamsFactory } from 'tfjs-tiny-yolov2';
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParams(weights: Float32Array, numFilters: number): { params: NetParams, paramMappings: ParamMapping[] } { export function extractParams(weights: Float32Array, channelsIn: number, channelsOut: number): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = [] const paramMappings: ParamMapping[] = []
...@@ -13,22 +12,9 @@ export function extractParams(weights: Float32Array, numFilters: number): { para ...@@ -13,22 +12,9 @@ export function extractParams(weights: Float32Array, numFilters: number): { para
getRemainingWeights getRemainingWeights
} = extractWeightsFactory(weights) } = extractWeightsFactory(weights)
function extractFCParams(channelsIn: number, channelsOut: number, mappedPrefix: string): FCParams { const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings)
const weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
const bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push( const fc = extractFCParams(channelsIn, channelsOut, 'fc')
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/bias` }
)
return {
weights,
bias
}
}
const fc = extractFCParams(numFilters, 136, 'fc')
if (getRemainingWeights().length !== 0) { if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`) throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
......
export * from './FaceProcessor';
\ No newline at end of file
...@@ -18,16 +18,13 @@ import { TinyFaceDetector } from '../tinyFaceDetector/TinyFaceDetector'; ...@@ -18,16 +18,13 @@ import { TinyFaceDetector } from '../tinyFaceDetector/TinyFaceDetector';
import { TinyFaceDetectorOptions } from '../tinyFaceDetector/TinyFaceDetectorOptions'; import { TinyFaceDetectorOptions } from '../tinyFaceDetector/TinyFaceDetectorOptions';
import { TinyYolov2 } from '../tinyYolov2/TinyYolov2'; import { TinyYolov2 } from '../tinyYolov2/TinyYolov2';
const faceFeatureExtractor = new FaceFeatureExtractor()
const tinyFaceFeatureExtractor = new TinyFaceFeatureExtractor()
export const nets = { export const nets = {
ssdMobilenetv1: new SsdMobilenetv1(), ssdMobilenetv1: new SsdMobilenetv1(),
tinyFaceDetector: new TinyFaceDetector(), tinyFaceDetector: new TinyFaceDetector(),
tinyYolov2: new TinyYolov2(), tinyYolov2: new TinyYolov2(),
mtcnn: new Mtcnn(), mtcnn: new Mtcnn(),
faceLandmark68Net: new FaceLandmark68Net(faceFeatureExtractor), faceLandmark68Net: new FaceLandmark68Net(),
faceLandmark68TinyNet: new FaceLandmark68TinyNet(tinyFaceFeatureExtractor), faceLandmark68TinyNet: new FaceLandmark68TinyNet(),
faceRecognitionNet: new FaceRecognitionNet() faceRecognitionNet: new FaceRecognitionNet()
} }
......
...@@ -8,6 +8,7 @@ export * from 'tfjs-image-recognition-base'; ...@@ -8,6 +8,7 @@ export * from 'tfjs-image-recognition-base';
export * from './classes/index'; export * from './classes/index';
export * from './dom/index' export * from './dom/index'
export * from './faceExpressionNet/index';
export * from './faceLandmarkNet/index'; export * from './faceLandmarkNet/index';
export * from './faceRecognitionNet/index'; export * from './faceRecognitionNet/index';
export * from './factories/index'; export * from './factories/index';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { convLayer } from 'tfjs-tiny-yolov2'; import { convLayer } from 'tfjs-tiny-yolov2';
import { fullyConnectedLayer } from '../faceLandmarkNet/fullyConnectedLayer'; import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
import { prelu } from './prelu'; import { prelu } from './prelu';
import { sharedLayer } from './sharedLayers'; import { sharedLayer } from './sharedLayers';
import { ONetParams } from './types'; import { ONetParams } from './types';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { fullyConnectedLayer } from '../faceLandmarkNet/fullyConnectedLayer'; import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
import { prelu } from './prelu'; import { prelu } from './prelu';
import { sharedLayer } from './sharedLayers'; import { sharedLayer } from './sharedLayers';
import { RNetParams } from './types'; import { RNetParams } from './types';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { FaceFeatureExtractor } from '../../../src/faceFeatureExtractor/FaceFeatureExtractor';
import { FaceLandmark68NetBase } from '../../../src/faceLandmarkNet/FaceLandmark68NetBase'; import { FaceLandmark68NetBase } from '../../../src/faceLandmarkNet/FaceLandmark68NetBase';
class FakeFaceLandmark68NetBase extends FaceLandmark68NetBase<any> { class FakeFaceLandmark68NetBase extends FaceLandmark68NetBase<any> {
protected getDefaultModelName(): string { protected getDefaultModelName(): string {
throw new Error('FakeFaceLandmark68NetBase - getDefaultModelName not implemented') throw new Error('FakeFaceLandmark68NetBase - getDefaultModelName not implemented')
} }
protected getClassifierChannelsIn(): number {
throw new Error('FakeFaceLandmark68NetBase - getClassifierChannelsIn not implemented')
}
protected extractParams(_: any): any { protected extractParams(_: any): any {
throw new Error('FakeFaceLandmark68NetBase - extractParams not implemented') throw new Error('FakeFaceLandmark68NetBase - extractParams not implemented')
} }
...@@ -18,13 +24,13 @@ class FakeFaceLandmark68NetBase extends FaceLandmark68NetBase<any> { ...@@ -18,13 +24,13 @@ class FakeFaceLandmark68NetBase extends FaceLandmark68NetBase<any> {
public runNet(): any { public runNet(): any {
throw new Error('FakeFaceLandmark68NetBase - extractParamsFromWeigthMap not implemented') throw new Error('FakeFaceLandmark68NetBase - extractParamsFromWeigthMap not implemented')
} }
} }
describe('FaceLandmark68NetBase', () => { describe('FaceLandmark68NetBase', () => {
describe('postProcess', () => { describe('postProcess', () => {
const net = new FakeFaceLandmark68NetBase('') const net = new FakeFaceLandmark68NetBase('', new FaceFeatureExtractor())
describe('single batch', () => { describe('single batch', () => {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment