Commit 21b3c7bf by vincent

remove padding of input image of face landmark net + helper for shifting face landmarks by offset

parent 68fe7ca5
export class Point { export interface IPoint {
x: number
y: number
}
export class Point implements IPoint {
public x: number public x: number
public y: number public y: number
...@@ -6,4 +11,20 @@ export class Point { ...@@ -6,4 +11,20 @@ export class Point {
this.x = x this.x = x
this.y = y this.y = y
} }
public add(pt: IPoint): Point {
return new Point(this.x + pt.x, this.y + pt.y)
}
public sub(pt: IPoint): Point {
return new Point(this.x - pt.x, this.y - pt.y)
}
public mul(pt: IPoint): Point {
return new Point(this.x * pt.x, this.y * pt.y)
}
public div(pt: IPoint): Point {
return new Point(this.x / pt.x, this.y / pt.y)
}
} }
\ No newline at end of file
...@@ -10,4 +10,13 @@ export class Rect { ...@@ -10,4 +10,13 @@ export class Rect {
this.width = width this.width = width
this.height = height this.height = height
} }
public floor(): Rect {
return new Rect(
Math.floor(this.x),
Math.floor(this.y),
Math.floor(this.width),
Math.floor(this.height)
)
}
} }
\ No newline at end of file
...@@ -27,7 +27,7 @@ export function extractFaceTensors( ...@@ -27,7 +27,7 @@ export function extractFaceTensors(
const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape
const faceTensors = detections.map(det => { const faceTensors = detections.map(det => {
const { x, y, width, height } = det.forSize(imgWidth, imgHeight).getBox() const { x, y, width, height } = det.forSize(imgWidth, imgHeight).getBox().floor()
return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels]) return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels])
}) })
......
...@@ -15,8 +15,7 @@ export function extractFaces( ...@@ -15,8 +15,7 @@ export function extractFaces(
const ctx = getContext2dOrThrow(image) const ctx = getContext2dOrThrow(image)
return detections.map(det => { return detections.map(det => {
const { x, y, width, height } = det.forSize(image.width, image.height).getBox() const { x, y, width, height } = det.forSize(image.width, image.height).getBox().floor()
const faceImg = createCanvas({ width, height }) const faceImg = createCanvas({ width, height })
getContext2dOrThrow(faceImg) getContext2dOrThrow(faceImg)
.putImageData(ctx.getImageData(x, y, width, height), 0, 0) .putImageData(ctx.getImageData(x, y, width, height), 0, 0)
......
...@@ -17,10 +17,10 @@ export class FaceDetection { ...@@ -17,10 +17,10 @@ export class FaceDetection {
this._imageHeight = height this._imageHeight = height
this._score = score this._score = score
this._box = new Rect( this._box = new Rect(
Math.floor(relativeBox.x * width), relativeBox.x * width,
Math.floor(relativeBox.y * height), relativeBox.y * height,
Math.floor(relativeBox.width * width), relativeBox.width * width,
Math.floor(relativeBox.height * height) relativeBox.height * height
) )
} }
...@@ -32,6 +32,14 @@ export class FaceDetection { ...@@ -32,6 +32,14 @@ export class FaceDetection {
return this._box return this._box
} }
public getImageWidth() {
return this._imageWidth
}
public getImageHeight() {
return this._imageHeight
}
public getRelativeBox() { public getRelativeBox() {
return new Rect( return new Rect(
this._box.x / this._imageWidth, this._box.x / this._imageWidth,
......
import { Point } from '../Point'; import { Point, IPoint } from '../Point';
import { Dimensions } from '../types'; import { Dimensions } from '../types';
export class FaceLandmarks { export class FaceLandmarks {
private _faceLandmarks: Point[] private _faceLandmarks: Point[]
private _imageWidth: number private _imageWidth: number
private _imageHeight: number private _imageHeight: number
private _shift: Point
constructor( constructor(
relativeFaceLandmarkPositions: Point[], relativeFaceLandmarkPositions: Point[],
imageDims: Dimensions imageDims: Dimensions,
shift: Point = new Point(0, 0)
) { ) {
const { width, height } = imageDims const { width, height } = imageDims
this._imageWidth = width this._imageWidth = width
this._imageHeight = height this._imageHeight = height
this._shift = shift
this._faceLandmarks = relativeFaceLandmarkPositions.map( this._faceLandmarks = relativeFaceLandmarkPositions.map(
pt => new Point(pt.x * width, pt.y * height) pt => pt.mul(new Point(width, height)).add(shift)
) )
} }
...@@ -24,7 +27,7 @@ export class FaceLandmarks { ...@@ -24,7 +27,7 @@ export class FaceLandmarks {
public getRelativePositions() { public getRelativePositions() {
return this._faceLandmarks.map( return this._faceLandmarks.map(
pt => new Point(pt.x / this._imageWidth, pt.y / this._imageHeight) pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
) )
} }
...@@ -57,6 +60,17 @@ export class FaceLandmarks { ...@@ -57,6 +60,17 @@ export class FaceLandmarks {
} }
public forSize(width: number, height: number): FaceLandmarks { public forSize(width: number, height: number): FaceLandmarks {
return new FaceLandmarks(this.getRelativePositions(), { width, height }) return new FaceLandmarks(
this.getRelativePositions(),
{ width, height }
)
}
public shift(x: number, y: number): FaceLandmarks {
return new FaceLandmarks(
this.getRelativePositions(),
{ width: this._imageWidth, height: this._imageHeight },
new Point(x, y)
)
} }
} }
\ No newline at end of file
...@@ -4,7 +4,7 @@ import { convLayer } from '../commons/convLayer'; ...@@ -4,7 +4,7 @@ import { convLayer } from '../commons/convLayer';
import { ConvParams } from '../commons/types'; import { ConvParams } from '../commons/types';
import { getImageTensor } from '../getImageTensor'; import { getImageTensor } from '../getImageTensor';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { padToSquare } from '../padToSquare'; import { Point } from '../Point';
import { Dimensions, TNetInput } from '../types'; import { Dimensions, TNetInput } from '../types';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { FaceLandmarks } from './FaceLandmarks'; import { FaceLandmarks } from './FaceLandmarks';
...@@ -22,8 +22,6 @@ export function faceLandmarkNet(weights: Float32Array) { ...@@ -22,8 +22,6 @@ export function faceLandmarkNet(weights: Float32Array) {
const params = extractParams(weights) const params = extractParams(weights)
async function detectLandmarks(input: tf.Tensor | NetInput | TNetInput) { async function detectLandmarks(input: tf.Tensor | NetInput | TNetInput) {
let adjustRelativeX = 0
let adjustRelativeY = 0
let imageDimensions: Dimensions | undefined let imageDimensions: Dimensions | undefined
const outTensor = tf.tidy(() => { const outTensor = tf.tidy(() => {
...@@ -31,9 +29,6 @@ export function faceLandmarkNet(weights: Float32Array) { ...@@ -31,9 +29,6 @@ export function faceLandmarkNet(weights: Float32Array) {
const [height, width] = imgTensor.shape.slice(1) const [height, width] = imgTensor.shape.slice(1)
imageDimensions = { width, height } imageDimensions = { width, height }
imgTensor = padToSquare(imgTensor, true)
adjustRelativeX = (height > width) ? imgTensor.shape[2] / (2 * width) : 0
adjustRelativeY = (width > height) ? imgTensor.shape[1] / (2 * height) : 0
// work with 128 x 128 sized face images // work with 128 x 128 sized face images
if (imgTensor.shape[1] !== 128 || imgTensor.shape[2] !== 128) { if (imgTensor.shape[1] !== 128 || imgTensor.shape[2] !== 128) {
...@@ -59,12 +54,13 @@ export function faceLandmarkNet(weights: Float32Array) { ...@@ -59,12 +54,13 @@ export function faceLandmarkNet(weights: Float32Array) {
}) })
const faceLandmarksArray = Array.from(await outTensor.data()) const faceLandmarksArray = Array.from(await outTensor.data())
const xCoords = faceLandmarksArray.filter((c, i) => (i - 1) % 2).map(x => x + adjustRelativeX)
const yCoords = faceLandmarksArray.filter((c, i) => i % 2).map(y => y + adjustRelativeY)
outTensor.dispose() outTensor.dispose()
const xCoords = faceLandmarksArray.filter((c, i) => (i - 1) % 2)
const yCoords = faceLandmarksArray.filter((c, i) => i % 2)
return new FaceLandmarks( return new FaceLandmarks(
Array(68).fill(0).map((_, i) => ({ x: xCoords[i], y: yCoords[i] })), Array(68).fill(0).map((_, i) => new Point(xCoords[i], yCoords[i])),
imageDimensions as Dimensions imageDimensions as Dimensions
) )
} }
......
...@@ -201,31 +201,31 @@ export function drawLandmarks( ...@@ -201,31 +201,31 @@ export function drawLandmarks(
throw new Error('drawLandmarks - expected canvas to be of type: HTMLCanvasElement') throw new Error('drawLandmarks - expected canvas to be of type: HTMLCanvasElement')
} }
const drawOptions = Object.assign( const drawOptions = Object.assign(
getDefaultDrawOptions(), getDefaultDrawOptions(),
(options || {}) (options || {})
) )
const { drawLines } = Object.assign({ drawLines: false }, (options || {}))
const ctx = getContext2dOrThrow(canvas) const { drawLines } = Object.assign({ drawLines: false }, (options || {}))
const { lineWidth, color } = drawOptions
const ctx = getContext2dOrThrow(canvas)
if (drawLines) { const { lineWidth, color } = drawOptions
ctx.strokeStyle = color
ctx.lineWidth = lineWidth if (drawLines) {
drawContour(ctx, faceLandmarks.getJawOutline()) ctx.strokeStyle = color
drawContour(ctx, faceLandmarks.getLeftEyeBrow()) ctx.lineWidth = lineWidth
drawContour(ctx, faceLandmarks.getRightEyeBrow()) drawContour(ctx, faceLandmarks.getJawOutline())
drawContour(ctx, faceLandmarks.getNose()) drawContour(ctx, faceLandmarks.getLeftEyeBrow())
drawContour(ctx, faceLandmarks.getLeftEye(), true) drawContour(ctx, faceLandmarks.getRightEyeBrow())
drawContour(ctx, faceLandmarks.getRightEye(), true) drawContour(ctx, faceLandmarks.getNose())
drawContour(ctx, faceLandmarks.getMouth(), true) drawContour(ctx, faceLandmarks.getLeftEye(), true)
return drawContour(ctx, faceLandmarks.getRightEye(), true)
} drawContour(ctx, faceLandmarks.getMouth(), true)
return
}
// else draw points // else draw points
const ptOffset = lineWidth / 2 const ptOffset = lineWidth / 2
ctx.fillStyle = color ctx.fillStyle = color
faceLandmarks.getPositions().forEach(pt => ctx.fillRect(pt.x - ptOffset, pt.y - ptOffset, lineWidth, lineWidth)) faceLandmarks.getPositions().forEach(pt => ctx.fillRect(pt.x - ptOffset, pt.y - ptOffset, lineWidth, lineWidth))
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment