Commit cdd2c49d by vincent

extract methods for face tensors and face canvas + pad and resize face…

extract methods for face tensors and face canvas + pad and resize face recognition net input to 150x150 + some fixes
parent f7c90389
import { Dimensions, TMediaElement, TNetInput } from './types'; import { Dimensions, TMediaElement, TNetInput } from './types';
import { getContext2dOrThrow, getElement, getMediaDimensions } from './utils'; import { createCanvas, getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
export class NetInput { export class NetInput {
private _canvases: HTMLCanvasElement[] private _canvases: HTMLCanvasElement[]
...@@ -42,10 +42,7 @@ export class NetInput { ...@@ -42,10 +42,7 @@ export class NetInput {
// if input is batch type, make sure every canvas has the same dimensions // if input is batch type, make sure every canvas has the same dimensions
const { width, height } = this.dims || dims || getMediaDimensions(media) const { width, height } = this.dims || dims || getMediaDimensions(media)
const canvas = document.createElement('canvas') const canvas = createCanvas({ width, height })
canvas.width = width
canvas.height = height
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height) getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
this._canvases.push(canvas) this._canvases.push(canvas)
} }
......
import * as tf from '@tensorflow/tfjs-core';
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult';
import { NetInput } from './NetInput';
import { getImageTensor } from './transformInputs';
import { TNetInput } from './types';
/**
* Extracts the tensors of the image regions containing the detected faces.
* Returned tensors have to be disposed manually once you don't need them anymore!
* Useful if you want to compute the face descriptors for the face
* images. Using this method is faster then extracting a canvas for each face and
* converting them to tensors individually.
*
* @param input The image that face detection has been performed on.
* @param detections The face detection results for that image.
* @returns Tensors of the corresponding image region for each detected face.
*/
export function extractFaceTensors(
image: tf.Tensor | NetInput | TNetInput,
detections: FaceDetectionResult[]
): tf.Tensor4D[] {
return tf.tidy(() => {
const imgTensor = getImageTensor(image)
// TODO handle batches
const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape
const faceTensors = detections.map(det => {
const { x, y, width, height } = det.forSize(imgWidth, imgHeight).box
return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels])
})
return faceTensors
})
}
\ No newline at end of file
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult';
import { createCanvas, getContext2dOrThrow } from './utils';
/**
* Extracts the image regions containing the detected faces.
*
* @param input The image that face detection has been performed on.
* @param detections The face detection results for that image.
* @returns The Canvases of the corresponding image region for each detected face.
*/
export function extractFaces(
image: HTMLCanvasElement,
detections: FaceDetectionResult[]
): HTMLCanvasElement[] {
const ctx = getContext2dOrThrow(image)
return detections.map(det => {
const { x, y, width, height } = det.forSize(image.width, image.height).box
const faceImg = createCanvas({ width, height })
getContext2dOrThrow(faceImg)
.putImageData(ctx.getImageData(x, y, width, height), 0, 0)
return faceImg
})
}
\ No newline at end of file
import { FaceDetectionNet } from './types'; import { FaceDetectionNet } from './types';
export class FaceDetectionResult { export class FaceDetectionResult {
private score: number private _score: number
private top: number private _topRelative: number
private left: number private _leftRelative: number
private bottom: number private _bottomRelative: number
private right: number private _rightRelative: number
constructor( constructor(
score: number, score: number,
top: number, topRelative: number,
left: number, leftRelative: number,
bottom: number, bottomRelative: number,
right: number rightRelative: number
) { ) {
this.score = score this._score = score
this.top = Math.max(0, top), this._topRelative = Math.max(0, topRelative),
this.left = Math.max(0, left), this._leftRelative = Math.max(0, leftRelative),
this.bottom = Math.min(1.0, bottom), this._bottomRelative = Math.min(1.0, bottomRelative),
this.right = Math.min(1.0, right) this._rightRelative = Math.min(1.0, rightRelative)
} }
public forSize(width: number, height: number): FaceDetectionNet.Detection { public forSize(width: number, height: number): FaceDetectionNet.Detection {
const x = Math.floor(this._leftRelative * width)
const y = Math.floor(this._topRelative * height)
return { return {
score: this.score, score: this._score,
box: { box: {
top: this.top * height, x,
left: this.left * width, y,
bottom: this.bottom * height, width: Math.floor(this._rightRelative * width) - x,
right: this.right * width height: Math.floor(this._bottomRelative * height) - y
} }
} }
} }
......
...@@ -69,10 +69,10 @@ export namespace FaceDetectionNet { ...@@ -69,10 +69,10 @@ export namespace FaceDetectionNet {
export type Detection = { export type Detection = {
score: number score: number
box: { box: {
top: number, x: number,
left: number, y: number,
right: number, width: number,
bottom: number height: number
} }
} }
......
...@@ -14,7 +14,13 @@ export function faceRecognitionNet(weights: Float32Array) { ...@@ -14,7 +14,13 @@ export function faceRecognitionNet(weights: Float32Array) {
function forward(input: tf.Tensor | NetInput | TNetInput) { function forward(input: tf.Tensor | NetInput | TNetInput) {
return tf.tidy(() => { return tf.tidy(() => {
const x = normalize(padToSquare(getImageTensor(input))) // TODO pad on both sides, to keep face centered
let x = padToSquare(getImageTensor(input))
// work with 150 x 150 sized face images
if (x.shape[1] !== 150 || x.shape[2] !== 150) {
x = tf.image.resizeBilinear(x, [150, 150])
}
x = normalize(x)
let out = convDown(x, params.conv32_down) let out = convDown(x, params.conv32_down)
out = tf.maxPool(out, 3, 2, 'valid') out = tf.maxPool(out, 3, 2, 'valid')
......
...@@ -2,12 +2,16 @@ import { euclideanDistance } from './euclideanDistance'; ...@@ -2,12 +2,16 @@ import { euclideanDistance } from './euclideanDistance';
import { faceDetectionNet } from './faceDetectionNet'; import { faceDetectionNet } from './faceDetectionNet';
import { faceRecognitionNet } from './faceRecognitionNet'; import { faceRecognitionNet } from './faceRecognitionNet';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
import * as tf from '@tensorflow/tfjs-core';
export { export {
euclideanDistance, euclideanDistance,
faceDetectionNet, faceDetectionNet,
faceRecognitionNet, faceRecognitionNet,
NetInput NetInput,
tf
} }
export * from './extractFaces'
export * from './extractFaceTensors'
export * from './utils' export * from './utils'
\ No newline at end of file
...@@ -15,7 +15,10 @@ export type DrawBoxOptions = { ...@@ -15,7 +15,10 @@ export type DrawBoxOptions = {
} }
export type DrawTextOptions = { export type DrawTextOptions = {
lineWidth: number
fontSize: number fontSize: number
fontStyle: string fontStyle: string
color: string color: string
} }
\ No newline at end of file
export type DrawOptions = DrawBoxOptions & DrawTextOptions
\ No newline at end of file
import { FaceDetectionNet } from './faceDetectionNet/types'; import { FaceDetectionNet } from './faceDetectionNet/types';
import { DrawBoxOptions, DrawTextOptions } from './types'; import { Dimensions, DrawBoxOptions, DrawOptions, DrawTextOptions } from './types';
export function isFloat(num: number) { export function isFloat(num: number) {
return num % 1 !== 0 return num % 1 !== 0
...@@ -24,7 +24,23 @@ export function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingC ...@@ -24,7 +24,23 @@ export function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingC
return ctx return ctx
} }
export function createCanvas({ width, height}: Dimensions): HTMLCanvasElement {
const canvas = document.createElement('canvas')
canvas.width = width
canvas.height = height
return canvas
}
export function createCanvasWithImageData({ width, height}: Dimensions, buf: Uint8ClampedArray): HTMLCanvasElement {
const canvas = createCanvas({ width, height })
getContext2dOrThrow(canvas).putImageData(new ImageData(buf, width, height), 0, 0)
return canvas
}
export function getMediaDimensions(media: HTMLImageElement | HTMLVideoElement) { export function getMediaDimensions(media: HTMLImageElement | HTMLVideoElement) {
if (media instanceof HTMLImageElement) {
return { width: media.naturalWidth, height: media.naturalHeight }
}
if (media instanceof HTMLVideoElement) { if (media instanceof HTMLVideoElement) {
return { width: media.videoWidth, height: media.videoHeight } return { width: media.videoWidth, height: media.videoHeight }
} }
...@@ -49,6 +65,15 @@ export function bufferToImage(buf: Blob): Promise<HTMLImageElement> { ...@@ -49,6 +65,15 @@ export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
}) })
} }
export function getDefaultDrawOptions(): DrawOptions {
return {
color: 'blue',
lineWidth: 2,
fontSize: 20,
fontStyle: 'Georgia'
}
}
export function drawBox( export function drawBox(
ctx: CanvasRenderingContext2D, ctx: CanvasRenderingContext2D,
x: number, x: number,
...@@ -69,9 +94,11 @@ export function drawText( ...@@ -69,9 +94,11 @@ export function drawText(
text: string, text: string,
options: DrawTextOptions options: DrawTextOptions
) { ) {
const padText = 2 + options.lineWidth
ctx.fillStyle = options.color ctx.fillStyle = options.color
ctx.font = `${options.fontSize}px ${options.fontStyle}` ctx.font = `${options.fontSize}px ${options.fontStyle}`
ctx.fillText(text, x, y) ctx.fillText(text, x + padText, y + padText + (options.fontSize * 0.6))
} }
export function drawDetection( export function drawDetection(
...@@ -95,38 +122,35 @@ export function drawDetection( ...@@ -95,38 +122,35 @@ export function drawDetection(
} = det } = det
const { const {
left, x,
right, y,
top, width,
bottom height
} = box } = box
const { const drawOptions = Object.assign(
color = 'blue', getDefaultDrawOptions(),
lineWidth = 2, (options || {})
fontSize = 20, )
fontStyle = 'Georgia',
withScore = true
} = (options || {})
const padText = 2 + lineWidth const { withScore } = Object.assign({ withScore: true }, (options || {}))
const ctx = getContext2dOrThrow(canvas) const ctx = getContext2dOrThrow(canvas)
drawBox( drawBox(
ctx, ctx,
left, x,
top, y,
right - left, width,
bottom - top, height,
{ lineWidth, color } drawOptions
) )
if (withScore) { if (withScore) {
drawText( drawText(
ctx, ctx,
left + padText, x,
top + (fontSize * 0.6) + padText, y,
`${round(score)}`, `${round(score)}`,
{ fontSize, fontStyle, color } drawOptions
) )
} }
}) })
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment