Commit 3ca0f4af by vincent

pad input to square

parent 89e9691e
......@@ -33,6 +33,20 @@ function fromImageData(input: ImageData[]) {
return tf.cast(tf.concat(imgTensors, 0), 'float32')
}
function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
const [_, height, width] = imgTensor.shape
if (height === width) {
return imgTensor
}
if (height > width) {
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 2)
}
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
return tf.concat([imgTensor, pad], 1)
}
function getImgTensor(input: ImageData|ImageData[]|number[]) {
return tf.tidy(() => {
......@@ -44,9 +58,11 @@ function getImgTensor(input: ImageData|ImageData[]|number[]) {
: null
)
return imgDataArray !== null
? fromImageData(imgDataArray)
: fromData(input as number[])
return padToSquare(
imgDataArray !== null
? fromImageData(imgDataArray)
: fromData(input as number[])
)
})
}
......@@ -71,7 +87,7 @@ export function faceDetectionNet(weights: Float32Array) {
function forward(input: ImageData|ImageData[]|number[]) {
return tf.tidy(
() => forwardTensor(getImgTensor(input))
() => forwardTensor(padToSquare(getImgTensor(input)))
)
}
......@@ -81,7 +97,6 @@ export function faceDetectionNet(weights: Float32Array) {
maxResults: number = 100,
): Promise<FaceDetectionNet.Detection[]> {
const imgTensor = getImgTensor(input)
const [_, height, width] = imgTensor.shape
const {
......
import * as tf from '@tensorflow/tfjs-core';
import { euclideanDistance } from './euclideanDistance';
import { faceDetectionNet } from './faceDetectionNet';
import { faceRecognitionNet } from './faceRecognitionNet';
......@@ -7,7 +9,8 @@ export {
euclideanDistance,
faceDetectionNet,
faceRecognitionNet,
normalize
normalize,
tf
}
export * from './utils'
\ No newline at end of file
......@@ -23,9 +23,15 @@ export function round(num: number) {
return Math.floor(num * 100) / 100
}
export type Dimensions = {
width: number
height: number
}
export function drawMediaToCanvas(
canvasArg: string | HTMLCanvasElement,
mediaArg: string | HTMLImageElement | HTMLVideoElement
mediaArg: string | HTMLImageElement | HTMLVideoElement,
dims?: Dimensions
): CanvasRenderingContext2D {
const canvas = getElement(canvasArg)
const media = getElement(mediaArg)
......@@ -37,21 +43,24 @@ export function drawMediaToCanvas(
throw new Error('drawMediaToCanvas - expected media to be of type: HTMLImageElement | HTMLVideoElement')
}
canvas.width = media.width
canvas.height = media.height
const { width, height } = dims || media
canvas.width = width
canvas.height = height
const ctx = getContext2dOrThrow(canvas)
ctx.drawImage(media, 0, 0, media.width, media.height)
ctx.drawImage(media, 0, 0, width, height)
return ctx
}
export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement): ImageData {
export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement, dims?: Dimensions): ImageData {
if (!(media instanceof HTMLImageElement || media instanceof HTMLVideoElement)) {
throw new Error('mediaToImageData - expected media to be of type: HTMLImageElement | HTMLVideoElement')
}
const ctx = drawMediaToCanvas(document.createElement('canvas'), media)
return ctx.getImageData(0, 0, media.width, media.height)
const { width, height } = dims || media
return ctx.getImageData(0, 0, width, height)
}
export function mediaSrcToImageData(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment