Commit cdd2c49d by vincent

extract methods for face tensors and face canvas + pad and resize face…

extract methods for face tensors and face canvas + pad and resize face recognition net input to 150x150 + some fixes
parent f7c90389
import { Dimensions, TMediaElement, TNetInput } from './types';
import { getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
import { createCanvas, getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
export class NetInput {
private _canvases: HTMLCanvasElement[]
......@@ -42,10 +42,7 @@ export class NetInput {
// if input is batch type, make sure every canvas has the same dimensions
const { width, height } = this.dims || dims || getMediaDimensions(media)
const canvas = document.createElement('canvas')
canvas.width = width
canvas.height = height
const canvas = createCanvas({ width, height })
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
this._canvases.push(canvas)
}
......
import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult';
import { NetInput } from './NetInput';
import { getImageTensor } from './transformInputs';
import { TNetInput } from './types';
/**
* Extracts the tensors of the image regions containing the detected faces.
* Returned tensors have to be disposed manually once you don't need them anymore!
* Useful if you want to compute the face descriptors for the face
* images. Using this method is faster then extracting a canvas for each face and
* converting them to tensors individually.
*
* @param input The image that face detection has been performed on.
* @param detections The face detection results for that image.
* @returns Tensors of the corresponding image region for each detected face.
*/
export function extractFaceTensors(
image: tf.Tensor | NetInput | TNetInput,
detections: FaceDetectionResult[]
): tf.Tensor4D[] {
return tf.tidy(() => {
const imgTensor = getImageTensor(image)
// TODO handle batches
const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape
const faceTensors = detections.map(det => {
const { x, y, width, height } = det.forSize(imgWidth, imgHeight).box
return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels])
})
return faceTensors
})
}
\ No newline at end of file
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult';
import { createCanvas, getContext2dOrThrow } from './utils';
/**
* Extracts the image regions containing the detected faces.
*
* @param input The image that face detection has been performed on.
* @param detections The face detection results for that image.
* @returns The Canvases of the corresponding image region for each detected face.
*/
export function extractFaces(
image: HTMLCanvasElement,
detections: FaceDetectionResult[]
): HTMLCanvasElement[] {
const ctx = getContext2dOrThrow(image)
return detections.map(det => {
const { x, y, width, height } = det.forSize(image.width, image.height).box
const faceImg = createCanvas({ width, height })
getContext2dOrThrow(faceImg)
.putImageData(ctx.getImageData(x, y, width, height), 0, 0)
return faceImg
})
}
\ No newline at end of file
import { FaceDetectionNet } from './types';
export class FaceDetectionResult {
private score: number
private top: number
private left: number
private bottom: number
private right: number
private _score: number
private _topRelative: number
private _leftRelative: number
private _bottomRelative: number
private _rightRelative: number
constructor(
score: number,
top: number,
left: number,
bottom: number,
right: number
topRelative: number,
leftRelative: number,
bottomRelative: number,
rightRelative: number
) {
this.score = score
this.top = Math.max(0, top),
this.left = Math.max(0, left),
this.bottom = Math.min(1.0, bottom),
this.right = Math.min(1.0, right)
this._score = score
this._topRelative = Math.max(0, topRelative),
this._leftRelative = Math.max(0, leftRelative),
this._bottomRelative = Math.min(1.0, bottomRelative),
this._rightRelative = Math.min(1.0, rightRelative)
}
public forSize(width: number, height: number): FaceDetectionNet.Detection {
const x = Math.floor(this._leftRelative * width)
const y = Math.floor(this._topRelative * height)
return {
score: this.score,
score: this._score,
box: {
top: this.top * height,
left: this.left * width,
bottom: this.bottom * height,
right: this.right * width
x,
y,
width: Math.floor(this._rightRelative * width) - x,
height: Math.floor(this._bottomRelative * height) - y
}
}
}
......
......@@ -69,10 +69,10 @@ export namespace FaceDetectionNet {
export type Detection = {
score: number
box: {
top: number,
left: number,
right: number,
bottom: number
x: number,
y: number,
width: number,
height: number
}
}
......
......@@ -14,7 +14,13 @@ export function faceRecognitionNet(weights: Float32Array) {
function forward(input: tf.Tensor | NetInput | TNetInput) {
return tf.tidy(() => {
const x = normalize(padToSquare(getImageTensor(input)))
// TODO pad on both sides, to keep face centered
let x = padToSquare(getImageTensor(input))
// work with 150 x 150 sized face images
if (x.shape[1] !== 150 || x.shape[2] !== 150) {
x = tf.image.resizeBilinear(x, [150, 150])
}
x = normalize(x)
let out = convDown(x, params.conv32_down)
out = tf.maxPool(out, 3, 2, 'valid')
......
......@@ -2,12 +2,16 @@ import { euclideanDistance } from './euclideanDistance';
import { faceDetectionNet } from './faceDetectionNet';
import { faceRecognitionNet } from './faceRecognitionNet';
import { NetInput } from './NetInput';
import * as tf from '@tensorflow/tfjs-core';
export {
euclideanDistance,
faceDetectionNet,
faceRecognitionNet,
NetInput
NetInput,
tf
}
export * from './extractFaces'
export * from './extractFaceTensors'
export * from './utils'
\ No newline at end of file
......@@ -15,7 +15,10 @@ export type DrawBoxOptions = {
}
export type DrawTextOptions = {
lineWidth: number
fontSize: number
fontStyle: string
color: string
}
export type DrawOptions = DrawBoxOptions & DrawTextOptions
\ No newline at end of file
import { FaceDetectionNet } from './faceDetectionNet/types';
import { DrawBoxOptions, DrawTextOptions } from './types';
import { Dimensions, DrawBoxOptions, DrawOptions, DrawTextOptions } from './types';
export function isFloat(num: number) {
return num % 1 !== 0
......@@ -24,7 +24,23 @@ export function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingC
return ctx
}
export function createCanvas({ width, height}: Dimensions): HTMLCanvasElement {
const canvas = document.createElement('canvas')
canvas.width = width
canvas.height = height
return canvas
}
export function createCanvasWithImageData({ width, height}: Dimensions, buf: Uint8ClampedArray): HTMLCanvasElement {
const canvas = createCanvas({ width, height })
getContext2dOrThrow(canvas).putImageData(new ImageData(buf, width, height), 0, 0)
return canvas
}
export function getMediaDimensions(media: HTMLImageElement | HTMLVideoElement) {
if (media instanceof HTMLImageElement) {
return { width: media.naturalWidth, height: media.naturalHeight }
}
if (media instanceof HTMLVideoElement) {
return { width: media.videoWidth, height: media.videoHeight }
}
......@@ -49,6 +65,15 @@ export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
})
}
export function getDefaultDrawOptions(): DrawOptions {
return {
color: 'blue',
lineWidth: 2,
fontSize: 20,
fontStyle: 'Georgia'
}
}
export function drawBox(
ctx: CanvasRenderingContext2D,
x: number,
......@@ -69,9 +94,11 @@ export function drawText(
text: string,
options: DrawTextOptions
) {
const padText = 2 + options.lineWidth
ctx.fillStyle = options.color
ctx.font = `${options.fontSize}px ${options.fontStyle}`
ctx.fillText(text, x, y)
ctx.fillText(text, x + padText, y + padText + (options.fontSize * 0.6))
}
export function drawDetection(
......@@ -95,38 +122,35 @@ export function drawDetection(
} = det
const {
left,
right,
top,
bottom
x,
y,
width,
height
} = box
const {
color = 'blue',
lineWidth = 2,
fontSize = 20,
fontStyle = 'Georgia',
withScore = true
} = (options || {})
const drawOptions = Object.assign(
getDefaultDrawOptions(),
(options || {})
)
const padText = 2 + lineWidth
const { withScore } = Object.assign({ withScore: true }, (options || {}))
const ctx = getContext2dOrThrow(canvas)
drawBox(
ctx,
left,
top,
right - left,
bottom - top,
{ lineWidth, color }
x,
y,
width,
height,
drawOptions
)
if (withScore) {
drawText(
ctx,
left + padText,
top + (fontSize * 0.6) + padText,
x,
y,
`${round(score)}`,
{ fontSize, fontStyle, color }
drawOptions
)
}
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment