Commit 2ae76364 by vincent

NetInput to simplify api

parent b23e6376
import { TMediaElement, TNetInput } from './types';
import { Dimensions, getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
export class NetInput {
private _canvases: HTMLCanvasElement[]
constructor(
mediaArg: TNetInput,
dims?: Dimensions
) {
const mediaArgArray = Array.isArray(mediaArg)
? mediaArg
: [mediaArg]
if (!mediaArgArray.length) {
throw new Error('NetInput - empty array passed as input')
}
const medias = mediaArgArray.map(getElement)
medias.forEach((media, i) => {
if (!(media instanceof HTMLImageElement || media instanceof HTMLVideoElement || media instanceof HTMLCanvasElement)) {
const idxHint = Array.isArray(mediaArg) ? ` at input index ${i}:` : ''
if (typeof mediaArgArray[i] === 'string') {
throw new Error(`NetInput -${idxHint} string passed, but could not resolve HTMLElement for element id`)
}
throw new Error(`NetInput -${idxHint} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement, or to be an element id`)
}
})
this._canvases = []
medias.forEach(m => this.initCanvas(m, dims))
}
private initCanvas(media: TMediaElement, dims?: Dimensions) {
if (media instanceof HTMLCanvasElement) {
this._canvases.push(media)
return
}
// if input is batch type, make sure every canvas has the same dimensions
const { width, height } = this.dims || dims || getMediaDimensions(media)
const canvas = document.createElement('canvas')
canvas.width = width
canvas.height = height
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
this._canvases.push(canvas)
}
public get canvases() : HTMLCanvasElement[] {
return this._canvases
}
public get width() : number {
return (this._canvases[0] || {}).width
}
public get height() : number {
return (this._canvases[0] || {}).height
}
public get dims() : Dimensions | null {
const { width, height } = this
return (width > 0 && height > 0) ? { width, height } : null
}
}
\ No newline at end of file
......@@ -16,8 +16,7 @@ function convWithBias(
export function boxPredictionLayer(
x: tf.Tensor4D,
params: FaceDetectionNet.BoxPredictionParams,
size: number
params: FaceDetectionNet.BoxPredictionParams
) {
return tf.tidy(() => {
......@@ -25,11 +24,11 @@ export function boxPredictionLayer(
const boxPredictionEncoding = tf.reshape(
convWithBias(x, params.box_encoding_predictor_params),
[batchSize, size, 1, 4]
[batchSize, -1, 1, 4]
)
const classPrediction = tf.reshape(
convWithBias(x, params.class_predictor_params),
[batchSize, size, 3]
[batchSize, -1, 3]
)
return {
......
import * as tf from '@tensorflow/tfjs-core';
import { isFloat } from '../utils';
import { NetInput } from '../NetInput';
import { getImageTensor, padToSquare } from '../transformInputs';
import { TNetInput } from '../types';
import { extractParams } from './extractParams';
import { FaceDetectionResult } from './FaceDetectionResult';
import { mobileNetV1 } from './mobileNetV1';
......@@ -9,81 +11,6 @@ import { outputLayer } from './outputLayer';
import { predictionLayer } from './predictionLayer';
import { resizeLayer } from './resizeLayer';
function fromData(input: number[]): tf.Tensor4D {
const pxPerChannel = input.length / 3
const dim = Math.sqrt(pxPerChannel)
if (isFloat(dim)) {
throw new Error(`invalid input size: ${dim}x${dim}x3 (array length: ${input.length})`)
}
return tf.tensor4d(input as number[], [1, dim, dim, 3])
}
function fromImageData(input: ImageData[]) {
return tf.tidy(() => {
const idx = input.findIndex(data => !(data instanceof ImageData))
if (idx !== -1) {
throw new Error(`expected input at index ${idx} to be instanceof ImageData`)
}
const imgTensors = input
.map(data => tf.fromPixels(data))
.map(data => tf.expandDims(data, 0)) as tf.Tensor4D[]
return tf.cast(tf.concat(imgTensors, 0), 'float32')
})
}
function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
return tf.tidy(() => {
const [_, height, width] = imgTensor.shape
if (height === width) {
return imgTensor
}
if (height > width) {
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 2)
}
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 1)
})
}
function getImgTensor(input: tf.Tensor|HTMLCanvasElement|ImageData|ImageData[]|number[]) {
return tf.tidy(() => {
if (input instanceof HTMLCanvasElement) {
return tf.cast(
tf.expandDims(tf.fromPixels(input), 0), 'float32'
) as tf.Tensor4D
}
if (input instanceof tf.Tensor) {
const rank = input.shape.length
if (rank !== 3 && rank !== 4) {
throw new Error('input tensor must be of rank 3 or 4')
}
return tf.cast(
rank === 3 ? tf.expandDims(input, 0) : input, 'float32'
) as tf.Tensor4D
}
const imgDataArray = input instanceof ImageData
? [input]
: (
input[0] instanceof ImageData
? input as ImageData[]
: null
)
return imgDataArray !== null
? fromImageData(imgDataArray)
: fromData(input as number[])
})
}
export function faceDetectionNet(weights: Float32Array) {
const params = extractParams(weights)
......@@ -102,14 +29,14 @@ export function faceDetectionNet(weights: Float32Array) {
})
}
function forward(input: tf.Tensor|ImageData|ImageData[]|number[]) {
function forward(input: tf.Tensor | NetInput | TNetInput) {
return tf.tidy(
() => forwardTensor(padToSquare(getImgTensor(input)))
() => forwardTensor(padToSquare(getImageTensor(input)))
)
}
async function locateFaces(
input: tf.Tensor|HTMLCanvasElement|ImageData|ImageData[]|number[],
input: tf.Tensor | NetInput,
minConfidence: number = 0.8,
maxResults: number = 100,
): Promise<FaceDetectionResult[]> {
......@@ -121,7 +48,7 @@ export function faceDetectionNet(weights: Float32Array) {
scores: _scores
} = tf.tidy(() => {
let imgTensor = getImgTensor(input)
let imgTensor = getImageTensor(input)
const [_, height, width] = imgTensor.shape
imgTensor = padToSquare(imgTensor)
......@@ -140,9 +67,7 @@ export function faceDetectionNet(weights: Float32Array) {
}
// TODO find a better way to filter by minConfidence
//const ts = Date.now()
const scoresData = Array.from(await scores.data())
//console.log('await data:', (Date.now() - ts))
const iouThreshold = 0.5
const indices = nonMaxSuppression(
......
......@@ -4,7 +4,11 @@ import { boxPredictionLayer } from './boxPredictionLayer';
import { pointwiseConvLayer } from './pointwiseConvLayer';
import { FaceDetectionNet } from './types';
export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: FaceDetectionNet.PredictionLayerParams) {
export function predictionLayer(
x: tf.Tensor4D,
conv11: tf.Tensor4D,
params: FaceDetectionNet.PredictionLayerParams
) {
return tf.tidy(() => {
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1])
......@@ -16,12 +20,12 @@ export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: Fac
const conv6 = pointwiseConvLayer(conv5, params.conv_6_params, [1, 1])
const conv7 = pointwiseConvLayer(conv6, params.conv_7_params, [2, 2])
const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0_params, 3072)
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params, 1536)
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params, 384)
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params, 96)
const boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4_params, 24)
const boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5_params, 6)
const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0_params)
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params)
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params)
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params)
const boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4_params)
const boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5_params)
const boxPredictions = tf.concat([
boxPrediction0.boxPredictionEncoding,
......
import * as tf from '@tensorflow/tfjs-core';
import { normalize } from '../normalize';
import { NetInput } from '../NetInput';
import { getImageTensor, padToSquare } from '../transformInputs';
import { TNetInput } from '../types';
import { convDown } from './convLayer';
import { extractParams } from './extractParams';
import { normalize } from './normalize';
import { residual, residualDown } from './residualLayer';
export function faceRecognitionNet(weights: Float32Array) {
const params = extractParams(weights)
function forward(input: number[] | ImageData) {
function forward(input: tf.Tensor | NetInput | TNetInput) {
return tf.tidy(() => {
const x = normalize(input)
const x = normalize(padToSquare(getImageTensor(input)))
let out = convDown(x, params.conv32_down)
out = tf.maxPool(out, 3, 2, 'valid')
......@@ -42,14 +44,14 @@ export function faceRecognitionNet(weights: Float32Array) {
})
}
const computeFaceDescriptor = async (input: number[] | ImageData) => {
const computeFaceDescriptor = async (input: tf.Tensor | NetInput | TNetInput) => {
const result = forward(input)
const data = await result.data()
result.dispose()
return data
}
const computeFaceDescriptorSync = (input: number[] | ImageData) => {
const computeFaceDescriptorSync = (input: tf.Tensor | NetInput | TNetInput) => {
const result = forward(input)
const data = result.dataSync()
result.dispose()
......
import * as tf from '@tensorflow/tfjs-core';
export function normalize(input: number[] | ImageData): tf.Tensor4D {
export function normalize(x: tf.Tensor4D): tf.Tensor4D {
return tf.tidy(() => {
const avg_r = tf.fill([1, 150, 150, 1], 122.782);
const avg_g = tf.fill([1, 150, 150, 1], 117.001);
const avg_b = tf.fill([1, 150, 150, 1], 104.298);
const avg_rgb = tf.concat([avg_r, avg_g, avg_b], 3)
const x = input instanceof ImageData
? tf.cast(tf.reshape(tf.fromPixels(input), [1, 150, 150, 3]), 'float32')
: tf.tensor4d(input, [1, 150, 150, 3])
return tf.div(tf.sub(x, avg_rgb), tf.fill(x.shape, 256))
return tf.div(tf.sub(x, avg_rgb), tf.scalar(256))
})
}
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { euclideanDistance } from './euclideanDistance';
import { faceDetectionNet } from './faceDetectionNet';
import { faceRecognitionNet } from './faceRecognitionNet';
import { normalize } from './normalize';
import { NetInput } from './NetInput';
export {
euclideanDistance,
faceDetectionNet,
faceRecognitionNet,
normalize,
tf
NetInput
}
export * from './utils'
\ No newline at end of file
import * as tf from '@tensorflow/tfjs-core';
import { NetInput } from './NetInput';
import { TNetInput } from './types';
export function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
return tf.tidy(() => {
const [_, height, width] = imgTensor.shape
if (height === width) {
return imgTensor
}
if (height > width) {
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 2)
}
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 1)
})
}
export function getImageTensor(input: tf.Tensor | NetInput | TNetInput): tf.Tensor4D {
return tf.tidy(() => {
if (input instanceof tf.Tensor) {
const rank = input.shape.length
if (rank !== 3 && rank !== 4) {
throw new Error('input tensor must be of rank 3 or 4')
}
return (rank === 3 ? input.expandDims(0) : input).toFloat() as tf.Tensor4D
}
const netInput = input instanceof NetInput ? input : new NetInput(input)
return tf.concat(
netInput.canvases.map(canvas =>
tf.fromPixels(canvas).expandDims(0).toFloat()
)
) as tf.Tensor4D
})
}
\ No newline at end of file
export type TMediaElement = HTMLImageElement | HTMLVideoElement | HTMLCanvasElement
export type TNetInputArg = string | TMediaElement
export type TNetInput = TNetInputArg | Array<TNetInputArg>
import { FaceDetectionNet } from './faceDetectionNet/types';
function getElement(arg: string | any) {
export function getElement(arg: string | any) {
if (typeof arg === 'string') {
return document.getElementById(arg)
}
return arg
}
function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingContext2D {
export function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingContext2D {
const ctx = canvas.getContext('2d')
if (!ctx) {
throw new Error('canvas 2d context is null')
......@@ -15,7 +15,7 @@ function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingContext2
return ctx
}
function getMediaDimensions(media: HTMLImageElement | HTMLVideoElement) {
export function getMediaDimensions(media: HTMLImageElement | HTMLVideoElement) {
if (media instanceof HTMLVideoElement) {
return { width: media.videoWidth, height: media.videoHeight }
}
......@@ -35,11 +35,11 @@ export type Dimensions = {
height: number
}
export function drawMediaToCanvas(
export function toNetInput(
canvasArg: string | HTMLCanvasElement,
mediaArg: string | HTMLImageElement | HTMLVideoElement,
dims?: Dimensions
): CanvasRenderingContext2D {
): HTMLCanvasElement {
const canvas = getElement(canvasArg)
const media = getElement(mediaArg)
......@@ -56,7 +56,7 @@ export function drawMediaToCanvas(
const ctx = getContext2dOrThrow(canvas)
ctx.drawImage(media, 0, 0, width, height)
return ctx
return canvas
}
export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement, dims?: Dimensions): ImageData {
......@@ -64,7 +64,8 @@ export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement, dim
throw new Error('mediaToImageData - expected media to be of type: HTMLImageElement | HTMLVideoElement')
}
const ctx = drawMediaToCanvas(document.createElement('canvas'), media)
const canvas = toNetInput(document.createElement('canvas'), media)
const ctx = getContext2dOrThrow(canvas)
const { width, height } = dims || getMediaDimensions(media)
return ctx.getImageData(0, 0, width, height)
......@@ -108,6 +109,24 @@ export async function bufferToImageData(buf: Blob): Promise<ImageData> {
return mediaSrcToImageData(await bufferToImgSrc(buf))
}
export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
return new Promise((resolve, reject) => {
if (!(buf instanceof Blob)) {
return reject('bufferToImage - expected buf to be of type: Blob')
}
const reader = new FileReader()
reader.onload = () => {
const img = new Image()
img.onload = () => resolve(img)
img.onerror = reject
img.src = reader.result
}
reader.onerror = reject
reader.readAsDataURL(buf)
})
}
export type DrawBoxOptions = {
lineWidth: number
color: string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment