Unverified Commit cb52dab4 by Vincent Mühler Committed by GitHub

Merge pull request #504 from justadudewhohacks/remove-tfjs-image-recognition-base

remove tfjs image recognition base
parents c804dc68 0b4470af
...@@ -26,13 +26,13 @@ function getFaceDetectorOptions() { ...@@ -26,13 +26,13 @@ function getFaceDetectorOptions() {
} }
function onIncreaseMinConfidence() { function onIncreaseMinConfidence() {
minConfidence = Math.min(faceapi.round(minConfidence + 0.1), 1.0) minConfidence = Math.min(faceapi.utils.round(minConfidence + 0.1), 1.0)
$('#minConfidence').val(minConfidence) $('#minConfidence').val(minConfidence)
updateResults() updateResults()
} }
function onDecreaseMinConfidence() { function onDecreaseMinConfidence() {
minConfidence = Math.max(faceapi.round(minConfidence - 0.1), 0.1) minConfidence = Math.max(faceapi.utils.round(minConfidence - 0.1), 0.1)
$('#minConfidence').val(minConfidence) $('#minConfidence').val(minConfidence)
updateResults() updateResults()
} }
...@@ -51,24 +51,24 @@ function changeInputSize(size) { ...@@ -51,24 +51,24 @@ function changeInputSize(size) {
} }
function onIncreaseScoreThreshold() { function onIncreaseScoreThreshold() {
scoreThreshold = Math.min(faceapi.round(scoreThreshold + 0.1), 1.0) scoreThreshold = Math.min(faceapi.utils.round(scoreThreshold + 0.1), 1.0)
$('#scoreThreshold').val(scoreThreshold) $('#scoreThreshold').val(scoreThreshold)
updateResults() updateResults()
} }
function onDecreaseScoreThreshold() { function onDecreaseScoreThreshold() {
scoreThreshold = Math.max(faceapi.round(scoreThreshold - 0.1), 0.1) scoreThreshold = Math.max(faceapi.utils.round(scoreThreshold - 0.1), 0.1)
$('#scoreThreshold').val(scoreThreshold) $('#scoreThreshold').val(scoreThreshold)
updateResults() updateResults()
} }
function onIncreaseMinFaceSize() { function onIncreaseMinFaceSize() {
minFaceSize = Math.min(faceapi.round(minFaceSize + 20), 300) minFaceSize = Math.min(faceapi.utils.round(minFaceSize + 20), 300)
$('#minFaceSize').val(minFaceSize) $('#minFaceSize').val(minFaceSize)
} }
function onDecreaseMinFaceSize() { function onDecreaseMinFaceSize() {
minFaceSize = Math.max(faceapi.round(minFaceSize - 20), 50) minFaceSize = Math.max(faceapi.utils.round(minFaceSize - 20), 50)
$('#minFaceSize').val(minFaceSize) $('#minFaceSize').val(minFaceSize)
} }
......
...@@ -161,8 +161,8 @@ ...@@ -161,8 +161,8 @@
const { age, gender, genderProbability } = result const { age, gender, genderProbability } = result
new faceapi.draw.DrawTextField( new faceapi.draw.DrawTextField(
[ [
`${faceapi.round(age, 0)} years`, `${faceapi.utils.round(age, 0)} years`,
`${gender} (${faceapi.round(genderProbability)})` `${gender} (${faceapi.utils.round(genderProbability)})`
], ],
result.detection.box.bottomLeft result.detection.box.bottomLeft
).draw(canvas) ).draw(canvas)
......
...@@ -96,7 +96,7 @@ ...@@ -96,7 +96,7 @@
function displayTimeStats(timeInMs) { function displayTimeStats(timeInMs) {
$('#time').val(`${timeInMs} ms`) $('#time').val(`${timeInMs} ms`)
$('#fps').val(`${faceapi.round(1000 / timeInMs)}`) $('#fps').val(`${faceapi.utils.round(1000 / timeInMs)}`)
} }
function displayImage(src) { function displayImage(src) {
......
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
let descriptors = { desc1: null, desc2: null } let descriptors = { desc1: null, desc2: null }
function updateResult() { function updateResult() {
const distance = faceapi.round( const distance = faceapi.utils.round(
faceapi.euclideanDistance(descriptors.desc1, descriptors.desc2) faceapi.euclideanDistance(descriptors.desc1, descriptors.desc2)
) )
let text = distance let text = distance
......
...@@ -156,7 +156,7 @@ ...@@ -156,7 +156,7 @@
forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30) forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30)
const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length
$('#time').val(`${Math.round(avgTimeInMs)} ms`) $('#time').val(`${Math.round(avgTimeInMs)} ms`)
$('#fps').val(`${faceapi.round(1000 / avgTimeInMs)}`) $('#fps').val(`${faceapi.utils.round(1000 / avgTimeInMs)}`)
} }
async function onPlay(videoEl) { async function onPlay(videoEl) {
......
...@@ -152,7 +152,7 @@ ...@@ -152,7 +152,7 @@
forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30) forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30)
const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length
$('#time').val(`${Math.round(avgTimeInMs)} ms`) $('#time').val(`${Math.round(avgTimeInMs)} ms`)
$('#fps').val(`${faceapi.round(1000 / avgTimeInMs)}`) $('#fps').val(`${faceapi.utils.round(1000 / avgTimeInMs)}`)
} }
function interpolateAgePredictions(age) { function interpolateAgePredictions(age) {
...@@ -192,8 +192,8 @@ ...@@ -192,8 +192,8 @@
const interpolatedAge = interpolateAgePredictions(age) const interpolatedAge = interpolateAgePredictions(age)
new faceapi.draw.DrawTextField( new faceapi.draw.DrawTextField(
[ [
`${faceapi.round(interpolatedAge, 0)} years`, `${faceapi.utils.round(interpolatedAge, 0)} years`,
`${gender} (${faceapi.round(genderProbability)})` `${gender} (${faceapi.utils.round(genderProbability)})`
], ],
result.detection.box.bottomLeft result.detection.box.bottomLeft
).draw(canvas) ).draw(canvas)
......
...@@ -139,7 +139,7 @@ ...@@ -139,7 +139,7 @@
forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30) forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30)
const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length
$('#time').val(`${Math.round(avgTimeInMs)} ms`) $('#time').val(`${Math.round(avgTimeInMs)} ms`)
$('#fps').val(`${faceapi.round(1000 / avgTimeInMs)}`) $('#fps').val(`${faceapi.utils.round(1000 / avgTimeInMs)}`)
} }
async function onPlay() { async function onPlay() {
......
...@@ -151,7 +151,7 @@ ...@@ -151,7 +151,7 @@
forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30) forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30)
const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length
$('#time').val(`${Math.round(avgTimeInMs)} ms`) $('#time').val(`${Math.round(avgTimeInMs)} ms`)
$('#fps').val(`${faceapi.round(1000 / avgTimeInMs)}`) $('#fps').val(`${faceapi.utils.round(1000 / avgTimeInMs)}`)
} }
async function onPlay() { async function onPlay() {
......
...@@ -151,7 +151,7 @@ ...@@ -151,7 +151,7 @@
forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30) forwardTimes = [timeInMs].concat(forwardTimes).slice(0, 30)
const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length const avgTimeInMs = forwardTimes.reduce((total, t) => total + t) / forwardTimes.length
$('#time').val(`${Math.round(avgTimeInMs)} ms`) $('#time').val(`${Math.round(avgTimeInMs)} ms`)
$('#fps').val(`${faceapi.round(1000 / avgTimeInMs)}`) $('#fps').val(`${faceapi.utils.round(1000 / avgTimeInMs)}`)
} }
async function onPlay() { async function onPlay() {
......
...@@ -19,8 +19,8 @@ async function run() { ...@@ -19,8 +19,8 @@ async function run() {
const { age, gender, genderProbability } = result const { age, gender, genderProbability } = result
new faceapi.draw.DrawTextField( new faceapi.draw.DrawTextField(
[ [
`${faceapi.round(age, 0)} years`, `${faceapi.utils.round(age, 0)} years`,
`${gender} (${faceapi.round(genderProbability)})` `${gender} (${faceapi.utils.round(genderProbability)})`
], ],
result.detection.box.bottomLeft result.detection.box.bottomLeft
).draw(out) ).draw(out)
......
let spec_files = ['**/*.test.ts'].concat( let spec_files = ['**/*.test.ts']
process.env.EXCLUDE_UNCOMPRESSED
? ['!**/*.uncompressed.test.ts']
: []
)
// exclude browser tests // exclude browser tests
spec_files = spec_files.concat(['!**/*.browser.test.ts']) spec_files = spec_files.concat(['!**/*.browser.test.ts'])
spec_files = spec_files.concat(['!test/tests.legacy/*'])
module.exports = { module.exports = {
spec_dir: 'test', spec_dir: 'test',
......
...@@ -2,6 +2,7 @@ const dataFiles = [ ...@@ -2,6 +2,7 @@ const dataFiles = [
'test/images/*.jpg', 'test/images/*.jpg',
'test/images/*.png', 'test/images/*.png',
'test/data/*.json', 'test/data/*.json',
'test/data/*.weights',
'test/media/*.mp4', 'test/media/*.mp4',
'weights/**/*', 'weights/**/*',
'weights_uncompressed/**/*', 'weights_uncompressed/**/*',
...@@ -21,24 +22,17 @@ let exclude = ( ...@@ -21,24 +22,17 @@ let exclude = (
'faceLandmarkNet', 'faceLandmarkNet',
'faceRecognitionNet', 'faceRecognitionNet',
'ssdMobilenetv1', 'ssdMobilenetv1',
'tinyFaceDetector', 'tinyFaceDetector'
'mtcnn'
] ]
: [] : []
) )
.filter(ex => ex !== process.env.UUT) .filter(ex => ex !== process.env.UUT)
.map(ex => `test/tests/${ex}/*.ts`) .map(ex => `test/tests/${ex}/*.ts`)
exclude = exclude.concat(
process.env.EXCLUDE_UNCOMPRESSED
? ['**/*.uncompressed.test.ts']
: []
)
// exclude nodejs tests // exclude nodejs tests
exclude = exclude.concat(['**/*.node.test.ts']) exclude = exclude.concat(['**/*.node.test.ts'])
exclude = exclude.concat(['test/env.node.ts']) exclude = exclude.concat(['test/env.node.ts'])
exclude = exclude.concat(['test/tests-legacy/**/*.ts'])
module.exports = function(config) { module.exports = function(config) {
......
...@@ -1487,9 +1487,9 @@ ...@@ -1487,9 +1487,9 @@
"dev": true "dev": true
}, },
"fsevents": { "fsevents": {
"version": "2.0.7", "version": "2.1.2",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.0.7.tgz", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.1.2.tgz",
"integrity": "sha512-a7YT0SV3RB+DjYcppwVDLtn13UQnmg0SWZS7ezZD0UjnLwXmy8Zm21GMVGLaFGimIqcvyMQaOJBrop8MyOp1kQ==", "integrity": "sha512-R4wDiBwZ0KzpgOWetKDug1FZcYhqYnUYKtfZYt4mD5SBz76q0KR4Q9o7GIPamsVPGmW3EYPPJ0dOOjvx32ldZA==",
"dev": true, "dev": true,
"optional": true "optional": true
}, },
...@@ -3917,15 +3917,6 @@ ...@@ -3917,15 +3917,6 @@
"yallist": "^3.0.3" "yallist": "^3.0.3"
} }
}, },
"tfjs-image-recognition-base": {
"version": "0.6.2",
"resolved": "https://registry.npmjs.org/tfjs-image-recognition-base/-/tfjs-image-recognition-base-0.6.2.tgz",
"integrity": "sha512-ukxViVfAPw7s0KiGhwr3zrwsm+EVa2Z+4aEwKBWO43Rt48nbPyVvrHdL+WbxRynZYjklEE69ft66C8zzea7vFw==",
"requires": {
"@tensorflow/tfjs-core": "^1.2.9",
"tslib": "^1.10.0"
}
},
"through2": { "through2": {
"version": "3.0.0", "version": "3.0.0",
"resolved": "https://registry.npmjs.org/through2/-/through2-3.0.0.tgz", "resolved": "https://registry.npmjs.org/through2/-/through2-3.0.0.tgz",
......
...@@ -11,22 +11,16 @@ ...@@ -11,22 +11,16 @@
"tsc": "tsc", "tsc": "tsc",
"tsc-es6": "tsc --p tsconfig.es6.json", "tsc-es6": "tsc --p tsconfig.es6.json",
"build": "rm -rf ./build && rm -rf ./dist && npm run rollup && npm run rollup-min && npm run tsc && npm run tsc-es6", "build": "rm -rf ./build && rm -rf ./dist && npm run rollup && npm run rollup-min && npm run tsc && npm run tsc-es6",
"test": "karma start", "test": "npm run test-browser && npm run test-node",
"test-browser": "karma start --single-run", "test-browser": "karma start --single-run",
"test-node": "ts-node -r ./test/env.node.ts node_modules/jasmine/bin/jasmine --config=jasmine-node.js", "test-node": "ts-node -r ./test/env.node.ts node_modules/jasmine/bin/jasmine --config=jasmine-node.js",
"test-all": "npm run test-browser-exclude-uncompressed && npm run test-node-exclude-uncompressed",
"test-all-include-uncompressed": "npm run test-browser && npm run test-node",
"test-facelandmarknets": "set UUT=faceLandmarkNet&& karma start", "test-facelandmarknets": "set UUT=faceLandmarkNet&& karma start",
"test-facerecognitionnet": "set UUT=faceRecognitionNet&& karma start", "test-facerecognitionnet": "set UUT=faceRecognitionNet&& karma start",
"test-agegendernet": "set UUT=ageGenderNet&& karma start", "test-agegendernet": "set UUT=ageGenderNet&& karma start",
"test-ssdmobilenetv1": "set UUT=ssdMobilenetv1&& karma start", "test-ssdmobilenetv1": "set UUT=ssdMobilenetv1&& karma start",
"test-tinyfacedetector": "set UUT=tinyFaceDetector&& karma start", "test-tinyfacedetector": "set UUT=tinyFaceDetector&& karma start",
"test-globalapi": "set UUT=globalApi&& karma start", "test-globalapi": "set UUT=globalApi&& karma start",
"test-mtcnn": "set UUT=mtcnn&& karma start",
"test-cpu": "set BACKEND_CPU=true&& karma start", "test-cpu": "set BACKEND_CPU=true&& karma start",
"test-exclude-uncompressed": "set EXCLUDE_UNCOMPRESSED=true&& karma start",
"test-browser-exclude-uncompressed": "set EXCLUDE_UNCOMPRESSED=true&& karma start --single-run",
"test-node-exclude-uncompressed": "set EXCLUDE_UNCOMPRESSED=true&& npm run test-node",
"docs": "typedoc --options ./typedoc.config.js ./src" "docs": "typedoc --options ./typedoc.config.js ./src"
}, },
"keywords": [ "keywords": [
...@@ -40,7 +34,6 @@ ...@@ -40,7 +34,6 @@
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"@tensorflow/tfjs-core": "1.2.9", "@tensorflow/tfjs-core": "1.2.9",
"tfjs-image-recognition-base": "^0.6.2",
"tslib": "^1.10.0" "tslib": "^1.10.0"
}, },
"devDependencies": { "devDependencies": {
......
import * as tf from '@tensorflow/tfjs-core';
import { ParamMapping } from './common';
import { getModelUris } from './common/getModelUris';
import { loadWeightMap } from './dom';
import { env } from './env';
export abstract class NeuralNetwork<TNetParams> {
protected _params: TNetParams | undefined = undefined
protected _paramMappings: ParamMapping[] = []
constructor(protected _name: string) {}
public get params(): TNetParams | undefined { return this._params }
public get paramMappings(): ParamMapping[] { return this._paramMappings }
public get isLoaded(): boolean { return !!this.params }
public getParamFromPath(paramPath: string): tf.Tensor {
const { obj, objProp } = this.traversePropertyPath(paramPath)
return obj[objProp]
}
public reassignParamFromPath(paramPath: string, tensor: tf.Tensor) {
const { obj, objProp } = this.traversePropertyPath(paramPath)
obj[objProp].dispose()
obj[objProp] = tensor
}
public getParamList() {
return this._paramMappings.map(({ paramPath }) => ({
path: paramPath,
tensor: this.getParamFromPath(paramPath)
}))
}
public getTrainableParams() {
return this.getParamList().filter(param => param.tensor instanceof tf.Variable)
}
public getFrozenParams() {
return this.getParamList().filter(param => !(param.tensor instanceof tf.Variable))
}
public variable() {
this.getFrozenParams().forEach(({ path, tensor }) => {
this.reassignParamFromPath(path, tensor.variable())
})
}
public freeze() {
this.getTrainableParams().forEach(({ path, tensor: variable }) => {
const tensor = tf.tensor(variable.dataSync())
variable.dispose()
this.reassignParamFromPath(path, tensor)
})
}
public dispose(throwOnRedispose: boolean = true) {
this.getParamList().forEach(param => {
if (throwOnRedispose && param.tensor.isDisposed) {
throw new Error(`param tensor has already been disposed for path ${param.path}`)
}
param.tensor.dispose()
})
this._params = undefined
}
public serializeParams(): Float32Array {
return new Float32Array(
this.getParamList()
.map(({ tensor }) => Array.from(tensor.dataSync()) as number[])
.reduce((flat, arr) => flat.concat(arr))
)
}
public async load(weightsOrUrl: Float32Array | string | undefined): Promise<void> {
if (weightsOrUrl instanceof Float32Array) {
this.extractWeights(weightsOrUrl)
return
}
await this.loadFromUri(weightsOrUrl)
}
public async loadFromUri(uri: string | undefined) {
if (uri && typeof uri !== 'string') {
throw new Error(`${this._name}.loadFromUri - expected model uri`)
}
const weightMap = await loadWeightMap(uri, this.getDefaultModelName())
this.loadFromWeightMap(weightMap)
}
public async loadFromDisk(filePath: string | undefined) {
if (filePath && typeof filePath !== 'string') {
throw new Error(`${this._name}.loadFromDisk - expected model file path`)
}
const { readFile } = env.getEnv()
const { manifestUri, modelBaseUri } = getModelUris(filePath, this.getDefaultModelName())
const fetchWeightsFromDisk = (filePaths: string[]) => Promise.all(
filePaths.map(filePath => readFile(filePath).then(buf => buf.buffer))
)
const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)
const manifest = JSON.parse((await readFile(manifestUri)).toString())
const weightMap = await loadWeights(manifest, modelBaseUri)
this.loadFromWeightMap(weightMap)
}
public loadFromWeightMap(weightMap: tf.NamedTensorMap) {
const {
paramMappings,
params
} = this.extractParamsFromWeigthMap(weightMap)
this._paramMappings = paramMappings
this._params = params
}
public extractWeights(weights: Float32Array) {
const {
paramMappings,
params
} = this.extractParams(weights)
this._paramMappings = paramMappings
this._params = params
}
private traversePropertyPath(paramPath: string) {
if (!this.params) {
throw new Error(`traversePropertyPath - model has no loaded params`)
}
const result = paramPath.split('/').reduce((res: { nextObj: any, obj?: any, objProp?: string }, objProp) => {
if (!res.nextObj.hasOwnProperty(objProp)) {
throw new Error(`traversePropertyPath - object does not have property ${objProp}, for path ${paramPath}`)
}
return { obj: res.nextObj, objProp, nextObj: res.nextObj[objProp] }
}, { nextObj: this.params })
const { obj, objProp } = result
if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
throw new Error(`traversePropertyPath - parameter is not a tensor, for path ${paramPath}`)
}
return { obj, objProp }
}
protected abstract getDefaultModelName(): string
protected abstract extractParamsFromWeigthMap(weightMap: tf.NamedTensorMap): { params: TNetParams, paramMappings: ParamMapping[] }
protected abstract extractParams(weights: Float32Array): { params: TNetParams, paramMappings: ParamMapping[] }
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput, NeuralNetwork, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
import { fullyConnectedLayer } from '../common/fullyConnectedLayer'; import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
import { seperateWeightMaps } from '../faceProcessor/util'; import { seperateWeightMaps } from '../faceProcessor/util';
...@@ -7,6 +6,8 @@ import { TinyXception } from '../xception/TinyXception'; ...@@ -7,6 +6,8 @@ import { TinyXception } from '../xception/TinyXception';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap'; import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { AgeAndGenderPrediction, Gender, NetOutput, NetParams } from './types'; import { AgeAndGenderPrediction, Gender, NetOutput, NetParams } from './types';
import { NeuralNetwork } from '../NeuralNetwork';
import { NetInput, TNetInput, toNetInput } from '../dom';
export class AgeGenderNet extends NeuralNetwork<NetParams> { export class AgeGenderNet extends NeuralNetwork<NetParams> {
......
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base'; import { extractFCParamsFactory, extractWeightsFactory, ParamMapping } from '../common';
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } { export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights
} = TfjsImageRecognitionBase.extractWeightsFactory(weights) } = extractWeightsFactory(weights)
const extractFCParams = TfjsImageRecognitionBase.extractFCParamsFactory(extractWeights, paramMappings) const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings)
const age = extractFCParams(512, 1, 'fc/age') const age = extractFCParams(512, 1, 'fc/age')
const gender = extractFCParams(512, 2, 'fc/gender') const gender = extractFCParams(512, 2, 'fc/gender')
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common';
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParamsFromWeigthMap( export function extractParamsFromWeigthMap(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap
): { params: NetParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } { ): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] const paramMappings: ParamMapping[] = []
const extractWeightEntry = TfjsImageRecognitionBase.extractWeightEntryFactory(weightMap, paramMappings) const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
function extractFcParams(prefix: string): TfjsImageRecognitionBase.FCParams { function extractFcParams(prefix: string): FCParams {
const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2) const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return { weights, bias } return { weights, bias }
...@@ -24,7 +24,7 @@ export function extractParamsFromWeigthMap( ...@@ -24,7 +24,7 @@ export function extractParamsFromWeigthMap(
} }
} }
TfjsImageRecognitionBase.disposeUnusedWeightTensors(weightMap, paramMappings) disposeUnusedWeightTensors(weightMap, paramMappings)
return { params, paramMappings } return { params, paramMappings }
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { FCParams } from '../common';
export type AgeAndGenderPrediction = { export type AgeAndGenderPrediction = {
age: number age: number
...@@ -16,7 +17,7 @@ export type NetOutput = { age: tf.Tensor1D, gender: tf.Tensor2D } ...@@ -16,7 +17,7 @@ export type NetOutput = { age: tf.Tensor1D, gender: tf.Tensor2D }
export type NetParams = { export type NetParams = {
fc: { fc: {
age: TfjsImageRecognitionBase.FCParams age: FCParams
gender: TfjsImageRecognitionBase.FCParams gender: FCParams
} }
} }
\ No newline at end of file
import { Box } from './Box';
export interface IBoundingBox {
left: number
top: number
right: number
bottom: number
}
export class BoundingBox extends Box<BoundingBox> implements IBoundingBox {
constructor(left: number, top: number, right: number, bottom: number, allowNegativeDimensions: boolean = false) {
super({ left, top, right, bottom }, allowNegativeDimensions)
}
}
\ No newline at end of file
import { isDimensions, isValidNumber } from '../utils';
import { IBoundingBox } from './BoundingBox';
import { IDimensions } from './Dimensions';
import { Point } from './Point';
import { IRect } from './Rect';
export class Box<BoxType = any> implements IBoundingBox, IRect {
public static isRect(rect: any): boolean {
return !!rect && [rect.x, rect.y, rect.width, rect.height].every(isValidNumber)
}
public static assertIsValidBox(box: any, callee: string, allowNegativeDimensions: boolean = false) {
if (!Box.isRect(box)) {
throw new Error(`${callee} - invalid box: ${JSON.stringify(box)}, expected object with properties x, y, width, height`)
}
if (!allowNegativeDimensions && (box.width < 0 || box.height < 0)) {
throw new Error(`${callee} - width (${box.width}) and height (${box.height}) must be positive numbers`)
}
}
private _x: number
private _y: number
private _width: number
private _height: number
constructor(_box: IBoundingBox | IRect, allowNegativeDimensions: boolean = true) {
const box = (_box || {}) as any
const isBbox = [box.left, box.top, box.right, box.bottom].every(isValidNumber)
const isRect = [box.x, box.y, box.width, box.height].every(isValidNumber)
if (!isRect && !isBbox) {
throw new Error(`Box.constructor - expected box to be IBoundingBox | IRect, instead have ${JSON.stringify(box)}`)
}
const [x, y, width, height] = isRect
? [box.x, box.y, box.width, box.height]
: [box.left, box.top, box.right - box.left, box.bottom - box.top]
Box.assertIsValidBox({ x, y, width, height }, 'Box.constructor', allowNegativeDimensions)
this._x = x
this._y = y
this._width = width
this._height = height
}
public get x(): number { return this._x }
public get y(): number { return this._y }
public get width(): number { return this._width }
public get height(): number { return this._height }
public get left(): number { return this.x }
public get top(): number { return this.y }
public get right(): number { return this.x + this.width }
public get bottom(): number { return this.y + this.height }
public get area(): number { return this.width * this.height }
public get topLeft(): Point { return new Point(this.left, this.top) }
public get topRight(): Point { return new Point(this.right, this.top) }
public get bottomLeft(): Point { return new Point(this.left, this.bottom) }
public get bottomRight(): Point { return new Point(this.right, this.bottom) }
public round(): Box<BoxType> {
const [x, y, width, height] = [this.x, this.y, this.width, this.height]
.map(val => Math.round(val))
return new Box({ x, y, width, height })
}
public floor(): Box<BoxType> {
const [x, y, width, height] = [this.x, this.y, this.width, this.height]
.map(val => Math.floor(val))
return new Box({ x, y, width, height })
}
public toSquare(): Box<BoxType> {
let { x, y, width, height } = this
const diff = Math.abs(width - height)
if (width < height) {
x -= (diff / 2)
width += diff
}
if (height < width) {
y -= (diff / 2)
height += diff
}
return new Box({ x, y, width, height })
}
public rescale(s: IDimensions | number): Box<BoxType> {
const scaleX = isDimensions(s) ? (s as IDimensions).width : s as number
const scaleY = isDimensions(s) ? (s as IDimensions).height : s as number
return new Box({
x: this.x * scaleX,
y: this.y * scaleY,
width: this.width * scaleX,
height: this.height * scaleY
})
}
public pad(padX: number, padY: number): Box<BoxType> {
let [x, y, width, height] = [
this.x - (padX / 2),
this.y - (padY / 2),
this.width + padX,
this.height + padY
]
return new Box({ x, y, width, height })
}
public clipAtImageBorders(imgWidth: number, imgHeight: number): Box<BoxType> {
const { x, y, right, bottom } = this
const clippedX = Math.max(x, 0)
const clippedY = Math.max(y, 0)
const newWidth = right - clippedX
const newHeight = bottom - clippedY
const clippedWidth = Math.min(newWidth, imgWidth - clippedX)
const clippedHeight = Math.min(newHeight, imgHeight - clippedY)
return (new Box({ x: clippedX, y: clippedY, width: clippedWidth, height: clippedHeight})).floor()
}
public shift(sx: number, sy: number): Box<BoxType> {
const { width, height } = this
const x = this.x + sx
const y = this.y + sy
return new Box({ x, y, width, height })
}
public padAtBorders(imageHeight: number, imageWidth: number) {
const w = this.width + 1
const h = this.height + 1
let dx = 1
let dy = 1
let edx = w
let edy = h
let x = this.left
let y = this.top
let ex = this.right
let ey = this.bottom
if (ex > imageWidth) {
edx = -ex + imageWidth + w
ex = imageWidth
}
if (ey > imageHeight) {
edy = -ey + imageHeight + h
ey = imageHeight
}
if (x < 1) {
edy = 2 - x
x = 1
}
if (y < 1) {
edy = 2 - y
y = 1
}
return { dy, edy, dx, edx, y, ey, x, ex, w, h }
}
public calibrate(region: Box) {
return new Box({
left: this.left + (region.left * this.width),
top: this.top + (region.top * this.height),
right: this.right + (region.right * this.width),
bottom: this.bottom + (region.bottom * this.height)
}).toSquare().round()
}
}
\ No newline at end of file
import { isValidNumber } from '../utils';
export interface IDimensions {
width: number
height: number
}
export class Dimensions implements IDimensions {
private _width: number
private _height: number
constructor(width: number, height: number) {
if (!isValidNumber(width) || !isValidNumber(height)) {
throw new Error(`Dimensions.constructor - expected width and height to be valid numbers, instead have ${JSON.stringify({ width, height })}`)
}
this._width = width
this._height = height
}
public get width(): number { return this._width }
public get height(): number { return this._height }
public reverse(): Dimensions {
return new Dimensions(1 / this.width, 1 / this.height)
}
}
\ No newline at end of file
import { Box, IDimensions, ObjectDetection, Rect } from 'tfjs-image-recognition-base'; import { Box } from './Box';
import { IDimensions } from './Dimensions';
import { ObjectDetection } from './ObjectDetection';
import { Rect } from './Rect';
export interface IFaceDetecion { export interface IFaceDetecion {
score: number score: number
......
import { Box, Dimensions, getCenterPoint, IBoundingBox, IDimensions, IRect, Point, Rect } from 'tfjs-image-recognition-base'; import { minBbox } from '../ops';
import { getCenterPoint } from '../utils';
import { minBbox } from '../minBbox'; import { IBoundingBox } from './BoundingBox';
import { Box } from './Box';
import { Dimensions, IDimensions } from './Dimensions';
import { FaceDetection } from './FaceDetection'; import { FaceDetection } from './FaceDetection';
import { Point } from './Point';
import { IRect, Rect } from './Rect';
// face alignment constants // face alignment constants
const relX = 0.5 const relX = 0.5
......
import { getCenterPoint, Point } from 'tfjs-image-recognition-base'; import { getCenterPoint } from '../utils';
import { FaceLandmarks } from './FaceLandmarks'; import { FaceLandmarks } from './FaceLandmarks';
import { Point } from './Point';
export class FaceLandmarks5 extends FaceLandmarks { export class FaceLandmarks5 extends FaceLandmarks {
......
import { getCenterPoint, Point } from 'tfjs-image-recognition-base'; import { getCenterPoint } from '../utils';
import { FaceLandmarks } from './FaceLandmarks';
import { FaceLandmarks } from '../classes/FaceLandmarks'; import { Point } from './Point';
export class FaceLandmarks68 extends FaceLandmarks { export class FaceLandmarks68 extends FaceLandmarks {
public getJawOutline(): Point[] { public getJawOutline(): Point[] {
......
import { round } from 'tfjs-image-recognition-base'; import { round } from '../utils';
export interface IFaceMatch { export interface IFaceMatch {
label: string label: string
......
import { isValidNumber } from '../utils';
import { IBoundingBox } from './BoundingBox';
import { Box } from './Box';
import { IRect } from './Rect';
export class LabeledBox extends Box<LabeledBox> {
public static assertIsValidLabeledBox(box: any, callee: string) {
Box.assertIsValidBox(box, callee)
if (!isValidNumber(box.label)) {
throw new Error(`${callee} - expected property label (${box.label}) to be a number`)
}
}
private _label: number
constructor(box: IBoundingBox | IRect | any, label: number) {
super(box)
this._label = label
}
public get label(): number { return this._label }
}
\ No newline at end of file
import { Box } from './Box';
import { Dimensions, IDimensions } from './Dimensions';
import { IRect, Rect } from './Rect';
export class ObjectDetection {
private _score: number
private _classScore: number
private _className: string
private _box: Rect
private _imageDims: Dimensions
constructor(
score: number,
classScore: number,
className: string,
relativeBox: IRect,
imageDims: IDimensions
) {
this._imageDims = new Dimensions(imageDims.width, imageDims.height)
this._score = score
this._classScore = classScore
this._className = className
this._box = new Box(relativeBox).rescale(this._imageDims)
}
public get score(): number { return this._score }
public get classScore(): number { return this._classScore }
public get className(): string { return this._className }
public get box(): Box { return this._box }
public get imageDims(): Dimensions { return this._imageDims }
public get imageWidth(): number { return this.imageDims.width }
public get imageHeight(): number { return this.imageDims.height }
public get relativeBox(): Box { return new Box(this._box).rescale(this.imageDims.reverse()) }
public forSize(width: number, height: number): ObjectDetection {
return new ObjectDetection(
this.score,
this.classScore,
this.className,
this.relativeBox,
{ width, height}
)
}
}
\ No newline at end of file
export interface IPoint {
x: number
y: number
}
export class Point implements IPoint {
private _x: number
private _y: number
constructor(x: number, y: number) {
this._x = x
this._y = y
}
get x(): number { return this._x }
get y(): number { return this._y }
public add(pt: IPoint): Point {
return new Point(this.x + pt.x, this.y + pt.y)
}
public sub(pt: IPoint): Point {
return new Point(this.x - pt.x, this.y - pt.y)
}
public mul(pt: IPoint): Point {
return new Point(this.x * pt.x, this.y * pt.y)
}
public div(pt: IPoint): Point {
return new Point(this.x / pt.x, this.y / pt.y)
}
public abs(): Point {
return new Point(Math.abs(this.x), Math.abs(this.y))
}
public magnitude(): number {
return Math.sqrt(Math.pow(this.x, 2) + Math.pow(this.y, 2))
}
public floor(): Point {
return new Point(Math.floor(this.x), Math.floor(this.y))
}
}
\ No newline at end of file
import { isValidProbablitiy } from '../utils';
import { IBoundingBox } from './BoundingBox';
import { LabeledBox } from './LabeledBox';
import { IRect } from './Rect';
export class PredictedBox extends LabeledBox {
public static assertIsValidPredictedBox(box: any, callee: string) {
LabeledBox.assertIsValidLabeledBox(box, callee)
if (
!isValidProbablitiy(box.score)
|| !isValidProbablitiy(box.classScore)
) {
throw new Error(`${callee} - expected properties score (${box.score}) and (${box.classScore}) to be a number between [0, 1]`)
}
}
private _score: number
private _classScore: number
constructor(box: IBoundingBox | IRect | any, label: number, score: number, classScore: number) {
super(box, label)
this._score = score
this._classScore = classScore
}
public get score(): number { return this._score }
public get classScore(): number { return this._classScore }
}
\ No newline at end of file
import { Box } from './Box';
export interface IRect {
x: number
y: number
width: number
height: number
}
export class Rect extends Box<Rect> implements IRect {
constructor(x: number, y: number, width: number, height: number, allowNegativeDimensions: boolean = false) {
super({ x, y, width, height }, allowNegativeDimensions)
}
}
\ No newline at end of file
export * from './BoundingBox'
export * from './Box'
export * from './Dimensions'
export * from './FaceDetection'; export * from './FaceDetection';
export * from './FaceLandmarks'; export * from './FaceLandmarks';
export * from './FaceLandmarks5'; export * from './FaceLandmarks5';
export * from './FaceLandmarks68'; export * from './FaceLandmarks68';
export * from './FaceMatch'; export * from './FaceMatch';
export * from './LabeledBox'
export * from './LabeledFaceDescriptors'; export * from './LabeledFaceDescriptors';
export * from './ObjectDetection'
export * from './Point'
export * from './PredictedBox'
export * from './Rect'
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { ConvParams } from './types';
export function convLayer(
x: tf.Tensor4D,
params: ConvParams,
padding: 'valid' | 'same' = 'same',
withRelu: boolean = false
): tf.Tensor4D {
return tf.tidy(() => {
const out = tf.add(
tf.conv2d(x, params.filters, [1, 1], padding),
params.bias
) as tf.Tensor4D
return withRelu ? tf.relu(out) : out
})
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { SeparableConvParams } from './types';
export function depthwiseSeparableConv( export function depthwiseSeparableConv(
x: tf.Tensor4D, x: tf.Tensor4D,
params: TfjsImageRecognitionBase.SeparableConvParams, params: SeparableConvParams,
stride: [number, number] stride: [number, number]
): tf.Tensor4D { ): tf.Tensor4D {
return tf.tidy(() => { return tf.tidy(() => {
......
import { ParamMapping } from './types';
export function disposeUnusedWeightTensors(weightMap: any, paramMappings: ParamMapping[]) {
Object.keys(weightMap).forEach(path => {
if (!paramMappings.some(pm => pm.originalPath === path)) {
weightMap[path].dispose()
}
})
}
import * as tf from '@tensorflow/tfjs-core';
import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
export function extractConvParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[]
) {
return function(
channelsIn: number,
channelsOut: number,
filterSize: number,
mappedPrefix: string
): ConvParams {
const filters = tf.tensor4d(
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
[filterSize, filterSize, channelsIn, channelsOut]
)
const bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` }
)
return { filters, bias }
}
}
import * as tf from '@tensorflow/tfjs-core';
import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
export function extractFCParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[]
) {
return function(
channelsIn: number,
channelsOut: number,
mappedPrefix: string
): FCParams {
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/bias` }
)
return {
weights: fc_weights,
bias: fc_bias
}
}
}
import * as tf from '@tensorflow/tfjs-core';
import { ExtractWeightsFunction, ParamMapping, SeparableConvParams } from './types';
export function extractSeparableConvParamsFactory(
extractWeights: ExtractWeightsFunction,
paramMappings: ParamMapping[]
) {
return function(channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams {
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1])
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
const bias = tf.tensor1d(extractWeights(channelsOut))
paramMappings.push(
{ paramPath: `${mappedPrefix}/depthwise_filter` },
{ paramPath: `${mappedPrefix}/pointwise_filter` },
{ paramPath: `${mappedPrefix}/bias` }
)
return new SeparableConvParams(
depthwise_filter,
pointwise_filter,
bias
)
}
}
export function loadSeparableConvParamsFactory(
extractWeightEntry: <T>(originalPath: string, paramRank: number) => T
) {
return function (prefix: string): SeparableConvParams {
const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4)
const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return new SeparableConvParams(
depthwise_filter,
pointwise_filter,
bias
)
}
}
import { isTensor } from '../utils';
import { ParamMapping } from './types';
export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) {
return function<T> (originalPath: string, paramRank: number, mappedPath?: string): T {
const tensor = weightMap[originalPath]
if (!isTensor(tensor, paramRank)) {
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`)
}
paramMappings.push(
{ originalPath, paramPath: mappedPath || originalPath }
)
return tensor
}
}
export function extractWeightsFactory(weights: Float32Array) {
let remainingWeights = weights
function extractWeights(numWeights: number): Float32Array {
const ret = remainingWeights.slice(0, numWeights)
remainingWeights = remainingWeights.slice(numWeights)
return ret
}
function getRemainingWeights(): Float32Array {
return remainingWeights
}
return {
extractWeights,
getRemainingWeights
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { FCParams } from './types';
export function fullyConnectedLayer( export function fullyConnectedLayer(
x: tf.Tensor2D, x: tf.Tensor2D,
params: TfjsImageRecognitionBase.FCParams params: FCParams
): tf.Tensor2D { ): tf.Tensor2D {
return tf.tidy(() => return tf.tidy(() =>
tf.add( tf.add(
......
export function getModelUris(uri: string | undefined, defaultModelName: string) {
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`
if (!uri) {
return {
modelBaseUri: '',
manifestUri: defaultManifestFilename
}
}
if (uri === '/') {
return {
modelBaseUri: '/',
manifestUri: `/${defaultManifestFilename}`
}
}
const protocol = uri.startsWith('http://') ? 'http://' : uri.startsWith('https://') ? 'https://' : '';
uri = uri.replace(protocol, '');
const parts = uri.split('/').filter(s => s)
const manifestFile = uri.endsWith('.json')
? parts[parts.length - 1]
: defaultManifestFilename
let modelBaseUri = protocol + (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/')
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri
return {
modelBaseUri,
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`
}
}
\ No newline at end of file
export * from './convLayer'
export * from './depthwiseSeparableConv'
export * from './disposeUnusedWeightTensors'
export * from './extractConvParamsFactory'
export * from './extractFCParamsFactory'
export * from './extractSeparableConvParamsFactory'
export * from './extractWeightEntryFactory'
export * from './extractWeightsFactory'
export * from './getModelUris'
export * from './types'
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { ConvParams } from './types';
export function loadConvParamsFactory(extractWeightEntry: <T>(originalPath: string, paramRank: number) => T) { export function loadConvParamsFactory(extractWeightEntry: <T>(originalPath: string, paramRank: number) => T) {
return function(prefix: string): TfjsImageRecognitionBase.ConvParams { return function(prefix: string): ConvParams {
const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4) const filters = extractWeightEntry<tf.Tensor4D>(`${prefix}/filters`, 4)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
......
import * as tf from '@tensorflow/tfjs-core';
export type ExtractWeightsFunction = (numWeights: number) => Float32Array
export type ParamMapping = {
originalPath?: string
paramPath: string
}
export type ConvParams = {
filters: tf.Tensor4D
bias: tf.Tensor1D
}
export type FCParams = {
weights: tf.Tensor2D
bias: tf.Tensor1D
}
export class SeparableConvParams {
constructor(
public depthwise_filter: tf.Tensor4D,
public pointwise_filter: tf.Tensor4D,
public bias: tf.Tensor1D
) {}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { Dimensions } from '../classes/Dimensions';
import { env } from '../env';
import { padToSquare } from '../ops/padToSquare';
import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils';
import { createCanvasFromMedia } from './createCanvas';
import { imageToSquare } from './imageToSquare';
import { TResolvedNetInput } from './types';
export class NetInput {
private _imageTensors: Array<tf.Tensor3D | tf.Tensor4D> = []
private _canvases: HTMLCanvasElement[] = []
private _batchSize: number
private _treatAsBatchInput: boolean = false
private _inputDimensions: number[][] = []
private _inputSize: number
constructor(
inputs: Array<TResolvedNetInput>,
treatAsBatchInput: boolean = false
) {
if (!Array.isArray(inputs)) {
throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`)
}
this._treatAsBatchInput = treatAsBatchInput
this._batchSize = inputs.length
inputs.forEach((input, idx) => {
if (isTensor3D(input)) {
this._imageTensors[idx] = input
this._inputDimensions[idx] = input.shape
return
}
if (isTensor4D(input)) {
const batchSize = input.shape[0]
if (batchSize !== 1) {
throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`)
}
this._imageTensors[idx] = input
this._inputDimensions[idx] = input.shape.slice(1)
return
}
const canvas = input instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input)
this._canvases[idx] = canvas
this._inputDimensions[idx] = [canvas.height, canvas.width, 3]
})
}
public get imageTensors(): Array<tf.Tensor3D | tf.Tensor4D> {
return this._imageTensors
}
public get canvases(): HTMLCanvasElement[] {
return this._canvases
}
public get isBatchInput(): boolean {
return this.batchSize > 1 || this._treatAsBatchInput
}
public get batchSize(): number {
return this._batchSize
}
public get inputDimensions(): number[][] {
return this._inputDimensions
}
public get inputSize(): number | undefined {
return this._inputSize
}
public get reshapedInputDimensions(): Dimensions[] {
return range(this.batchSize, 0, 1).map(
(_, batchIdx) => this.getReshapedInputDimensions(batchIdx)
)
}
public getInput(batchIdx: number): tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement {
return this.canvases[batchIdx] || this.imageTensors[batchIdx]
}
public getInputDimensions(batchIdx: number): number[] {
return this._inputDimensions[batchIdx]
}
public getInputHeight(batchIdx: number): number {
return this._inputDimensions[batchIdx][0]
}
public getInputWidth(batchIdx: number): number {
return this._inputDimensions[batchIdx][1]
}
public getReshapedInputDimensions(batchIdx: number): Dimensions {
if (typeof this.inputSize !== 'number') {
throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet')
}
const width = this.getInputWidth(batchIdx)
const height = this.getInputHeight(batchIdx)
return computeReshapedDimensions({ width, height }, this.inputSize)
}
/**
* Create a batch tensor from all input canvases and tensors
* with size [batchSize, inputSize, inputSize, 3].
*
* @param inputSize Height and width of the tensor.
* @param isCenterImage (optional, default: false) If true, add an equal amount of padding on
* both sides of the minor dimension oof the image.
* @returns The batch tensor.
*/
public toBatchTensor(inputSize: number, isCenterInputs: boolean = true): tf.Tensor4D {
this._inputSize = inputSize
return tf.tidy(() => {
const inputTensors = range(this.batchSize, 0, 1).map(batchIdx => {
const input = this.getInput(batchIdx)
if (input instanceof tf.Tensor) {
let imgTensor = isTensor4D(input) ? input : input.expandDims<tf.Rank.R4>()
imgTensor = padToSquare(imgTensor, isCenterInputs)
if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize])
}
return imgTensor.as3D(inputSize, inputSize, 3)
}
if (input instanceof env.getEnv().Canvas) {
return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs))
}
throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`)
})
const batchTensor = tf.stack(inputTensors.map(t => t.toFloat())).as4D(this.batchSize, inputSize, inputSize, 3)
return batchTensor
})
}
}
\ No newline at end of file
import { env } from '../env';
import { isMediaLoaded } from './isMediaLoaded';
export function awaitMediaLoaded(media: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement) {
return new Promise((resolve, reject) => {
if (media instanceof env.getEnv().Canvas || isMediaLoaded(media)) {
return resolve()
}
function onLoad(e: Event) {
if (!e.currentTarget) return
e.currentTarget.removeEventListener('load', onLoad)
e.currentTarget.removeEventListener('error', onError)
resolve(e)
}
function onError(e: Event) {
if (!e.currentTarget) return
e.currentTarget.removeEventListener('load', onLoad)
e.currentTarget.removeEventListener('error', onError)
reject(e)
}
media.addEventListener('load', onLoad)
media.addEventListener('error', onError)
})
}
\ No newline at end of file
import { env } from '../env';
export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
return new Promise((resolve, reject) => {
if (!(buf instanceof Blob)) {
return reject('bufferToImage - expected buf to be of type: Blob')
}
const reader = new FileReader()
reader.onload = () => {
if (typeof reader.result !== 'string') {
return reject('bufferToImage - expected reader.result to be a string, in onload')
}
const img = env.getEnv().createImageElement()
img.onload = () => resolve(img)
img.onerror = reject
img.src = reader.result
}
reader.onerror = reject
reader.readAsDataURL(buf)
})
}
\ No newline at end of file
import { IDimensions } from '../classes/Dimensions';
import { env } from '../env';
import { getContext2dOrThrow } from './getContext2dOrThrow';
import { getMediaDimensions } from './getMediaDimensions';
import { isMediaLoaded } from './isMediaLoaded';
export function createCanvas({ width, height }: IDimensions): HTMLCanvasElement {
const { createCanvasElement } = env.getEnv()
const canvas = createCanvasElement()
canvas.width = width
canvas.height = height
return canvas
}
export function createCanvasFromMedia(media: HTMLImageElement | HTMLVideoElement | ImageData, dims?: IDimensions): HTMLCanvasElement {
const { ImageData } = env.getEnv()
if (!(media instanceof ImageData) && !isMediaLoaded(media)) {
throw new Error('createCanvasFromMedia - media has not finished loading yet')
}
const { width, height } = dims || getMediaDimensions(media)
const canvas = createCanvas({ width, height })
if (media instanceof ImageData) {
getContext2dOrThrow(canvas).putImageData(media, 0, 0)
} else {
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
}
return canvas
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { isTensor4D, Rect, isTensor3D } from 'tfjs-image-recognition-base';
import { Rect } from '../classes';
import { FaceDetection } from '../classes/FaceDetection'; import { FaceDetection } from '../classes/FaceDetection';
import { isTensor3D, isTensor4D } from '../utils';
/** /**
* Extracts the tensors of the image regions containing the detected faces. * Extracts the tensors of the image regions containing the detected faces.
......
import {
createCanvas,
env,
getContext2dOrThrow,
imageTensorToCanvas,
Rect,
TNetInput,
toNetInput,
} from 'tfjs-image-recognition-base';
import { FaceDetection } from '../classes/FaceDetection'; import { FaceDetection } from '../classes/FaceDetection';
import { Rect } from '../classes/Rect';
import { env } from '../env';
import { createCanvas } from './createCanvas';
import { getContext2dOrThrow } from './getContext2dOrThrow';
import { imageTensorToCanvas } from './imageTensorToCanvas';
import { toNetInput } from './toNetInput';
import { TNetInput } from './types';
/** /**
* Extracts the image regions containing the detected faces. * Extracts the image regions containing the detected faces.
......
import { bufferToImage } from './bufferToImage';
import { fetchOrThrow } from './fetchOrThrow';
export async function fetchImage(uri: string): Promise<HTMLImageElement> {
const res = await fetchOrThrow(uri)
const blob = await (res).blob()
if (!blob.type.startsWith('image/')) {
throw new Error(`fetchImage - expected blob type to be of type image/*, instead have: ${blob.type}, for url: ${res.url}`)
}
return bufferToImage(blob)
}
import { fetchOrThrow } from './fetchOrThrow';
export async function fetchJson<T>(uri: string): Promise<T> {
return (await fetchOrThrow(uri)).json()
}
import { fetchOrThrow } from './fetchOrThrow';
export async function fetchNetWeights(uri: string): Promise<Float32Array> {
return new Float32Array(await (await fetchOrThrow(uri)).arrayBuffer())
}
import { env } from '../env';
export async function fetchOrThrow(
url: string,
init?: RequestInit
): Promise<Response> {
const fetch = env.getEnv().fetch
const res = await fetch(url, init)
if (!(res.status < 400)) {
throw new Error(`failed to fetch: (${res.status}) ${res.statusText}, from url: ${res.url}`)
}
return res
}
\ No newline at end of file
import { env } from '../env';
import { resolveInput } from './resolveInput';
export function getContext2dOrThrow(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D): CanvasRenderingContext2D {
const { Canvas, CanvasRenderingContext2D } = env.getEnv()
if (canvasArg instanceof CanvasRenderingContext2D) {
return canvasArg
}
const canvas = resolveInput(canvasArg)
if (!(canvas instanceof Canvas)) {
throw new Error('resolveContext2d - expected canvas to be of instance of Canvas')
}
const ctx = canvas.getContext('2d')
if (!ctx) {
throw new Error('resolveContext2d - canvas 2d context is null')
}
return ctx
}
\ No newline at end of file
import { Dimensions, IDimensions } from '../classes/Dimensions';
import { env } from '../env';
export function getMediaDimensions(input: HTMLImageElement | HTMLCanvasElement | HTMLVideoElement | IDimensions): Dimensions {
const { Image, Video } = env.getEnv()
if (input instanceof Image) {
return new Dimensions(input.naturalWidth, input.naturalHeight)
}
if (input instanceof Video) {
return new Dimensions(input.videoWidth, input.videoHeight)
}
return new Dimensions(input.width, input.height)
}
import * as tf from '@tensorflow/tfjs-core';
import { env } from '../env';
import { isTensor4D } from '../utils';
export async function imageTensorToCanvas(
imgTensor: tf.Tensor,
canvas?: HTMLCanvasElement
): Promise<HTMLCanvasElement> {
const targetCanvas = canvas || env.getEnv().createCanvasElement()
const [height, width, numChannels] = imgTensor.shape.slice(isTensor4D(imgTensor) ? 1 : 0)
const imgTensor3D = tf.tidy(() => imgTensor.as3D(height, width, numChannels).toInt())
await tf.browser.toPixels(imgTensor3D, targetCanvas)
imgTensor3D.dispose()
return targetCanvas
}
\ No newline at end of file
import { env } from '../env';
import { createCanvas, createCanvasFromMedia } from './createCanvas';
import { getContext2dOrThrow } from './getContext2dOrThrow';
import { getMediaDimensions } from './getMediaDimensions';
export function imageToSquare(input: HTMLImageElement | HTMLCanvasElement, inputSize: number, centerImage: boolean = false) {
const { Image, Canvas } = env.getEnv()
if (!(input instanceof Image || input instanceof Canvas)) {
throw new Error('imageToSquare - expected arg0 to be HTMLImageElement | HTMLCanvasElement')
}
const dims = getMediaDimensions(input)
const scale = inputSize / Math.max(dims.height, dims.width)
const width = scale * dims.width
const height = scale * dims.height
const targetCanvas = createCanvas({ width: inputSize, height: inputSize })
const inputCanvas = input instanceof Canvas ? input : createCanvasFromMedia(input)
const offset = Math.abs(width - height) / 2
const dx = centerImage && width < height ? offset : 0
const dy = centerImage && height < width ? offset : 0
getContext2dOrThrow(targetCanvas).drawImage(inputCanvas, dx, dy, width, height)
return targetCanvas
}
\ No newline at end of file
export * from './awaitMediaLoaded'
export * from './bufferToImage'
export * from './createCanvas'
export * from './extractFaces' export * from './extractFaces'
export * from './extractFaceTensors' export * from './extractFaceTensors'
export * from './fetchImage'
export * from './fetchJson'
export * from './fetchNetWeights'
export * from './fetchOrThrow'
export * from './getContext2dOrThrow'
export * from './getMediaDimensions'
export * from './imageTensorToCanvas'
export * from './imageToSquare'
export * from './isMediaElement'
export * from './isMediaLoaded'
export * from './loadWeightMap'
export * from './matchDimensions'
export * from './NetInput'
export * from './resolveInput'
export * from './toNetInput'
export * from './types'
\ No newline at end of file
import { env } from '../env';
export function isMediaElement(input: any) {
const { Image, Canvas, Video } = env.getEnv()
return input instanceof Image
|| input instanceof Canvas
|| input instanceof Video
}
\ No newline at end of file
import { env } from '../env';
export function isMediaLoaded(media: HTMLImageElement | HTMLVideoElement) : boolean {
const { Image, Video } = env.getEnv()
return (media instanceof Image && media.complete)
|| (media instanceof Video && media.readyState >= 3)
}
import * as tf from '@tensorflow/tfjs-core';
import { getModelUris } from '../common/getModelUris';
import { fetchJson } from './fetchJson';
export async function loadWeightMap(
uri: string | undefined,
defaultModelName: string,
): Promise<tf.NamedTensorMap> {
const { manifestUri, modelBaseUri } = getModelUris(uri, defaultModelName)
const manifest = await fetchJson<tf.io.WeightsManifestConfig>(manifestUri)
return tf.io.loadWeights(manifest, modelBaseUri)
}
\ No newline at end of file
import { IDimensions } from '../classes';
import { getMediaDimensions } from './getMediaDimensions';
export function matchDimensions(input: IDimensions, reference: IDimensions, useMediaDimensions: boolean = false) {
const { width, height } = useMediaDimensions
? getMediaDimensions(reference)
: reference
input.width = width
input.height = height
return { width, height }
}
\ No newline at end of file
import { env } from '../env';
export function resolveInput(arg: string | any) {
if (!env.isNodejs() && typeof arg === 'string') {
return document.getElementById(arg)
}
return arg
}
\ No newline at end of file
import { isTensor3D, isTensor4D } from '../utils';
import { awaitMediaLoaded } from './awaitMediaLoaded';
import { isMediaElement } from './isMediaElement';
import { NetInput } from './NetInput';
import { resolveInput } from './resolveInput';
import { TNetInput } from './types';
/**
* Validates the input to make sure, they are valid net inputs and awaits all media elements
* to be finished loading.
*
* @param input The input, which can be a media element or an array of different media elements.
* @returns A NetInput instance, which can be passed into one of the neural networks.
*/
export async function toNetInput(inputs: TNetInput): Promise<NetInput> {
if (inputs instanceof NetInput) {
return inputs
}
let inputArgArray = Array.isArray(inputs)
? inputs
: [inputs]
if (!inputArgArray.length) {
throw new Error('toNetInput - empty array passed as input')
}
const getIdxHint = (idx: number) => Array.isArray(inputs) ? ` at input index ${idx}:` : ''
const inputArray = inputArgArray.map(resolveInput)
inputArray.forEach((input, i) => {
if (!isMediaElement(input) && !isTensor3D(input) && !isTensor4D(input)) {
if (typeof inputArgArray[i] === 'string') {
throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`)
}
throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`)
}
if (isTensor4D(input)) {
// if tf.Tensor4D is passed in the input array, the batch size has to be 1
const batchSize = input.shape[0]
if (batchSize !== 1) {
throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`)
}
}
})
// wait for all media elements being loaded
await Promise.all(
inputArray.map(input => isMediaElement(input) && awaitMediaLoaded(input))
)
return new NetInput(inputArray, Array.isArray(inputs))
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { NetInput } from './NetInput';
export type TMediaElement = HTMLImageElement | HTMLVideoElement | HTMLCanvasElement
export type TResolvedNetInput = TMediaElement | tf.Tensor3D | tf.Tensor4D
export type TNetInputArg = string | TResolvedNetInput
export type TNetInput = TNetInputArg | Array<TNetInputArg> | NetInput | tf.Tensor4D
\ No newline at end of file
import { Box, IBoundingBox, IRect } from '../classes';
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
import { AnchorPosition, DrawTextField, DrawTextFieldOptions, IDrawTextFieldOptions } from './DrawTextField';
export interface IDrawBoxOptions {
boxColor?: string
lineWidth?: number
drawLabelOptions?: IDrawTextFieldOptions
label?: string
}
export class DrawBoxOptions {
public boxColor: string
public lineWidth: number
public drawLabelOptions: DrawTextFieldOptions
public label?: string
constructor(options: IDrawBoxOptions = {}) {
const { boxColor, lineWidth, label, drawLabelOptions } = options
this.boxColor = boxColor || 'rgba(0, 0, 255, 1)'
this.lineWidth = lineWidth || 2
this.label = label
const defaultDrawLabelOptions = {
anchorPosition: AnchorPosition.BOTTOM_LEFT,
backgroundColor: this.boxColor
}
this.drawLabelOptions = new DrawTextFieldOptions(Object.assign({}, defaultDrawLabelOptions, drawLabelOptions))
}
}
export class DrawBox {
public box: Box
public options: DrawBoxOptions
constructor(
box: IBoundingBox | IRect,
options: IDrawBoxOptions = {}
) {
this.box = new Box(box)
this.options = new DrawBoxOptions(options)
}
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
const ctx = getContext2dOrThrow(canvasArg)
const { boxColor, lineWidth } = this.options
const { x, y, width, height } = this.box
ctx.strokeStyle = boxColor
ctx.lineWidth = lineWidth
ctx.strokeRect(x, y, width, height)
const { label } = this.options
if (label) {
new DrawTextField([label], { x: x - (lineWidth / 2), y }, this.options.drawLabelOptions).draw(canvasArg)
}
}
}
\ No newline at end of file
import { getContext2dOrThrow, IPoint } from 'tfjs-image-recognition-base'; import { IPoint } from '../classes';
import { FaceLandmarks } from '../classes/FaceLandmarks'; import { FaceLandmarks } from '../classes/FaceLandmarks';
import { FaceLandmarks68 } from '../classes/FaceLandmarks68'; import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
import { WithFaceDetection } from '../factories/WithFaceDetection'; import { WithFaceDetection } from '../factories/WithFaceDetection';
import { isWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFaceLandmarks'; import { isWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFaceLandmarks';
import { drawContour } from './drawContour'; import { drawContour } from './drawContour';
......
import { IDimensions, IPoint } from '../classes';
import { getContext2dOrThrow } from '../dom/getContext2dOrThrow';
import { resolveInput } from '../dom/resolveInput';
export enum AnchorPosition {
TOP_LEFT = 'TOP_LEFT',
TOP_RIGHT = 'TOP_RIGHT',
BOTTOM_LEFT = 'BOTTOM_LEFT',
BOTTOM_RIGHT = 'BOTTOM_RIGHT'
}
export interface IDrawTextFieldOptions {
anchorPosition?: AnchorPosition
backgroundColor?: string
fontColor?: string
fontSize?: number
fontStyle?: string
padding?: number
}
export class DrawTextFieldOptions implements IDrawTextFieldOptions {
public anchorPosition: AnchorPosition
public backgroundColor: string
public fontColor: string
public fontSize: number
public fontStyle: string
public padding: number
constructor(options: IDrawTextFieldOptions = {}) {
const { anchorPosition, backgroundColor, fontColor, fontSize, fontStyle, padding } = options
this.anchorPosition = anchorPosition || AnchorPosition.TOP_LEFT
this.backgroundColor = backgroundColor || 'rgba(0, 0, 0, 0.5)'
this.fontColor = fontColor || 'rgba(255, 255, 255, 1)'
this.fontSize = fontSize || 14
this.fontStyle = fontStyle || 'Georgia'
this.padding = padding || 4
}
}
export class DrawTextField {
public text: string[]
public anchor : IPoint
public options: DrawTextFieldOptions
constructor(
text: string | string[] | DrawTextField,
anchor: IPoint,
options: IDrawTextFieldOptions = {}
) {
this.text = typeof text === 'string'
? [text]
: (text instanceof DrawTextField ? text.text : text)
this.anchor = anchor
this.options = new DrawTextFieldOptions(options)
}
measureWidth(ctx: CanvasRenderingContext2D): number {
const { padding } = this.options
return this.text.map(l => ctx.measureText(l).width).reduce((w0, w1) => w0 < w1 ? w1 : w0, 0) + (2 * padding)
}
measureHeight(): number {
const { fontSize, padding } = this.options
return this.text.length * fontSize + (2 * padding)
}
getUpperLeft(ctx: CanvasRenderingContext2D, canvasDims?: IDimensions): IPoint {
const { anchorPosition } = this.options
const isShiftLeft = anchorPosition === AnchorPosition.BOTTOM_RIGHT || anchorPosition === AnchorPosition.TOP_RIGHT
const isShiftTop = anchorPosition === AnchorPosition.BOTTOM_LEFT || anchorPosition === AnchorPosition.BOTTOM_RIGHT
const textFieldWidth = this.measureWidth(ctx)
const textFieldHeight = this.measureHeight()
const x = (isShiftLeft ? this.anchor.x - textFieldWidth : this.anchor.x)
const y = isShiftTop ? this.anchor.y - textFieldHeight : this.anchor.y
// adjust anchor if text box exceeds canvas borders
if (canvasDims) {
const { width, height } = canvasDims
const newX = Math.max(Math.min(x, width - textFieldWidth), 0)
const newY = Math.max(Math.min(y, height - textFieldHeight), 0)
return { x: newX, y: newY }
}
return { x, y }
}
draw(canvasArg: string | HTMLCanvasElement | CanvasRenderingContext2D) {
const canvas = resolveInput(canvasArg)
const ctx = getContext2dOrThrow(canvas)
const { backgroundColor, fontColor, fontSize, fontStyle, padding } = this.options
ctx.font = `${fontSize}px ${fontStyle}`
const maxTextWidth = this.measureWidth(ctx)
const textHeight = this.measureHeight()
ctx.fillStyle = backgroundColor
const upperLeft = this.getUpperLeft(ctx, canvas)
ctx.fillRect(upperLeft.x, upperLeft.y, maxTextWidth, textHeight)
ctx.fillStyle = fontColor;
this.text.forEach((textLine, i) => {
const x = padding + upperLeft.x
const y = padding + upperLeft.y + ((i + 1) * fontSize)
ctx.fillText(textLine, x, y)
})
}
}
\ No newline at end of file
import { Point } from 'tfjs-image-recognition-base'; import { Point } from '../classes';
export function drawContour( export function drawContour(
ctx: CanvasRenderingContext2D, ctx: CanvasRenderingContext2D,
......
import { Box, draw, IBoundingBox, IRect, round } from 'tfjs-image-recognition-base'; import { Box, IBoundingBox, IRect } from '../classes';
import { FaceDetection } from '../classes/FaceDetection'; import { FaceDetection } from '../classes/FaceDetection';
import { isWithFaceDetection, WithFaceDetection } from '../factories/WithFaceDetection'; import { isWithFaceDetection, WithFaceDetection } from '../factories/WithFaceDetection';
import { round } from '../utils';
import { DrawBox } from './DrawBox';
export type TDrawDetectionsInput = IRect | IBoundingBox | FaceDetection | WithFaceDetection<{}> export type TDrawDetectionsInput = IRect | IBoundingBox | FaceDetection | WithFaceDetection<{}>
...@@ -21,6 +22,6 @@ export function drawDetections( ...@@ -21,6 +22,6 @@ export function drawDetections(
: (isWithFaceDetection(det) ? det.detection.box : new Box(det)) : (isWithFaceDetection(det) ? det.detection.box : new Box(det))
const label = score ? `${round(score)}` : undefined const label = score ? `${round(score)}` : undefined
new draw.DrawBox(box, { label }).draw(canvasArg) new DrawBox(box, { label }).draw(canvasArg)
}) })
} }
\ No newline at end of file
import { draw, IPoint, Point, round } from 'tfjs-image-recognition-base'; import { IPoint, Point } from '../classes';
import { FaceExpressions } from '../faceExpressionNet'; import { FaceExpressions } from '../faceExpressionNet';
import { isWithFaceDetection } from '../factories/WithFaceDetection'; import { isWithFaceDetection } from '../factories/WithFaceDetection';
import { isWithFaceExpressions, WithFaceExpressions } from '../factories/WithFaceExpressions'; import { isWithFaceExpressions, WithFaceExpressions } from '../factories/WithFaceExpressions';
import { round } from '../utils';
import { DrawTextField } from './DrawTextField';
export type DrawFaceExpressionsInput = FaceExpressions | WithFaceExpressions<{}> export type DrawFaceExpressionsInput = FaceExpressions | WithFaceExpressions<{}>
...@@ -29,7 +30,7 @@ export function drawFaceExpressions( ...@@ -29,7 +30,7 @@ export function drawFaceExpressions(
? e.detection.box.bottomLeft ? e.detection.box.bottomLeft
: (textFieldAnchor || new Point(0, 0)) : (textFieldAnchor || new Point(0, 0))
const drawTextField = new draw.DrawTextField( const drawTextField = new DrawTextField(
resultsToDisplay.map(expr => `${expr.expression} (${round(expr.probability)})`), resultsToDisplay.map(expr => `${expr.expression} (${round(expr.probability)})`),
anchor anchor
) )
......
export * from './drawContour' export * from './drawContour'
export * from './drawDetections' export * from './drawDetections'
export * from './drawFaceExpressions' export * from './drawFaceExpressions'
export * from './DrawBox'
export * from './DrawFaceLandmarks' export * from './DrawFaceLandmarks'
export * from './DrawTextField'
\ No newline at end of file
import { Environment } from './types';
export function createBrowserEnv(): Environment {
const fetch = window['fetch'] || function() {
throw new Error('fetch - missing fetch implementation for browser environment')
}
const readFile = function() {
throw new Error('readFile - filesystem not available for browser environment')
}
return {
Canvas: HTMLCanvasElement,
CanvasRenderingContext2D: CanvasRenderingContext2D,
Image: HTMLImageElement,
ImageData: ImageData,
Video: HTMLVideoElement,
createCanvasElement: () => document.createElement('canvas'),
createImageElement: () => document.createElement('img'),
fetch,
readFile
}
}
\ No newline at end of file
import { FileSystem } from './types';
export function createFileSystem(fs?: any): FileSystem {
let requireFsError = ''
if (!fs) {
try {
fs = require('fs')
} catch (err) {
requireFsError = err.toString()
}
}
const readFile = fs
? function(filePath: string) {
return new Promise<Buffer>((res, rej) => {
fs.readFile(filePath, function(err: any, buffer: Buffer) {
return err ? rej(err) : res(buffer)
})
})
}
: function() {
throw new Error(`readFile - failed to require fs in nodejs environment with error: ${requireFsError}`)
}
return {
readFile
}
}
\ No newline at end of file
import { createFileSystem } from './createFileSystem';
import { Environment } from './types';
export function createNodejsEnv(): Environment {
const Canvas = global['Canvas'] || global['HTMLCanvasElement']
const Image = global['Image'] || global['HTMLImageElement']
const createCanvasElement = function() {
if (Canvas) {
return new Canvas()
}
throw new Error('createCanvasElement - missing Canvas implementation for nodejs environment')
}
const createImageElement = function() {
if (Image) {
return new Image()
}
throw new Error('createImageElement - missing Image implementation for nodejs environment')
}
const fetch = global['fetch'] || function() {
throw new Error('fetch - missing fetch implementation for nodejs environment')
}
const fileSystem = createFileSystem()
return {
Canvas: Canvas || class {},
CanvasRenderingContext2D: global['CanvasRenderingContext2D'] || class {},
Image: Image || class {},
ImageData: global['ImageData'] || class {},
Video: global['HTMLVideoElement'] || class {},
createCanvasElement,
createImageElement,
fetch,
...fileSystem
}
}
\ No newline at end of file
import { createBrowserEnv } from './createBrowserEnv';
import { createFileSystem } from './createFileSystem';
import { createNodejsEnv } from './createNodejsEnv';
import { isBrowser } from './isBrowser';
import { isNodejs } from './isNodejs';
import { Environment } from './types';
let environment: Environment | null
function getEnv(): Environment {
if (!environment) {
throw new Error('getEnv - environment is not defined, check isNodejs() and isBrowser()')
}
return environment
}
function setEnv(env: Environment) {
environment = env
}
function initialize() {
// check for isBrowser() first to prevent electron renderer process
// to be initialized with wrong environment due to isNodejs() returning true
if (isBrowser()) {
setEnv(createBrowserEnv())
}
if (isNodejs()) {
setEnv(createNodejsEnv())
}
}
function monkeyPatch(env: Partial<Environment>) {
if (!environment) {
initialize()
}
if (!environment) {
throw new Error('monkeyPatch - environment is not defined, check isNodejs() and isBrowser()')
}
const { Canvas = environment.Canvas, Image = environment.Image } = env
environment.Canvas = Canvas
environment.Image = Image
environment.createCanvasElement = env.createCanvasElement || (() => new Canvas())
environment.createImageElement = env.createImageElement || (() => new Image())
environment.ImageData = env.ImageData || environment.ImageData
environment.Video = env.Video || environment.Video
environment.fetch = env.fetch || environment.fetch
environment.readFile = env.readFile || environment.readFile
}
export const env = {
getEnv,
setEnv,
initialize,
createBrowserEnv,
createFileSystem,
createNodejsEnv,
monkeyPatch,
isBrowser,
isNodejs
}
initialize()
export * from './types'
export function isBrowser(): boolean {
return typeof window === 'object'
&& typeof document !== 'undefined'
&& typeof HTMLImageElement !== 'undefined'
&& typeof HTMLCanvasElement !== 'undefined'
&& typeof HTMLVideoElement !== 'undefined'
&& typeof ImageData !== 'undefined'
&& typeof CanvasRenderingContext2D !== 'undefined'
}
\ No newline at end of file
export function isNodejs(): boolean {
return typeof global === 'object'
&& typeof require === 'function'
&& typeof module !== 'undefined'
// issues with gatsby.js: module.exports is undefined
// && !!module.exports
&& typeof process !== 'undefined' && !!process.version
}
\ No newline at end of file
export type FileSystem = {
readFile: (filePath: string) => Promise<Buffer>
}
export type Environment = FileSystem & {
Canvas: typeof HTMLCanvasElement
CanvasRenderingContext2D: typeof CanvasRenderingContext2D
Image: typeof HTMLImageElement
ImageData: typeof ImageData
Video: typeof HTMLVideoElement
createCanvasElement: () => HTMLCanvasElement
createImageElement: () => HTMLImageElement
fetch: (url: string, init?: RequestInit) => Promise<Response>
}
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
import { NetInput, TNetInput, toNetInput } from '../dom';
import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor'; import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor';
import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types'; import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
import { FaceProcessor } from '../faceProcessor/FaceProcessor'; import { FaceProcessor } from '../faceProcessor/FaceProcessor';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput, NeuralNetwork, normalize, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
import { NetInput, TNetInput, toNetInput } from '../dom';
import { NeuralNetwork } from '../NeuralNetwork';
import { normalize } from '../ops';
import { denseBlock4 } from './denseBlock'; import { denseBlock4 } from './denseBlock';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap'; import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput, NeuralNetwork, normalize, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
import { NetInput, TNetInput, toNetInput } from '../dom';
import { NeuralNetwork } from '../NeuralNetwork';
import { normalize } from '../ops';
import { denseBlock3 } from './denseBlock'; import { denseBlock3 } from './denseBlock';
import { extractParamsFromWeigthMapTiny } from './extractParamsFromWeigthMapTiny'; import { extractParamsFromWeigthMapTiny } from './extractParamsFromWeigthMapTiny';
import { extractParamsTiny } from './extractParamsTiny'; import { extractParamsTiny } from './extractParamsTiny';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { ConvParams, SeparableConvParams } from '../common';
import { depthwiseSeparableConv } from '../common/depthwiseSeparableConv'; import { depthwiseSeparableConv } from '../common/depthwiseSeparableConv';
import { DenseBlock3Params, DenseBlock4Params } from './types'; import { DenseBlock3Params, DenseBlock4Params } from './types';
...@@ -13,10 +13,10 @@ export function denseBlock3( ...@@ -13,10 +13,10 @@ export function denseBlock3(
const out1 = tf.relu( const out1 = tf.relu(
isFirstLayer isFirstLayer
? tf.add( ? tf.add(
tf.conv2d(x, (denseBlockParams.conv0 as TfjsImageRecognitionBase.ConvParams).filters, [2, 2], 'same'), tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, [2, 2], 'same'),
denseBlockParams.conv0.bias denseBlockParams.conv0.bias
) )
: depthwiseSeparableConv(x, denseBlockParams.conv0 as TfjsImageRecognitionBase.SeparableConvParams, [2, 2]) : depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, [2, 2])
) as tf.Tensor4D ) as tf.Tensor4D
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]) const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1])
...@@ -37,10 +37,10 @@ export function denseBlock4( ...@@ -37,10 +37,10 @@ export function denseBlock4(
const out1 = tf.relu( const out1 = tf.relu(
isFirstLayer isFirstLayer
? tf.add( ? tf.add(
tf.conv2d(x, (denseBlockParams.conv0 as TfjsImageRecognitionBase.ConvParams).filters, isScaleDown ? [2, 2] : [1, 1], 'same'), tf.conv2d(x, (denseBlockParams.conv0 as ConvParams).filters, isScaleDown ? [2, 2] : [1, 1], 'same'),
denseBlockParams.conv0.bias denseBlockParams.conv0.bias
) )
: depthwiseSeparableConv(x, denseBlockParams.conv0 as TfjsImageRecognitionBase.SeparableConvParams, isScaleDown ? [2, 2] : [1, 1]) : depthwiseSeparableConv(x, denseBlockParams.conv0 as SeparableConvParams, isScaleDown ? [2, 2] : [1, 1])
) as tf.Tensor4D ) as tf.Tensor4D
const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1]) const out2 = depthwiseSeparableConv(out1, denseBlockParams.conv1, [1, 1])
......
import { extractWeightsFactory, ParamMapping } from '../common';
import { extractorsFactory } from './extractorsFactory'; import { extractorsFactory } from './extractorsFactory';
import { FaceFeatureExtractorParams } from './types'; import { FaceFeatureExtractorParams } from './types';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
export function extractParams(weights: Float32Array): { params: FaceFeatureExtractorParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } {
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] export function extractParams(weights: Float32Array): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights
} = TfjsImageRecognitionBase.extractWeightsFactory(weights) } = extractWeightsFactory(weights)
const { const {
extractDenseBlock4Params extractDenseBlock4Params
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { disposeUnusedWeightTensors, ParamMapping } from '../common';
import { loadParamsFactory } from './loadParamsFactory'; import { loadParamsFactory } from './loadParamsFactory';
import { FaceFeatureExtractorParams } from './types'; import { FaceFeatureExtractorParams } from './types';
export function extractParamsFromWeigthMap( export function extractParamsFromWeigthMap(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap
): { params: FaceFeatureExtractorParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } { ): { params: FaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] const paramMappings: ParamMapping[] = []
const { const {
extractDenseBlock4Params extractDenseBlock4Params
...@@ -21,7 +21,7 @@ export function extractParamsFromWeigthMap( ...@@ -21,7 +21,7 @@ export function extractParamsFromWeigthMap(
dense3: extractDenseBlock4Params('dense3') dense3: extractDenseBlock4Params('dense3')
} }
TfjsImageRecognitionBase.disposeUnusedWeightTensors(weightMap, paramMappings) disposeUnusedWeightTensors(weightMap, paramMappings)
return { params, paramMappings } return { params, paramMappings }
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { disposeUnusedWeightTensors, ParamMapping } from '../common';
import { loadParamsFactory } from './loadParamsFactory'; import { loadParamsFactory } from './loadParamsFactory';
import { TinyFaceFeatureExtractorParams } from './types'; import { TinyFaceFeatureExtractorParams } from './types';
export function extractParamsFromWeigthMapTiny( export function extractParamsFromWeigthMapTiny(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap
): { params: TinyFaceFeatureExtractorParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } { ): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] const paramMappings: ParamMapping[] = []
const { const {
extractDenseBlock3Params extractDenseBlock3Params
...@@ -20,7 +20,7 @@ export function extractParamsFromWeigthMapTiny( ...@@ -20,7 +20,7 @@ export function extractParamsFromWeigthMapTiny(
dense2: extractDenseBlock3Params('dense2') dense2: extractDenseBlock3Params('dense2')
} }
TfjsImageRecognitionBase.disposeUnusedWeightTensors(weightMap, paramMappings) disposeUnusedWeightTensors(weightMap, paramMappings)
return { params, paramMappings } return { params, paramMappings }
} }
\ No newline at end of file
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base'; import { extractWeightsFactory, ParamMapping } from '../common';
import { extractorsFactory } from './extractorsFactory'; import { extractorsFactory } from './extractorsFactory';
import { TinyFaceFeatureExtractorParams } from './types'; import { TinyFaceFeatureExtractorParams } from './types';
export function extractParamsTiny(weights: Float32Array): { params: TinyFaceFeatureExtractorParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } {
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] export function extractParamsTiny(weights: Float32Array): { params: TinyFaceFeatureExtractorParams, paramMappings: ParamMapping[] } {
const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights
} = TfjsImageRecognitionBase.extractWeightsFactory(weights) } = extractWeightsFactory(weights)
const { const {
extractDenseBlock3Params extractDenseBlock3Params
......
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base'; import {
extractConvParamsFactory,
extractSeparableConvParamsFactory,
ExtractWeightsFunction,
ParamMapping,
} from '../common';
import { DenseBlock3Params, DenseBlock4Params } from './types'; import { DenseBlock3Params, DenseBlock4Params } from './types';
export function extractorsFactory(extractWeights: TfjsImageRecognitionBase.ExtractWeightsFunction, paramMappings: TfjsImageRecognitionBase.ParamMapping[]) { export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
const extractConvParams = TfjsImageRecognitionBase.extractConvParamsFactory(extractWeights, paramMappings) const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
const extractSeparableConvParams = TfjsImageRecognitionBase.extractSeparableConvParamsFactory(extractWeights, paramMappings) const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings)
function extractDenseBlock3Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock3Params { function extractDenseBlock3Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer: boolean = false): DenseBlock3Params {
......
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base'; import { extractWeightEntryFactory, loadSeparableConvParamsFactory, ParamMapping } from '../common';
import { loadConvParamsFactory } from '../common/loadConvParamsFactory'; import { loadConvParamsFactory } from '../common/loadConvParamsFactory';
import { DenseBlock3Params, DenseBlock4Params } from './types'; import { DenseBlock3Params, DenseBlock4Params } from './types';
export function loadParamsFactory(weightMap: any, paramMappings: TfjsImageRecognitionBase.ParamMapping[]) { export function loadParamsFactory(weightMap: any, paramMappings: ParamMapping[]) {
const extractWeightEntry = TfjsImageRecognitionBase.extractWeightEntryFactory(weightMap, paramMappings) const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
const extractConvParams = loadConvParamsFactory(extractWeightEntry) const extractConvParams = loadConvParamsFactory(extractWeightEntry)
const extractSeparableConvParams = TfjsImageRecognitionBase.loadSeparableConvParamsFactory(extractWeightEntry) const extractSeparableConvParams = loadSeparableConvParamsFactory(extractWeightEntry)
function extractDenseBlock3Params(prefix: string, isFirstLayer: boolean = false): DenseBlock3Params { function extractDenseBlock3Params(prefix: string, isFirstLayer: boolean = false): DenseBlock3Params {
const conv0 = isFirstLayer const conv0 = isFirstLayer
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput, NeuralNetwork, TNetInput, TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { NetInput, TNetInput } from '..';
import { ConvParams, SeparableConvParams } from '../common';
import { NeuralNetwork } from '../NeuralNetwork';
export type ConvWithBatchNormParams = BatchNormParams & { export type ConvWithBatchNormParams = BatchNormParams & {
filter: tf.Tensor4D filter: tf.Tensor4D
...@@ -18,13 +21,13 @@ export type SeparableConvWithBatchNormParams = { ...@@ -18,13 +21,13 @@ export type SeparableConvWithBatchNormParams = {
} }
export type DenseBlock3Params = { export type DenseBlock3Params = {
conv0: TfjsImageRecognitionBase.SeparableConvParams | TfjsImageRecognitionBase.ConvParams conv0: SeparableConvParams | ConvParams
conv1: TfjsImageRecognitionBase.SeparableConvParams conv1: SeparableConvParams
conv2: TfjsImageRecognitionBase.SeparableConvParams conv2: SeparableConvParams
} }
export type DenseBlock4Params = DenseBlock3Params & { export type DenseBlock4Params = DenseBlock3Params & {
conv3: TfjsImageRecognitionBase.SeparableConvParams conv3: SeparableConvParams
} }
export type TinyFaceFeatureExtractorParams = { export type TinyFaceFeatureExtractorParams = {
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { IDimensions, isEven, NetInput, Point, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
import { IDimensions, Point } from '../classes';
import { FaceLandmarks68 } from '../classes/FaceLandmarks68'; import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
import { NetInput, TNetInput, toNetInput } from '../dom';
import { FaceFeatureExtractorParams, TinyFaceFeatureExtractorParams } from '../faceFeatureExtractor/types'; import { FaceFeatureExtractorParams, TinyFaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
import { FaceProcessor } from '../faceProcessor/FaceProcessor'; import { FaceProcessor } from '../faceProcessor/FaceProcessor';
import { isEven } from '../utils';
export abstract class FaceLandmark68NetBase< export abstract class FaceLandmark68NetBase<
TExtractorParams extends FaceFeatureExtractorParams | TinyFaceFeatureExtractorParams TExtractorParams extends FaceFeatureExtractorParams | TinyFaceFeatureExtractorParams
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput, NeuralNetwork } from 'tfjs-image-recognition-base';
import { fullyConnectedLayer } from '../common/fullyConnectedLayer'; import { fullyConnectedLayer } from '../common/fullyConnectedLayer';
import { NetInput } from '../dom';
import { import {
FaceFeatureExtractorParams, FaceFeatureExtractorParams,
IFaceFeatureExtractor, IFaceFeatureExtractor,
TinyFaceFeatureExtractorParams, TinyFaceFeatureExtractorParams,
} from '../faceFeatureExtractor/types'; } from '../faceFeatureExtractor/types';
import { NeuralNetwork } from '../NeuralNetwork';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap'; import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
import { NetParams } from './types'; import { NetParams } from './types';
......
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base'; import { extractFCParamsFactory, extractWeightsFactory, ParamMapping } from '../common';
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParams(weights: Float32Array, channelsIn: number, channelsOut: number): { params: NetParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } { export function extractParams(weights: Float32Array, channelsIn: number, channelsOut: number): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] const paramMappings: ParamMapping[] = []
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights
} = TfjsImageRecognitionBase.extractWeightsFactory(weights) } = extractWeightsFactory(weights)
const extractFCParams = TfjsImageRecognitionBase.extractFCParamsFactory(extractWeights, paramMappings) const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings)
const fc = extractFCParams(channelsIn, channelsOut, 'fc') const fc = extractFCParams(channelsIn, channelsOut, 'fc')
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common';
import { NetParams } from './types'; import { NetParams } from './types';
export function extractParamsFromWeigthMap( export function extractParamsFromWeigthMap(
weightMap: tf.NamedTensorMap weightMap: tf.NamedTensorMap
): { params: NetParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } { ): { params: NetParams, paramMappings: ParamMapping[] } {
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] const paramMappings: ParamMapping[] = []
const extractWeightEntry = TfjsImageRecognitionBase.extractWeightEntryFactory(weightMap, paramMappings) const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
function extractFcParams(prefix: string): TfjsImageRecognitionBase.FCParams { function extractFcParams(prefix: string): FCParams {
const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2) const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2)
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1) const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
return { weights, bias } return { weights, bias }
...@@ -21,7 +21,7 @@ export function extractParamsFromWeigthMap( ...@@ -21,7 +21,7 @@ export function extractParamsFromWeigthMap(
fc: extractFcParams('fc') fc: extractFcParams('fc')
} }
TfjsImageRecognitionBase.disposeUnusedWeightTensors(weightMap, paramMappings) disposeUnusedWeightTensors(weightMap, paramMappings)
return { params, paramMappings } return { params, paramMappings }
} }
\ No newline at end of file
import { TfjsImageRecognitionBase } from 'tfjs-image-recognition-base'; import { FCParams } from '../common';
export type NetParams = { export type NetParams = {
fc: TfjsImageRecognitionBase.FCParams fc: FCParams
} }
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NetInput, NeuralNetwork, normalize, TNetInput, toNetInput } from 'tfjs-image-recognition-base';
import { NetInput, TNetInput, toNetInput } from '../dom';
import { NeuralNetwork } from '../NeuralNetwork';
import { normalize } from '../ops';
import { convDown } from './convLayer'; import { convDown } from './convLayer';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap'; import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { isFloat, TfjsImageRecognitionBase } from 'tfjs-image-recognition-base';
import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common';
import { isFloat } from '../utils';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types'; import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
function extractorsFactory(extractWeights: TfjsImageRecognitionBase.ExtractWeightsFunction, paramMappings: TfjsImageRecognitionBase.ParamMapping[]) { function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D { function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {
const weights = extractWeights(numFilterValues) const weights = extractWeights(numFilterValues)
...@@ -26,7 +27,7 @@ function extractorsFactory(extractWeights: TfjsImageRecognitionBase.ExtractWeigh ...@@ -26,7 +27,7 @@ function extractorsFactory(extractWeights: TfjsImageRecognitionBase.ExtractWeigh
numFilters: number, numFilters: number,
filterSize: number, filterSize: number,
mappedPrefix: string mappedPrefix: string
): TfjsImageRecognitionBase.ConvParams { ): ConvParams {
const filters = extractFilterValues(numFilterValues, numFilters, filterSize) const filters = extractFilterValues(numFilterValues, numFilters, filterSize)
const bias = tf.tensor1d(extractWeights(numFilters)) const bias = tf.tensor1d(extractWeights(numFilters))
...@@ -89,14 +90,14 @@ function extractorsFactory(extractWeights: TfjsImageRecognitionBase.ExtractWeigh ...@@ -89,14 +90,14 @@ function extractorsFactory(extractWeights: TfjsImageRecognitionBase.ExtractWeigh
} }
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: TfjsImageRecognitionBase.ParamMapping[] } { export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const { const {
extractWeights, extractWeights,
getRemainingWeights getRemainingWeights
} = TfjsImageRecognitionBase.extractWeightsFactory(weights) } = extractWeightsFactory(weights)
const paramMappings: TfjsImageRecognitionBase.ParamMapping[] = [] const paramMappings: ParamMapping[] = []
const { const {
extractConvLayerParams, extractConvLayerParams,
......
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment