Commit 2b127e43 by vincent

add centering option to padToSquare and center input face of face recognition net

parent d7962a58
...@@ -2,7 +2,7 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,7 +2,7 @@ import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult'; import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import { getImageTensor } from './transformInputs'; import { getImageTensor } from './getImageTensor';
import { TNetInput } from './types'; import { TNetInput } from './types';
/** /**
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { getImageTensor } from '../getImageTensor';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { getImageTensor, padToSquare } from '../transformInputs'; import { padToSquare } from '../padToSquare';
import { TNetInput } from '../types'; import { TNetInput } from '../types';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { FaceDetectionResult } from './FaceDetectionResult'; import { FaceDetectionResult } from './FaceDetectionResult';
...@@ -49,7 +50,7 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -49,7 +50,7 @@ export function faceDetectionNet(weights: Float32Array) {
} = tf.tidy(() => { } = tf.tidy(() => {
let imgTensor = getImageTensor(input) let imgTensor = getImageTensor(input)
const [_, height, width] = imgTensor.shape const [height, width] = imgTensor.shape.slice(1)
imgTensor = padToSquare(imgTensor) imgTensor = padToSquare(imgTensor)
paddedHeightRelative = imgTensor.shape[1] / height paddedHeightRelative = imgTensor.shape[1] / height
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { getImageTensor } from '../getImageTensor';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { getImageTensor, padToSquare } from '../transformInputs'; import { padToSquare } from '../padToSquare';
import { TNetInput } from '../types'; import { TNetInput } from '../types';
import { convDown } from './convLayer'; import { convDown } from './convLayer';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
...@@ -14,8 +15,7 @@ export function faceRecognitionNet(weights: Float32Array) { ...@@ -14,8 +15,7 @@ export function faceRecognitionNet(weights: Float32Array) {
function forward(input: tf.Tensor | NetInput | TNetInput) { function forward(input: tf.Tensor | NetInput | TNetInput) {
return tf.tidy(() => { return tf.tidy(() => {
// TODO pad on both sides, to keep face centered let x = padToSquare(getImageTensor(input), true)
let x = padToSquare(getImageTensor(input))
// work with 150 x 150 sized face images // work with 150 x 150 sized face images
if (x.shape[1] !== 150 || x.shape[2] !== 150) { if (x.shape[1] !== 150 || x.shape[2] !== 150) {
x = tf.image.resizeBilinear(x, [150, 150]) x = tf.image.resizeBilinear(x, [150, 150])
......
...@@ -3,23 +3,6 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -3,23 +3,6 @@ import * as tf from '@tensorflow/tfjs-core';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import { TNetInput } from './types'; import { TNetInput } from './types';
export function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
return tf.tidy(() => {
const [_, height, width] = imgTensor.shape
if (height === width) {
return imgTensor
}
if (height > width) {
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 2)
}
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 1)
})
}
export function getImageTensor(input: tf.Tensor | NetInput | TNetInput): tf.Tensor4D { export function getImageTensor(input: tf.Tensor | NetInput | TNetInput): tf.Tensor4D {
return tf.tidy(() => { return tf.tidy(() => {
if (input instanceof tf.Tensor) { if (input instanceof tf.Tensor) {
......
import * as tf from '@tensorflow/tfjs-core';
import { euclideanDistance } from './euclideanDistance'; import { euclideanDistance } from './euclideanDistance';
import { faceDetectionNet } from './faceDetectionNet'; import { faceDetectionNet } from './faceDetectionNet';
import { faceRecognitionNet } from './faceRecognitionNet'; import { faceRecognitionNet } from './faceRecognitionNet';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import * as tf from '@tensorflow/tfjs-core'; import { padToSquare } from './padToSquare';
export { export {
euclideanDistance, euclideanDistance,
faceDetectionNet, faceDetectionNet,
faceRecognitionNet, faceRecognitionNet,
NetInput, NetInput,
tf tf,
padToSquare
} }
export * from './extractFaces' export * from './extractFaces'
......
import * as tf from '@tensorflow/tfjs-core';
/**
* Pads the smaller dimension of an image tensor with zeros, such that width === height.
*
* @param imgTensor The image tensor.
* @param isCenterImage (optional, default: false) If true, add padding on both sides of the image, such that the image
* @returns The padded tensor with width === height.
*/
export function padToSquare(
imgTensor: tf.Tensor4D,
isCenterImage: boolean = false
): tf.Tensor4D {
return tf.tidy(() => {
const [height, width] = imgTensor.shape.slice(1)
if (height === width) {
return imgTensor
}
const paddingAmount = Math.floor(Math.abs(height - width) * (isCenterImage ? 0.5 : 1))
const paddingAxis = height > width ? 2 : 1
const paddingTensorShape = imgTensor.shape.slice() as [number, number, number, number]
paddingTensorShape[paddingAxis] = paddingAmount
const tensorsToStack = (isCenterImage ? [tf.fill(paddingTensorShape, 0)] : [])
.concat([imgTensor, tf.fill(paddingTensorShape, 0)]) as tf.Tensor4D[]
return tf.concat(tensorsToStack, paddingAxis)
})
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment