Commit 3ca0f4af by vincent

pad input to square

parent 89e9691e
...@@ -33,6 +33,20 @@ function fromImageData(input: ImageData[]) { ...@@ -33,6 +33,20 @@ function fromImageData(input: ImageData[]) {
return tf.cast(tf.concat(imgTensors, 0), 'float32') return tf.cast(tf.concat(imgTensors, 0), 'float32')
} }
function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
const [_, height, width] = imgTensor.shape
if (height === width) {
return imgTensor
}
if (height > width) {
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 2)
}
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 1)
}
function getImgTensor(input: ImageData|ImageData[]|number[]) { function getImgTensor(input: ImageData|ImageData[]|number[]) {
return tf.tidy(() => { return tf.tidy(() => {
...@@ -44,9 +58,11 @@ function getImgTensor(input: ImageData|ImageData[]|number[]) { ...@@ -44,9 +58,11 @@ function getImgTensor(input: ImageData|ImageData[]|number[]) {
: null : null
) )
return imgDataArray !== null return padToSquare(
? fromImageData(imgDataArray) imgDataArray !== null
: fromData(input as number[]) ? fromImageData(imgDataArray)
: fromData(input as number[])
)
}) })
} }
...@@ -71,7 +87,7 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -71,7 +87,7 @@ export function faceDetectionNet(weights: Float32Array) {
function forward(input: ImageData|ImageData[]|number[]) { function forward(input: ImageData|ImageData[]|number[]) {
return tf.tidy( return tf.tidy(
() => forwardTensor(getImgTensor(input)) () => forwardTensor(padToSquare(getImgTensor(input)))
) )
} }
...@@ -81,7 +97,6 @@ export function faceDetectionNet(weights: Float32Array) { ...@@ -81,7 +97,6 @@ export function faceDetectionNet(weights: Float32Array) {
maxResults: number = 100, maxResults: number = 100,
): Promise<FaceDetectionNet.Detection[]> { ): Promise<FaceDetectionNet.Detection[]> {
const imgTensor = getImgTensor(input) const imgTensor = getImgTensor(input)
const [_, height, width] = imgTensor.shape const [_, height, width] = imgTensor.shape
const { const {
......
import * as tf from '@tensorflow/tfjs-core';
import { euclideanDistance } from './euclideanDistance'; import { euclideanDistance } from './euclideanDistance';
import { faceDetectionNet } from './faceDetectionNet'; import { faceDetectionNet } from './faceDetectionNet';
import { faceRecognitionNet } from './faceRecognitionNet'; import { faceRecognitionNet } from './faceRecognitionNet';
...@@ -7,7 +9,8 @@ export { ...@@ -7,7 +9,8 @@ export {
euclideanDistance, euclideanDistance,
faceDetectionNet, faceDetectionNet,
faceRecognitionNet, faceRecognitionNet,
normalize normalize,
tf
} }
export * from './utils' export * from './utils'
\ No newline at end of file
...@@ -23,9 +23,15 @@ export function round(num: number) { ...@@ -23,9 +23,15 @@ export function round(num: number) {
return Math.floor(num * 100) / 100 return Math.floor(num * 100) / 100
} }
export type Dimensions = {
width: number
height: number
}
export function drawMediaToCanvas( export function drawMediaToCanvas(
canvasArg: string | HTMLCanvasElement, canvasArg: string | HTMLCanvasElement,
mediaArg: string | HTMLImageElement | HTMLVideoElement mediaArg: string | HTMLImageElement | HTMLVideoElement,
dims?: Dimensions
): CanvasRenderingContext2D { ): CanvasRenderingContext2D {
const canvas = getElement(canvasArg) const canvas = getElement(canvasArg)
const media = getElement(mediaArg) const media = getElement(mediaArg)
...@@ -37,21 +43,24 @@ export function drawMediaToCanvas( ...@@ -37,21 +43,24 @@ export function drawMediaToCanvas(
throw new Error('drawMediaToCanvas - expected media to be of type: HTMLImageElement | HTMLVideoElement') throw new Error('drawMediaToCanvas - expected media to be of type: HTMLImageElement | HTMLVideoElement')
} }
canvas.width = media.width const { width, height } = dims || media
canvas.height = media.height canvas.width = width
canvas.height = height
const ctx = getContext2dOrThrow(canvas) const ctx = getContext2dOrThrow(canvas)
ctx.drawImage(media, 0, 0, media.width, media.height) ctx.drawImage(media, 0, 0, width, height)
return ctx return ctx
} }
export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement): ImageData { export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement, dims?: Dimensions): ImageData {
if (!(media instanceof HTMLImageElement || media instanceof HTMLVideoElement)) { if (!(media instanceof HTMLImageElement || media instanceof HTMLVideoElement)) {
throw new Error('mediaToImageData - expected media to be of type: HTMLImageElement | HTMLVideoElement') throw new Error('mediaToImageData - expected media to be of type: HTMLImageElement | HTMLVideoElement')
} }
const ctx = drawMediaToCanvas(document.createElement('canvas'), media) const ctx = drawMediaToCanvas(document.createElement('canvas'), media)
return ctx.getImageData(0, 0, media.width, media.height)
const { width, height } = dims || media
return ctx.getImageData(0, 0, width, height)
} }
export function mediaSrcToImageData( export function mediaSrcToImageData(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment