Commit 1e2d2616 by vincent

face recognition net now accepts batch inputs

parent 8ef9b662
...@@ -35,7 +35,7 @@ export function allFacesFactory( ...@@ -35,7 +35,7 @@ export function allFacesFactory(
const descriptors = await Promise.all(alignedFaceTensors.map( const descriptors = await Promise.all(alignedFaceTensors.map(
faceTensor => recognitionNet.computeFaceDescriptor(faceTensor) faceTensor => recognitionNet.computeFaceDescriptor(faceTensor)
)) )) as Float32Array[]
alignedFaceTensors.forEach(t => t.dispose()) alignedFaceTensors.forEach(t => t.dispose())
return detections.map((detection, i) => return detections.map((detection, i) =>
......
...@@ -30,12 +30,11 @@ export class FaceRecognitionNet { ...@@ -30,12 +30,11 @@ export class FaceRecognitionNet {
this._params = extractParams(weights) this._params = extractParams(weights)
} }
public async forwardInput(input: NetInput): Promise<tf.Tensor2D> { public forwardInput(input: NetInput): tf.Tensor2D {
if (!this._params) { if (!this._params) {
throw new Error('FaceRecognitionNet - load model before inference') throw new Error('FaceRecognitionNet - load model before inference')
} }
return tf.tidy(() => { return tf.tidy(() => {
const batchTensor = input.toBatchTensor(150, true) const batchTensor = input.toBatchTensor(150, true)
...@@ -68,14 +67,26 @@ export class FaceRecognitionNet { ...@@ -68,14 +67,26 @@ export class FaceRecognitionNet {
return fullyConnected return fullyConnected
}) })
} }
public async forward(input: TNetInput): Promise<tf.Tensor2D> { public async forward(input: TNetInput): Promise<tf.Tensor2D> {
return this.forwardInput(await toNetInput(input, true)) return this.forwardInput(await toNetInput(input, true))
} }
public async computeFaceDescriptor(input: TNetInput) { public async computeFaceDescriptor(input: TNetInput): Promise<Float32Array|Float32Array[]> {
const result = await this.forward(await toNetInput(input, true)) const netInput = await toNetInput(input, true)
const data = await result.data()
result.dispose() const faceDescriptorTensors = tf.tidy(
return data as Float32Array () => tf.unstack(this.forwardInput(netInput))
)
const faceDescriptorsForBatch = await Promise.all(faceDescriptorTensors.map(
t => t.data()
)) as Float32Array[]
faceDescriptorTensors.forEach(t => t.dispose())
return netInput.isBatchInput
? faceDescriptorsForBatch
: faceDescriptorsForBatch[0]
} }
} }
\ No newline at end of file
...@@ -2,9 +2,9 @@ import * as tf from '@tensorflow/tfjs-core'; ...@@ -2,9 +2,9 @@ import * as tf from '@tensorflow/tfjs-core';
export function normalize(x: tf.Tensor4D): tf.Tensor4D { export function normalize(x: tf.Tensor4D): tf.Tensor4D {
return tf.tidy(() => { return tf.tidy(() => {
const avg_r = tf.fill([1, 150, 150, 1], 122.782); const avg_r = tf.fill([...x.shape.slice(0, 3), 1], 122.782);
const avg_g = tf.fill([1, 150, 150, 1], 117.001); const avg_g = tf.fill([...x.shape.slice(0, 3), 1], 117.001);
const avg_b = tf.fill([1, 150, 150, 1], 104.298); const avg_b = tf.fill([...x.shape.slice(0, 3), 1], 104.298);
const avg_rgb = tf.concat([avg_r, avg_g, avg_b], 3) const avg_rgb = tf.concat([avg_r, avg_g, avg_b], 3)
return tf.div(tf.sub(x, avg_rgb), tf.scalar(256)) return tf.div(tf.sub(x, avg_rgb), tf.scalar(256))
......
...@@ -50,7 +50,7 @@ export function detectLandmarks( ...@@ -50,7 +50,7 @@ export function detectLandmarks(
export function computeFaceDescriptor( export function computeFaceDescriptor(
input: TNetInput input: TNetInput
): Promise<Float32Array> { ): Promise<Float32Array | Float32Array[]> {
return recognitionNet.computeFaceDescriptor(input) return recognitionNet.computeFaceDescriptor(input)
} }
......
[-0.08900658041238785, 0.10903996974229813, 0.027176279574632645, 0.04400758072733879, -0.14542895555496216, 0.11051996797323227, -0.04482650384306908, -0.05154910683631897, 0.10313281416893005, -0.09580713510513306, 0.11335672438144684, -0.02723177894949913, -0.2017219066619873, 0.09402787685394287, -0.025814395397901535, 0.07219463586807251, -0.12272300571203232, -0.07349629700183868, -0.1723618507385254, -0.1745331585407257, -0.03420797362923622, 0.10511981695890427, 0.0262751504778862, 0.014430010691285133, -0.2035353034734726, -0.2949812114238739, -0.04833773523569107, -0.10960741341114044, 0.08448510617017746, -0.039910122752189636, -0.03964325413107872, -0.099286288022995, -0.16025686264038086, 0.026379037648439407, 0.09079921245574951, 0.07745557278394699, -0.05415252223610878, -0.017411116510629654, 0.16053830087184906, 0.010681805200874805, -0.11814302206039429, 0.0382964164018631, 0.08098040521144867, 0.29891595244407654, 0.1258186250925064, 0.06479117274284363, 0.02330329827964306, -0.07838230580091476, 0.1363348364830017, -0.21215586364269257, 0.07675530016422272, 0.1447518914937973, 0.14686468243598938, 0.06991209089756012, 0.08843740075826645, -0.11935211718082428, -0.015284902416169643, 0.16930945217609406, -0.044002968817949295, 0.16501764953136444, 0.10481955111026764, -0.013367846608161926, -0.05079612880945206, -0.07971523702144623, 0.2541899085044861, 0.07128541171550751, -0.1458708792924881, -0.15604135394096375, 0.11365226656198502, -0.16018034517765045, -0.034580036997795105, 0.05678928270936012, -0.07191935181617737, -0.15881866216659546, -0.1955043375492096, 0.06456604599952698, 0.5308966040611267, 0.13605228066444397, -0.18340089917182922, -0.054736778140068054, -0.09668046236038208, -0.0006025233305990696, 0.06609033048152924, 0.0835171788930893, -0.13018545508384705, -0.07167276740074158, -0.04313529655337334, 0.08809386193752289, 0.29993879795074463, -0.07008976489305496, 0.005112136714160442, 0.1464609056711197, 0.03064284473657608, 0.005341261625289917, -0.03758316487073898, -0.002741048112511635, -0.19020092487335205, -0.005203879438340664, -0.03693881630897522, 0.017715569585561752, 0.025151528418064117, -0.1393381506204605, 0.04255775362253189, 0.080945685505867, -0.23745450377464294, 0.21049565076828003, -0.01615971140563488, -0.0642223060131073, 0.0915207713842392, 0.10660708695650101, -0.14731745421886444, -0.027426915243268013, 0.2378913164138794, -0.2964036166667938, 0.2034282684326172, 0.2009482979774475, 0.04706001281738281, 0.13964271545410156, 0.05233509838581085, 0.11507777869701385, 0.045886922627687454, 0.12765641510486603, -0.15917260944843292, -0.13223722577095032, -0.023241272196173668, -0.129884734749794, -0.027176398783922195, 0.009421694092452526]
\ No newline at end of file
[-0.13293321430683136, 0.09793781489133835, 0.06550372391939163, 0.02364283800125122, -0.043399304151535034, 0.004586201161146164, -0.09000064432621002, -0.05539097636938095, 0.10467389971017838, -0.09715163707733154, 0.18808841705322266, -0.0205547958612442, -0.23795807361602783, -0.026068881154060364, -0.04790578782558441, 0.10736768692731857, -0.1791372150182724, -0.09754926711320877, -0.08212480694055557, -0.07197146117687225, 0.07512062042951584, 0.06562784314155579, -0.06910805404186249, 0.010537944734096527, -0.1353086233139038, -0.29961100220680237, -0.04597249627113342, -0.09019482880830765, 0.04843198508024216, -0.08456507325172424, -0.06385420262813568, 0.09591938555240631, -0.08721363544464111, 0.0029465071856975555, 0.062499962747097015, 0.08367685973644257, 0.004837760701775551, -0.02126195654273033, 0.18138188123703003, -0.0330311618745327, -0.1149168312549591, -0.014434240758419037, 0.04467501491308212, 0.32643717527389526, 0.13417592644691467, 0.049149081110954285, 0.0002636462450027466, -0.030674105510115623, 0.15085124969482422, -0.25617715716362, 0.007638035342097282, 0.20309507846832275, 0.155135378241539, 0.10535001009702682, 0.09949050843715668, -0.19686023890972137, 0.055761925876140594, 0.10784860700368881, -0.16404221951961517, 0.12705324590206146, 0.06780532747507095, -0.12821750342845917, -0.015174079686403275, -0.08541303128004074, 0.23064906895160675, 0.04403648152947426, -0.16575516760349274, -0.10698974132537842, 0.13322079181671143, -0.10516376793384552, -0.03650324046611786, 0.05603502690792084, -0.1468498408794403, -0.21398313343524933, -0.22947216033935547, 0.022328242659568787, 0.4006509780883789, 0.2338075339794159, -0.1980385184288025, 0.05581464245915413, -0.033158354461193085, -0.047999653965234756, 0.10474226623773575, 0.11267579346895218, -0.0938166156411171, -0.005631402134895325, -0.0698985829949379, 0.06661885231733322, 0.18326956033706665, 0.042940653860569, -0.031386956572532654, 0.2056775838136673, 0.011491281911730766, 0.05759737640619278, -0.029466431587934494, -0.04597870260477066, -0.07393362373113632, -0.037820909172296524, -0.07149908691644669, 0.023783499374985695, 0.016364723443984985, -0.09576655924320221, 0.02455282025039196, 0.11984197050333023, -0.11477060616016388, 0.17211446166038513, -0.008100427687168121, 0.09116753190755844, -0.004660069011151791, 0.029939215630292892, -0.10707360506057739, 0.03878428786993027, 0.15494686365127563, -0.2801153063774109, 0.1764734983444214, 0.1614546924829483, 0.09864784777164459, 0.12133727967739105, 0.05214153230190277, 0.04244184494018555, 0.024142231792211533, -0.019513756036758423, -0.22539466619491577, -0.0927465632557869, 0.06196486949920654, -0.09522707760334015, 0.04965142160654068, 0.023237790912389755]
\ No newline at end of file
...@@ -87,7 +87,7 @@ describe('faceLandmarkNet', () => { ...@@ -87,7 +87,7 @@ describe('faceLandmarkNet', () => {
await faceLandmarkNet.load('base/weights') await faceLandmarkNet.load('base/weights')
}) })
it('computes face landmarks', async () => { it('computes face landmarks for squared input', async () => {
const { width, height } = imgEl1 const { width, height } = imgEl1
const result = await faceLandmarkNet.detectLandmarks(imgEl1) as FaceLandmarks const result = await faceLandmarkNet.detectLandmarks(imgEl1) as FaceLandmarks
......
import * as tf from '@tensorflow/tfjs-core';
import * as faceapi from '../../../src'; import * as faceapi from '../../../src';
describe('faceRecognitionNet', () => { describe('faceRecognitionNet', () => {
let faceRecognitionNet: any, imgEl: HTMLImageElement, faceDescriptor: number[] let imgEl1: HTMLImageElement
let imgEl2: HTMLImageElement
let imgElRect: HTMLImageElement
let faceDescriptor1: number[]
let faceDescriptor2: number[]
let faceDescriptorRect: number[]
beforeAll(async () => {
const img1 = await (await fetch('base/test/images/face1.png')).blob()
imgEl1 = await faceapi.bufferToImage(img1)
const img2 = await (await fetch('base/test/images/face2.png')).blob()
imgEl2 = await faceapi.bufferToImage(img2)
const imgRect = await (await fetch('base/test/images/face_rectangular.png')).blob()
imgElRect = await faceapi.bufferToImage(imgRect)
faceDescriptor1 = await (await fetch('base/test/data/faceDescriptor1.json')).json()
faceDescriptor2 = await (await fetch('base/test/data/faceDescriptor2.json')).json()
faceDescriptorRect = await (await fetch('base/test/data/faceDescriptorRect.json')).json()
})
describe('uncompressed weights', () => {
let faceRecognitionNet: faceapi.FaceRecognitionNet
beforeAll(async () => { beforeAll(async () => {
const res = await fetch('base/weights/uncompressed/face_recognition_model.weights') const res = await fetch('base/weights/uncompressed/face_recognition_model.weights')
const weights = new Float32Array(await res.arrayBuffer()) const weights = new Float32Array(await res.arrayBuffer())
faceRecognitionNet = faceapi.faceRecognitionNet(weights) faceRecognitionNet = faceapi.faceRecognitionNet(weights)
})
const img = await (await fetch('base/test/images/face1.png')).blob() it('computes face descriptor for squared input', async () => {
imgEl = await faceapi.bufferToImage(img) const result = await faceRecognitionNet.computeFaceDescriptor(imgEl1) as Float32Array
faceDescriptor = await (await fetch('base/test/data/faceDescriptor.json')).json() expect(result.length).toEqual(128)
expect(result).toEqual(new Float32Array(faceDescriptor1))
}) })
it('computes face descriptor', async () => { it('computes face descriptor for rectangular input', async () => {
const result = await faceRecognitionNet.computeFaceDescriptor(imgEl) as number[] const result = await faceRecognitionNet.computeFaceDescriptor(imgElRect) as Float32Array
expect(result.length).toEqual(128) expect(result.length).toEqual(128)
expect(result).toEqual(new Float32Array(faceDescriptor)) expect(result).toEqual(new Float32Array(faceDescriptorRect))
})
})
// TODO: figure out why descriptors return NaN in the test cases
/*
describe('quantized weights', () => {
let faceRecognitionNet: faceapi.FaceRecognitionNet
beforeAll(async () => {
faceRecognitionNet = new faceapi.FaceRecognitionNet()
await faceRecognitionNet.load('base/weights')
})
it('computes face descriptor for squared input', async () => {
const result = await faceRecognitionNet.computeFaceDescriptor(imgEl1) as Float32Array
expect(result.length).toEqual(128)
expect(result).toEqual(new Float32Array(faceDescriptor1))
})
it('computes face descriptor for rectangular input', async () => {
const result = await faceRecognitionNet.computeFaceDescriptor(imgElRect) as Float32Array
expect(result.length).toEqual(128)
expect(result).toEqual(new Float32Array(faceDescriptorRect))
})
})
*/
describe('batch inputs', () => {
let faceRecognitionNet: faceapi.FaceRecognitionNet
beforeAll(async () => {
const res = await fetch('base/weights/uncompressed/face_recognition_model.weights')
const weights = new Float32Array(await res.arrayBuffer())
faceRecognitionNet = faceapi.faceRecognitionNet(weights)
})
it('computes face descriptors for batch of image elements', async () => {
const inputs = [imgEl1, imgEl2, imgElRect]
const faceDescriptors = [
faceDescriptor1,
faceDescriptor2,
faceDescriptorRect
]
const results = await faceRecognitionNet.computeFaceDescriptor(inputs) as Float32Array[]
expect(Array.isArray(results)).toBe(true)
expect(results.length).toEqual(3)
results.forEach((result, batchIdx) => {
expect(result).toEqual(new Float32Array(faceDescriptors[batchIdx]))
})
})
it('computes face landmarks for batch of tf.Tensor3D', async () => {
const inputs = [imgEl1, imgEl2, imgElRect].map(el => tf.fromPixels(el))
const faceDescriptors = [
faceDescriptor1,
faceDescriptor2,
faceDescriptorRect
]
const results = await faceRecognitionNet.computeFaceDescriptor(inputs) as Float32Array[]
expect(Array.isArray(results)).toBe(true)
expect(results.length).toEqual(3)
results.forEach((result, batchIdx) => {
expect(result).toEqual(new Float32Array(faceDescriptors[batchIdx]))
})
}) })
it('computes face landmarks for tf.Tensor4D', async () => {
const inputs = [imgEl1, imgEl2].map(el => tf.fromPixels(el))
const faceDescriptors = [
faceDescriptor1,
faceDescriptor2,
faceDescriptorRect
]
const results = await faceRecognitionNet.computeFaceDescriptor(tf.stack(inputs) as tf.Tensor4D) as Float32Array[]
expect(Array.isArray(results)).toBe(true)
expect(results.length).toEqual(2)
results.forEach((result, batchIdx) => {
expect(result).toEqual(new Float32Array(faceDescriptors[batchIdx]))
})
})
it('computes face landmarks for batch of mixed inputs', async () => {
const inputs = [imgEl1, tf.fromPixels(imgEl2), tf.fromPixels(imgElRect)]
const faceDescriptors = [
faceDescriptor1,
faceDescriptor2,
faceDescriptorRect
]
const results = await faceRecognitionNet.computeFaceDescriptor(inputs) as Float32Array[]
expect(Array.isArray(results)).toBe(true)
expect(results.length).toEqual(3)
results.forEach((result, batchIdx) => {
expect(result).toEqual(new Float32Array(faceDescriptors[batchIdx]))
})
})
})
}) })
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment