Commit e04770cb by vincent

finished mtcnn implementation

parent 5e957ec7
import { Point } from './Point';
import { Dimensions } from './types';
export class FaceLandmarks {
protected _imageWidth: number
protected _imageHeight: number
protected _shift: Point
protected _faceLandmarks: Point[]
constructor(
relativeFaceLandmarkPositions: Point[],
imageDims: Dimensions,
shift: Point = new Point(0, 0)
) {
const { width, height } = imageDims
this._imageWidth = width
this._imageHeight = height
this._shift = shift
this._faceLandmarks = relativeFaceLandmarkPositions.map(
pt => pt.mul(new Point(width, height)).add(shift)
)
}
public getShift(): Point {
return new Point(this._shift.x, this._shift.y)
}
public getImageWidth(): number {
return this._imageWidth
}
public getImageHeight(): number {
return this._imageHeight
}
public getPositions(): Point[] {
return this._faceLandmarks
}
public getRelativePositions(): Point[] {
return this._faceLandmarks.map(
pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
)
}
}
\ No newline at end of file
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks'; import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
export class FullFaceDescription { export class FullFaceDescription {
constructor( constructor(
private _detection: FaceDetection, private _detection: FaceDetection,
private _landmarks: FaceLandmarks, private _landmarks: FaceLandmarks68,
private _descriptor: Float32Array private _descriptor: Float32Array
) {} ) {}
...@@ -12,7 +12,7 @@ export class FullFaceDescription { ...@@ -12,7 +12,7 @@ export class FullFaceDescription {
return this._detection return this._detection
} }
public get landmarks(): FaceLandmarks { public get landmarks(): FaceLandmarks68 {
return this._landmarks return this._landmarks
} }
......
import { extractFaceTensors } from './extractFaceTensors'; import { extractFaceTensors } from './extractFaceTensors';
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet'; import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet'; import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks'; import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet'; import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
import { FullFaceDescription } from './FullFaceDescription'; import { FullFaceDescription } from './FullFaceDescription';
import { TNetInput } from './types'; import { TNetInput } from './types';
...@@ -22,10 +22,10 @@ export function allFacesFactory( ...@@ -22,10 +22,10 @@ export function allFacesFactory(
const faceTensors = await extractFaceTensors(input, detections) const faceTensors = await extractFaceTensors(input, detections)
const faceLandmarksByFace = useBatchProcessing const faceLandmarksByFace = useBatchProcessing
? await landmarkNet.detectLandmarks(faceTensors) as FaceLandmarks[] ? await landmarkNet.detectLandmarks(faceTensors) as FaceLandmarks68[]
: await Promise.all(faceTensors.map( : await Promise.all(faceTensors.map(
faceTensor => landmarkNet.detectLandmarks(faceTensor) faceTensor => landmarkNet.detectLandmarks(faceTensor)
)) as FaceLandmarks[] )) as FaceLandmarks68[]
faceTensors.forEach(t => t.dispose()) faceTensors.forEach(t => t.dispose())
......
import { FaceDetection } from '../faceDetectionNet/FaceDetection'; import { FaceDetection } from '../faceDetectionNet/FaceDetection';
import { FaceLandmarks } from '../faceLandmarkNet/FaceLandmarks'; import { FaceLandmarks68 } from '../faceLandmarkNet';
import { FaceLandmarks } from '../FaceLandmarks';
import { Point } from '../Point'; import { Point } from '../Point';
import { getContext2dOrThrow, resolveInput, round } from '../utils'; import { getContext2dOrThrow, resolveInput, round } from '../utils';
import { DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types'; import { DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
...@@ -150,7 +151,7 @@ export function drawLandmarks( ...@@ -150,7 +151,7 @@ export function drawLandmarks(
const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks] const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks]
faceLandmarksArray.forEach(landmarks => { faceLandmarksArray.forEach(landmarks => {
if (drawLines) { if (drawLines && landmarks instanceof FaceLandmarks68) {
ctx.strokeStyle = color ctx.strokeStyle = color
ctx.lineWidth = lineWidth ctx.lineWidth = lineWidth
drawContour(ctx, landmarks.getJawOutline()) drawContour(ctx, landmarks.getJawOutline())
......
...@@ -9,7 +9,7 @@ import { toNetInput } from '../toNetInput'; ...@@ -9,7 +9,7 @@ import { toNetInput } from '../toNetInput';
import { TNetInput } from '../types'; import { TNetInput } from '../types';
import { isEven } from '../utils'; import { isEven } from '../utils';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { FaceLandmarks } from './FaceLandmarks'; import { FaceLandmarks68 } from './FaceLandmarks68';
import { fullyConnectedLayer } from './fullyConnectedLayer'; import { fullyConnectedLayer } from './fullyConnectedLayer';
import { loadQuantizedParams } from './loadQuantizedParams'; import { loadQuantizedParams } from './loadQuantizedParams';
import { NetParams } from './types'; import { NetParams } from './types';
...@@ -93,7 +93,7 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> { ...@@ -93,7 +93,7 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
return this.forwardInput(await toNetInput(input, true)) return this.forwardInput(await toNetInput(input, true))
} }
public async detectLandmarks(input: TNetInput): Promise<FaceLandmarks | FaceLandmarks[]> { public async detectLandmarks(input: TNetInput): Promise<FaceLandmarks68 | FaceLandmarks68[]> {
const netInput = await toNetInput(input, true) const netInput = await toNetInput(input, true)
const landmarkTensors = tf.tidy( const landmarkTensors = tf.tidy(
...@@ -106,7 +106,7 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> { ...@@ -106,7 +106,7 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
const xCoords = landmarksArray.filter((_, i) => isEven(i)) const xCoords = landmarksArray.filter((_, i) => isEven(i))
const yCoords = landmarksArray.filter((_, i) => !isEven(i)) const yCoords = landmarksArray.filter((_, i) => !isEven(i))
return new FaceLandmarks( return new FaceLandmarks68(
Array(68).fill(0).map((_, i) => new Point(xCoords[i], yCoords[i])), Array(68).fill(0).map((_, i) => new Point(xCoords[i], yCoords[i])),
{ {
height: netInput.getInputHeight(batchIdx), height: netInput.getInputHeight(batchIdx),
......
import { getCenterPoint } from '../commons/getCenterPoint'; import { getCenterPoint } from '../commons/getCenterPoint';
import { FaceDetection } from '../faceDetectionNet/FaceDetection'; import { FaceDetection } from '../faceDetectionNet/FaceDetection';
import { FaceLandmarks } from '../FaceLandmarks';
import { IPoint, Point } from '../Point'; import { IPoint, Point } from '../Point';
import { Rect } from '../Rect'; import { Rect } from '../Rect';
import { Dimensions } from '../types'; import { Dimensions } from '../types';
...@@ -9,48 +10,7 @@ const relX = 0.5 ...@@ -9,48 +10,7 @@ const relX = 0.5
const relY = 0.43 const relY = 0.43
const relScale = 0.45 const relScale = 0.45
export class FaceLandmarks { export class FaceLandmarks68 extends FaceLandmarks {
private _imageWidth: number
private _imageHeight: number
private _shift: Point
private _faceLandmarks: Point[]
constructor(
relativeFaceLandmarkPositions: Point[],
imageDims: Dimensions,
shift: Point = new Point(0, 0)
) {
const { width, height } = imageDims
this._imageWidth = width
this._imageHeight = height
this._shift = shift
this._faceLandmarks = relativeFaceLandmarkPositions.map(
pt => pt.mul(new Point(width, height)).add(shift)
)
}
public getShift(): Point {
return new Point(this._shift.x, this._shift.y)
}
public getImageWidth(): number {
return this._imageWidth
}
public getImageHeight(): number {
return this._imageHeight
}
public getPositions(): Point[] {
return this._faceLandmarks
}
public getRelativePositions(): Point[] {
return this._faceLandmarks.map(
pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
)
}
public getJawOutline(): Point[] { public getJawOutline(): Point[] {
return this._faceLandmarks.slice(0, 17) return this._faceLandmarks.slice(0, 17)
} }
...@@ -79,22 +39,22 @@ export class FaceLandmarks { ...@@ -79,22 +39,22 @@ export class FaceLandmarks {
return this._faceLandmarks.slice(48, 68) return this._faceLandmarks.slice(48, 68)
} }
public forSize(width: number, height: number): FaceLandmarks { public forSize(width: number, height: number): FaceLandmarks68 {
return new FaceLandmarks( return new FaceLandmarks68(
this.getRelativePositions(), this.getRelativePositions(),
{ width, height } { width, height }
) )
} }
public shift(x: number, y: number): FaceLandmarks { public shift(x: number, y: number): FaceLandmarks68 {
return new FaceLandmarks( return new FaceLandmarks68(
this.getRelativePositions(), this.getRelativePositions(),
{ width: this._imageWidth, height: this._imageHeight }, { width: this._imageWidth, height: this._imageHeight },
new Point(x, y) new Point(x, y)
) )
} }
public shiftByPoint(pt: IPoint): FaceLandmarks { public shiftByPoint(pt: IPoint): FaceLandmarks68 {
return this.shift(pt.x, pt.y) return this.shift(pt.x, pt.y)
} }
......
import { FaceLandmarkNet } from './FaceLandmarkNet'; import { FaceLandmarkNet } from './FaceLandmarkNet';
export * from './FaceLandmarkNet'; export * from './FaceLandmarkNet';
export * from './FaceLandmarks'; export * from './FaceLandmarks68';
export function faceLandmarkNet(weights: Float32Array) { export function faceLandmarkNet(weights: Float32Array) {
const net = new FaceLandmarkNet() const net = new FaceLandmarkNet()
......
...@@ -4,7 +4,7 @@ import { allFacesFactory } from './allFacesFactory'; ...@@ -4,7 +4,7 @@ import { allFacesFactory } from './allFacesFactory';
import { FaceDetection } from './faceDetectionNet/FaceDetection'; import { FaceDetection } from './faceDetectionNet/FaceDetection';
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet'; import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet'; import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks'; import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet'; import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
import { FullFaceDescription } from './FullFaceDescription'; import { FullFaceDescription } from './FullFaceDescription';
import { NetInput } from './NetInput'; import { NetInput } from './NetInput';
...@@ -44,7 +44,7 @@ export function locateFaces( ...@@ -44,7 +44,7 @@ export function locateFaces(
export function detectLandmarks( export function detectLandmarks(
input: TNetInput input: TNetInput
): Promise<FaceLandmarks | FaceLandmarks[]> { ): Promise<FaceLandmarks68 | FaceLandmarks68[]> {
return landmarkNet.detectLandmarks(input) return landmarkNet.detectLandmarks(input)
} }
......
...@@ -89,4 +89,13 @@ export class BoundingBox { ...@@ -89,4 +89,13 @@ export class BoundingBox {
return { dy, edy, dx, edx, y, ey, x, ex, w, h } return { dy, edy, dx, edx, y, ey, x, ex, w, h }
} }
public calibrate(region: BoundingBox) {
return new BoundingBox(
this.left + (region.left * this.width),
this.top + (region.top * this.height),
this.right + (region.right * this.width),
this.bottom + (region.bottom * this.height)
).toSquare().round()
}
} }
\ No newline at end of file
import { FaceLandmarks } from '../FaceLandmarks';
import { IPoint, Point } from '../Point';
export class FaceLandmarks5 extends FaceLandmarks {
public forSize(width: number, height: number): FaceLandmarks5 {
return new FaceLandmarks5(
this.getRelativePositions(),
{ width, height }
)
}
public shift(x: number, y: number): FaceLandmarks5 {
return new FaceLandmarks5(
this.getRelativePositions(),
{ width: this._imageWidth, height: this._imageHeight },
new Point(x, y)
)
}
public shiftByPoint(pt: IPoint): FaceLandmarks5 {
return this.shift(pt.x, pt.y)
}
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
import { NeuralNetwork } from '../commons/NeuralNetwork'; import { NeuralNetwork } from '../commons/NeuralNetwork';
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
import { NetInput } from '../NetInput'; import { NetInput } from '../NetInput';
import { Point } from '../Point';
import { Rect } from '../Rect';
import { toNetInput } from '../toNetInput'; import { toNetInput } from '../toNetInput';
import { TNetInput } from '../types'; import { TNetInput } from '../types';
import { bgrToRgbTensor } from './bgrToRgbTensor'; import { bgrToRgbTensor } from './bgrToRgbTensor';
import { extractParams } from './extractParams'; import { extractParams } from './extractParams';
import { FaceLandmarks5 } from './FaceLandmarks5';
import { pyramidDown } from './pyramidDown'; import { pyramidDown } from './pyramidDown';
import { stage1 } from './stage1'; import { stage1 } from './stage1';
import { stage2 } from './stage2'; import { stage2 } from './stage2';
import { stage3 } from './stage3';
import { NetParams } from './types'; import { NetParams } from './types';
export class Mtcnn extends NeuralNetwork<NetParams> { export class Mtcnn extends NeuralNetwork<NetParams> {
...@@ -22,7 +27,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -22,7 +27,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
minFaceSize: number = 20, minFaceSize: number = 20,
scaleFactor: number = 0.709, scaleFactor: number = 0.709,
scoreThresholds: number[] = [0.6, 0.7, 0.7] scoreThresholds: number[] = [0.6, 0.7, 0.7]
): Promise<tf.Tensor2D> { ): Promise<any> {
const { params } = this const { params } = this
...@@ -43,19 +48,46 @@ export class Mtcnn extends NeuralNetwork<NetParams> { ...@@ -43,19 +48,46 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
) )
) )
const scales = pyramidDown(minFaceSize, scaleFactor, imgTensor.shape.slice(1)) const [height, width] = imgTensor.shape.slice(1)
const scales = pyramidDown(minFaceSize, scaleFactor, [height, width])
const out1 = await stage1(imgTensor, scales, scoreThresholds[0], params.pnet) const out1 = await stage1(imgTensor, scales, scoreThresholds[0], params.pnet)
// using the inputCanvas to extract and resize the image patches, since it is faster // using the inputCanvas to extract and resize the image patches, since it is faster
// than doing this on the gpu // than doing this on the gpu
const out2 = await stage2(inputCanvas, out1, scoreThresholds[1], params.rnet) const out2 = await stage2(inputCanvas, out1.boxes, scoreThresholds[1], params.rnet)
const out3 = await stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet)
imgTensor.dispose() imgTensor.dispose()
input.dispose() input.dispose()
return tf.tensor2d([0], [1, 1]) const faceDetections = out3.boxes.map((box, idx) =>
new FaceDetection(
out3.scores[idx],
new Rect(
box.left / width,
box.top / height,
box.width / width,
box.height / height
),
{
height,
width
}
)
)
const faceLandmarks = out3.points.map(pts =>
new FaceLandmarks5(
pts.map(pt => pt.div(new Point(width, height))),
{ width, height }
)
)
return {
faceDetections,
faceLandmarks
}
} }
public async forward( public async forward(
......
import * as tf from '@tensorflow/tfjs-core';
import { convLayer } from '../commons/convLayer';
import { fullyConnectedLayer } from '../faceLandmarkNet/fullyConnectedLayer';
import { prelu } from './prelu';
import { sharedLayer } from './sharedLayers';
import { ONetParams } from './types';
export function ONet(x: tf.Tensor4D, params: ONetParams): { scores: tf.Tensor1D, regions: tf.Tensor2D, points: tf.Tensor2D } {
return tf.tidy(() => {
let out = sharedLayer(x, params)
out = tf.maxPool(out, [2, 2], [2, 2], 'same')
out = convLayer(out, params.conv4, 'valid')
out = prelu<tf.Tensor4D>(out, params.prelu4_alpha)
const vectorized = tf.reshape(out, [out.shape[0], params.fc1.weights.shape[0]]) as tf.Tensor2D
const fc1 = fullyConnectedLayer(vectorized, params.fc1)
const prelu5 = prelu<tf.Tensor2D>(fc1, params.prelu5_alpha)
const fc2_1 = fullyConnectedLayer(prelu5, params.fc2_1)
const max = tf.expandDims(tf.max(fc2_1, 1), 1)
const prob = tf.softmax(tf.sub(fc2_1, max), 1) as tf.Tensor2D
const regions = fullyConnectedLayer(prelu5, params.fc2_2)
const points = fullyConnectedLayer(prelu5, params.fc2_3)
const scores = tf.unstack(prob, 1)[1] as tf.Tensor1D
return { scores, regions, points }
})
}
\ No newline at end of file
...@@ -5,7 +5,7 @@ import { prelu } from './prelu'; ...@@ -5,7 +5,7 @@ import { prelu } from './prelu';
import { sharedLayer } from './sharedLayers'; import { sharedLayer } from './sharedLayers';
import { RNetParams } from './types'; import { RNetParams } from './types';
export function RNet(x: tf.Tensor4D, params: RNetParams): { prob: tf.Tensor2D, regions: tf.Tensor2D } { export function RNet(x: tf.Tensor4D, params: RNetParams): { scores: tf.Tensor1D, regions: tf.Tensor2D } {
return tf.tidy(() => { return tf.tidy(() => {
const convOut = sharedLayer(x, params) const convOut = sharedLayer(x, params)
...@@ -17,6 +17,7 @@ export function RNet(x: tf.Tensor4D, params: RNetParams): { prob: tf.Tensor2D, r ...@@ -17,6 +17,7 @@ export function RNet(x: tf.Tensor4D, params: RNetParams): { prob: tf.Tensor2D, r
const prob = tf.softmax(tf.sub(fc2_1, max), 1) as tf.Tensor2D const prob = tf.softmax(tf.sub(fc2_1, max), 1) as tf.Tensor2D
const regions = fullyConnectedLayer(prelu4, params.fc2_2) const regions = fullyConnectedLayer(prelu4, params.fc2_2)
return { prob, regions } const scores = tf.unstack(prob, 1)[1] as tf.Tensor1D
return { scores, regions }
}) })
} }
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { Dimensions } from '../types';
import { createCanvas, getContext2dOrThrow } from '../utils';
import { bgrToRgbTensor } from './bgrToRgbTensor';
import { BoundingBox } from './BoundingBox';
import { normalize } from './normalize';
export async function extractImagePatches(
img: HTMLCanvasElement,
boxes: BoundingBox[],
{ width, height }: Dimensions
): Promise<tf.Tensor4D> {
const imgCtx = getContext2dOrThrow(img)
const bitmaps = await Promise.all(boxes.map(async box => {
// TODO: correct padding
const { y, ey, x, ex } = box.padAtBorders(img.height, img.width)
const fromX = x - 1
const fromY = y - 1
const imgData = imgCtx.getImageData(fromX, fromY, (ex - fromX), (ey - fromY))
return createImageBitmap(imgData)
}))
const imagePatchesData: number[] = []
bitmaps.forEach(bmp => {
const patch = createCanvas({ width, height })
const patchCtx = getContext2dOrThrow(patch)
patchCtx.drawImage(bmp, 0, 0, width, height)
const { data } = patchCtx.getImageData(0, 0, width, height)
for(let i = 0; i < data.length; i++) {
if ((i + 1) % 4 === 0) continue
imagePatchesData.push(data[i])
}
})
return tf.tidy(() => {
const imagePatchTensor = bgrToRgbTensor(tf.transpose(
tf.tensor4d(imagePatchesData, [boxes.length, width, height, 3]),
[0, 2, 1, 3]
).toFloat()) as tf.Tensor4D
return normalize(imagePatchTensor)
})
}
\ No newline at end of file
...@@ -115,6 +115,9 @@ export function stage1( ...@@ -115,6 +115,9 @@ export function stage1(
(all, boxes) => all.concat(boxes) (all, boxes) => all.concat(boxes)
) )
let finalBoxes: BoundingBox[] = []
let finalScores: number[] = []
if (allBoxes.length > 0) { if (allBoxes.length > 0) {
const indices = nms( const indices = nms(
allBoxes.map(bbox => bbox.cell), allBoxes.map(bbox => bbox.cell),
...@@ -122,21 +125,23 @@ export function stage1( ...@@ -122,21 +125,23 @@ export function stage1(
0.7 0.7
) )
const finalBoxes = indices finalScores = indices.map(idx => allBoxes[idx].score)
finalBoxes = indices
.map(idx => allBoxes[idx]) .map(idx => allBoxes[idx])
.map(({ cell, region, score }) => ({ .map(({ cell, region }) =>
box: new BoundingBox( new BoundingBox(
cell.left + (region.left * cell.width), cell.left + (region.left * cell.width),
cell.top + (region.top * cell.height), cell.top + (region.top * cell.height),
cell.right + (region.right * cell.width), cell.right + (region.right * cell.width),
cell.bottom + (region.bottom * cell.height) cell.bottom + (region.bottom * cell.height)
).toSquare().round(), ).toSquare().round()
score )
}))
return finalBoxes
} }
return [] return {
boxes: finalBoxes,
scores: finalScores
}
} }
import * as tf from '@tensorflow/tfjs-core';
import { createCanvas, getContext2dOrThrow } from '../utils';
import { bgrToRgbTensor } from './bgrToRgbTensor';
import { BoundingBox } from './BoundingBox'; import { BoundingBox } from './BoundingBox';
import { extractImagePatches } from './extractImagePatches';
import { nms } from './nms'; import { nms } from './nms';
import { normalize } from './normalize';
import { RNet } from './RNet'; import { RNet } from './RNet';
import { RNetParams } from './types'; import { RNetParams } from './types';
export async function stage2( export async function stage2(
img: HTMLCanvasElement, img: HTMLCanvasElement,
boxes: { box: BoundingBox, score: number }[], inputBoxes: BoundingBox[],
scoreThreshold: number, scoreThreshold: number,
params: RNetParams params: RNetParams
) { ) {
const { height, width } = img const rnetInput = await extractImagePatches(img, inputBoxes, { width: 24, height: 24 })
const rnetOut = RNet(rnetInput, params)
const imgCtx = getContext2dOrThrow(img)
const bitmaps = await Promise.all(boxes.map(async ({ box }) => {
// TODO: correct padding
const { y, ey, x, ex } = box.padAtBorders(height, width)
const fromX = x - 1
const fromY = y - 1
const imgData = imgCtx.getImageData(fromX, fromY, (ex - fromX), (ey - fromY))
return createImageBitmap(imgData)
}))
const imagePatchesData: number[] = []
bitmaps.forEach(bmp => {
const patch = createCanvas({ width: 24, height: 24 })
const patchCtx = getContext2dOrThrow(patch)
patchCtx.drawImage(bmp, 0, 0, 24, 24)
const { data } = patchCtx.getImageData(0, 0, 24, 24)
for(let i = 0; i < data.length; i++) { rnetInput.dispose()
if ((i + 1) % 4 === 0) continue
imagePatchesData.push(data[i])
}
})
const rnetOut = tf.tidy(() => {
const imagePatchTensor = bgrToRgbTensor(tf.transpose(
tf.tensor4d(imagePatchesData, [boxes.length, 24, 24, 3]),
[0, 2, 1, 3]
).toFloat()) as tf.Tensor4D
const normalized = normalize(imagePatchTensor)
const { prob, regions } = RNet(normalized, params)
return {
scores: tf.unstack(prob, 1)[1],
regions
}
})
const scores = Array.from(await rnetOut.scores.data()) const scores = Array.from(await rnetOut.scores.data())
const indices = scores const indices = scores
.map((score, idx) => ({ score, idx })) .map((score, idx) => ({ score, idx }))
.filter(c => c.score > scoreThreshold) .filter(c => c.score > scoreThreshold)
.map(({ idx }) => idx) .map(({ idx }) => idx)
const filteredBoxes = indices.map(idx => boxes[idx].box) const filteredBoxes = indices.map(idx => inputBoxes[idx])
const filteredScores = indices.map(idx => scores[idx]) const filteredScores = indices.map(idx => scores[idx])
let finalBoxes: BoundingBox[] = [] let finalBoxes: BoundingBox[] = []
...@@ -79,31 +35,24 @@ export async function stage2( ...@@ -79,31 +35,24 @@ export async function stage2(
0.7 0.7
) )
finalScores = indicesNms.map(idx => filteredScores[idx]) const regions = indicesNms.map(idx =>
finalBoxes = indicesNms new BoundingBox(
.map(idx => { rnetOut.regions.get(indices[idx], 0),
const box = filteredBoxes[idx] rnetOut.regions.get(indices[idx], 1),
const [rleft, rtop, right, rbottom] = [ rnetOut.regions.get(indices[idx], 2),
rnetOut.regions.get(indices[idx], 0), rnetOut.regions.get(indices[idx], 3)
rnetOut.regions.get(indices[idx], 1), )
rnetOut.regions.get(indices[idx], 2), )
rnetOut.regions.get(indices[idx], 3)
]
return new BoundingBox( finalScores = indicesNms.map(idx => filteredScores[idx])
box.left + (rleft * box.width), finalBoxes = indicesNms.map((idx, i) => filteredBoxes[idx].calibrate(regions[i]))
box.top + (rtop * box.height),
box.right + (right * box.width),
box.bottom + (rbottom * box.height)
).toSquare().round()
})
} }
rnetOut.regions.dispose() rnetOut.regions.dispose()
rnetOut.scores.dispose() rnetOut.scores.dispose()
return { return {
finalBoxes, boxes: finalBoxes,
finalScores scores: finalScores
} }
} }
\ No newline at end of file
import { Point } from '../Point';
import { BoundingBox } from './BoundingBox';
import { extractImagePatches } from './extractImagePatches';
import { nms } from './nms';
import { ONet } from './ONet';
import { ONetParams } from './types';
export async function stage3(
img: HTMLCanvasElement,
inputBoxes: BoundingBox[],
scoreThreshold: number,
params: ONetParams
) {
const onetInput = await extractImagePatches(img, inputBoxes, { width: 48, height: 48 })
const onetOut = ONet(onetInput, params)
onetInput.dispose()
const scores = Array.from(await onetOut.scores.data())
const indices = scores
.map((score, idx) => ({ score, idx }))
.filter(c => c.score > scoreThreshold)
.map(({ idx }) => idx)
const filteredRegions = indices.map(idx => new BoundingBox(
onetOut.regions.get(idx, 0),
onetOut.regions.get(idx, 1),
onetOut.regions.get(idx, 2),
onetOut.regions.get(idx, 3)
))
const filteredBoxes = indices
.map((idx, i) => inputBoxes[idx].calibrate(filteredRegions[i]))
const filteredScores = indices.map(idx => scores[idx])
let finalBoxes: BoundingBox[] = []
let finalScores: number[] = []
let points: Point[][] = []
if (filteredBoxes.length > 0) {
const indicesNms = nms(
filteredBoxes,
filteredScores,
0.7,
false
)
finalBoxes = indicesNms.map(idx => filteredBoxes[idx])
finalScores = indicesNms.map(idx => filteredScores[idx])
points = indicesNms.map((idx, i) =>
Array(5).fill(0).map((_, ptIdx) =>
new Point(
((onetOut.points.get(idx, ptIdx) * (finalBoxes[i].width + 1)) + finalBoxes[i].left) ,
((onetOut.points.get(idx, ptIdx + 5) * (finalBoxes[i].height + 1)) + finalBoxes[i].top)
)
)
)
}
onetOut.regions.dispose()
onetOut.scores.dispose()
onetOut.points.dispose()
return {
boxes: finalBoxes,
scores: finalScores,
points
}
}
import { ConvParams, FCParams } from '../commons/types';
import { tf } from '..'; import { tf } from '..';
import { ConvParams, FCParams } from '../commons/types';
import { BoundingBox } from './BoundingBox';
export type SharedParams = { export type SharedParams = {
conv1: ConvParams conv1: ConvParams
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment