Commit ec095c07 by vincent

final landmark net implementation

parent 72d280cf
import { Dimensions, TMediaElement, TNetInput } from './types'; import { Dimensions, TMediaElement, TNetInput } from './types';
import { createCanvas, getContext2dOrThrow, getElement, getMediaDimensions } from './utils'; import { createCanvasFromMedia, getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
export class NetInput { export class NetInput {
private _canvases: HTMLCanvasElement[] private _canvases: HTMLCanvasElement[]
...@@ -40,11 +40,8 @@ export class NetInput { ...@@ -40,11 +40,8 @@ export class NetInput {
} }
// if input is batch type, make sure every canvas has the same dimensions // if input is batch type, make sure every canvas has the same dimensions
const { width, height } = this.dims || dims || getMediaDimensions(media) const canvasDims = this.dims || dims
this._canvases.push(createCanvasFromMedia(media, canvasDims))
const canvas = createCanvas({ width, height })
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
this._canvases.push(canvas)
} }
public get canvases() : HTMLCanvasElement[] { public get canvases() : HTMLCanvasElement[] {
......
export class Point {
public x: number
public y: number
constructor(x: number, y: number) {
this.x = x
this.y = y
}
}
\ No newline at end of file
...@@ -5,12 +5,15 @@ import { ConvParams } from './types'; ...@@ -5,12 +5,15 @@ import { ConvParams } from './types';
export function convLayer( export function convLayer(
x: tf.Tensor4D, x: tf.Tensor4D,
params: ConvParams, params: ConvParams,
padding: 'valid' | 'same' = 'same' padding: 'valid' | 'same' = 'same',
withRelu: boolean = false
): tf.Tensor4D { ): tf.Tensor4D {
return tf.tidy(() => return tf.tidy(() => {
tf.add( const out = tf.add(
tf.conv2d(x, params.filters, [1, 1], padding), tf.conv2d(x, params.filters, [1, 1], padding),
params.bias params.bias
) ) as tf.Tensor4D
)
return withRelu ? tf.relu(out) : out
})
} }
\ No newline at end of file
import { Point } from '../Point';
import { Dimensions } from '../types';
export class FaceLandmarks {
private _faceLandmarks: Point[]
private _imageWidth: number
private _imageHeight: number
constructor(
relativeFaceLandmarkPositions: Point[],
imageDims: Dimensions
) {
const { width, height } = imageDims
this._imageWidth = width
this._imageHeight = height
this._faceLandmarks = relativeFaceLandmarkPositions.map(
pt => new Point(pt.x * width, pt.y * height)
)
}
public getPositions() {
return this._faceLandmarks
}
public getRelativePositions() {
return this._faceLandmarks.map(
pt => new Point(pt.x / this._imageWidth, pt.y / this._imageHeight)
)
}
public forSize(width: number, height: number): FaceLandmarks {
return new FaceLandmarks(this.getRelativePositions(), { width, height })
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { convLayer } from '../commons/convLayer';
import { ConvParams } from '../commons/types';
import { getImageTensor } from '../getImageTensor'; import { getImageTensor } from '../getImageTensor';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { padToSquare } from '../padToSquare'; import { padToSquare } from '../padToSquare';
import { TNetInput } from '../types'; import { Dimensions, TNetInput } from '../types';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { convLayer } from '../commons/convLayer'; import { FaceLandmarks } from './FaceLandmarks';
import { fullyConnectedLayer } from './fullyConnectedLayer'; import { fullyConnectedLayer } from './fullyConnectedLayer';
function conv(x: tf.Tensor4D, params: ConvParams): tf.Tensor4D {
return convLayer(x, params, 'valid', true)
}
function maxPool(x: tf.Tensor4D, strides: [number, number] = [2, 2]): tf.Tensor4D {
return tf.maxPool(x, [2, 2], strides, 'valid')
}
export function faceLandmarkNet(weights: Float32Array) { export function faceLandmarkNet(weights: Float32Array) {
const params = extractParams(weights) const params = extractParams(weights)
function forward(input: tf.Tensor | NetInput | TNetInput) { async function detectLandmarks(input: tf.Tensor | NetInput | TNetInput) {
return tf.tidy(() => { let adjustRelativeX = 0
let adjustRelativeY = 0
let imageDimensions: Dimensions | undefined
const outTensor = tf.tidy(() => {
let imgTensor = getImageTensor(input)
const [height, width] = imgTensor.shape.slice(1)
imageDimensions = { width, height }
imgTensor = padToSquare(imgTensor, true)
adjustRelativeX = (height > width) ? imgTensor.shape[2] / (2 * width) : 0
adjustRelativeY = (width > height) ? imgTensor.shape[1] / (2 * height) : 0
let x = padToSquare(getImageTensor(input), true)
// work with 128 x 128 sized face images // work with 128 x 128 sized face images
if (x.shape[1] !== 128 || x.shape[2] !== 128) { if (imgTensor.shape[1] !== 128 || imgTensor.shape[2] !== 128) {
x = tf.image.resizeBilinear(x, [128, 128]) imgTensor = tf.image.resizeBilinear(imgTensor, [128, 128])
} }
let out = convLayer(x, params.conv0_params, 'valid') let out = conv(imgTensor, params.conv0_params)
out = tf.maxPool(out, [2, 2], [2, 2], 'valid') out = maxPool(out)
out = convLayer(out, params.conv1_params, 'valid') out = conv(out, params.conv1_params)
out = convLayer(out, params.conv2_params, 'valid') out = conv(out, params.conv2_params)
out = tf.maxPool(out, [2, 2], [2, 2], 'valid') out = maxPool(out)
out = convLayer(out, params.conv3_params, 'valid') out = conv(out, params.conv3_params)
out = convLayer(out, params.conv4_params, 'valid') out = conv(out, params.conv4_params)
out = tf.maxPool(out, [2, 2], [2, 2], 'valid') out = maxPool(out)
out = convLayer(out, params.conv5_params, 'valid') out = conv(out, params.conv5_params)
out = convLayer(out, params.conv6_params, 'valid') out = conv(out, params.conv6_params)
out = tf.maxPool(out, [2, 2], [1, 1], 'valid') out = maxPool(out, [1, 1])
out = convLayer(out, params.conv7_params, 'valid') out = conv(out, params.conv7_params)
const fc0 = fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0_params) const fc0 = tf.relu(fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0_params))
const fc1 = fullyConnectedLayer(fc0, params.fc1_params) const fc1 = fullyConnectedLayer(fc0, params.fc1_params)
return fc1 return fc1
}) })
const faceLandmarksArray = Array.from(await outTensor.data())
const xCoords = faceLandmarksArray.filter((c, i) => (i - 1) % 2).map(x => x + adjustRelativeX)
const yCoords = faceLandmarksArray.filter((c, i) => i % 2).map(y => y + adjustRelativeY)
outTensor.dispose()
return new FaceLandmarks(
Array(68).fill(0).map((_, i) => ({ x: xCoords[i], y: yCoords[i] })),
imageDimensions as Dimensions
)
} }
return { return {
forward detectLandmarks
} }
} }
\ No newline at end of file
...@@ -10,15 +10,25 @@ export type Dimensions = { ...@@ -10,15 +10,25 @@ export type Dimensions = {
} }
export type DrawBoxOptions = { export type DrawBoxOptions = {
lineWidth: number lineWidth?: number
color: string color?: string
} }
export type DrawTextOptions = { export type DrawTextOptions = {
lineWidth?: number
fontSize?: number
fontStyle?: string
color?: string
}
export type DrawLandmarksOptions = {
lineWidth?: number
color?: string
}
export type DrawOptions = {
lineWidth: number lineWidth: number
fontSize: number fontSize: number
fontStyle: string fontStyle: string
color: string color: string
} }
\ No newline at end of file
export type DrawOptions = DrawBoxOptions & DrawTextOptions
\ No newline at end of file
import { FaceDetectionNet } from './faceDetectionNet/types'; import { FaceDetectionNet } from './faceDetectionNet/types';
import { Dimensions, DrawBoxOptions, DrawOptions, DrawTextOptions } from './types'; import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
import { Dimensions, DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
export function isFloat(num: number) { export function isFloat(num: number) {
return num % 1 !== 0 return num % 1 !== 0
...@@ -24,16 +25,17 @@ export function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingC ...@@ -24,16 +25,17 @@ export function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingC
return ctx return ctx
} }
export function createCanvas({ width, height}: Dimensions): HTMLCanvasElement { export function createCanvas({ width, height }: Dimensions): HTMLCanvasElement {
const canvas = document.createElement('canvas') const canvas = document.createElement('canvas')
canvas.width = width canvas.width = width
canvas.height = height canvas.height = height
return canvas return canvas
} }
export function createCanvasWithImageData({ width, height}: Dimensions, buf: Uint8ClampedArray): HTMLCanvasElement { export function createCanvasFromMedia(media: HTMLImageElement | HTMLVideoElement, dims?: Dimensions): HTMLCanvasElement {
const { width, height } = dims || getMediaDimensions(media)
const canvas = createCanvas({ width, height }) const canvas = createCanvas({ width, height })
getContext2dOrThrow(canvas).putImageData(new ImageData(buf, width, height), 0, 0) getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
return canvas return canvas
} }
...@@ -82,8 +84,13 @@ export function drawBox( ...@@ -82,8 +84,13 @@ export function drawBox(
h: number, h: number,
options: DrawBoxOptions options: DrawBoxOptions
) { ) {
ctx.strokeStyle = options.color const drawOptions = Object.assign(
ctx.lineWidth = options.lineWidth getDefaultDrawOptions(),
(options || {})
)
ctx.strokeStyle = drawOptions.color
ctx.lineWidth = drawOptions.lineWidth
ctx.strokeRect(x, y, w, h) ctx.strokeRect(x, y, w, h)
} }
...@@ -94,11 +101,16 @@ export function drawText( ...@@ -94,11 +101,16 @@ export function drawText(
text: string, text: string,
options: DrawTextOptions options: DrawTextOptions
) { ) {
const padText = 2 + options.lineWidth const drawOptions = Object.assign(
getDefaultDrawOptions(),
(options || {})
)
const padText = 2 + drawOptions.lineWidth
ctx.fillStyle = options.color ctx.fillStyle = drawOptions.color
ctx.font = `${options.fontSize}px ${options.fontStyle}` ctx.font = `${drawOptions.fontSize}px ${drawOptions.fontStyle}`
ctx.fillText(text, x + padText, y + padText + (options.fontSize * 0.6)) ctx.fillText(text, x + padText, y + padText + (drawOptions.fontSize * 0.6))
} }
export function drawDetection( export function drawDetection(
...@@ -154,4 +166,28 @@ export function drawDetection( ...@@ -154,4 +166,28 @@ export function drawDetection(
) )
} }
}) })
}
export function drawLandmarks(
canvasArg: string | HTMLCanvasElement,
faceLandmarks: FaceLandmarks,
options?: DrawLandmarksOptions & { drawLines: boolean }
) {
const canvas = getElement(canvasArg)
if (!(canvas instanceof HTMLCanvasElement)) {
throw new Error('drawLandmarks - expected canvas to be of type: HTMLCanvasElement')
}
const drawOptions = Object.assign(
getDefaultDrawOptions(),
(options || {})
)
const { drawLines } = Object.assign({ drawLines: false }, (options || {}))
const ctx = getContext2dOrThrow(canvas)
const { lineWidth,color } = drawOptions
ctx.fillStyle = color
const ptOffset = lineWidth / 2
faceLandmarks.getPositions().forEach(pt => ctx.fillRect(pt.x - ptOffset, pt.y - ptOffset, lineWidth, lineWidth))
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment