Commit e7d1d043 by vincent

concat score tensors before getting their data, apparently it's faster this way

parent 0e899bd2
import { tf } from '..';
import { BoundingBox } from './BoundingBox'; import { BoundingBox } from './BoundingBox';
import { extractImagePatches } from './extractImagePatches'; import { extractImagePatches } from './extractImagePatches';
import { nms } from './nms'; import { nms } from './nms';
...@@ -26,8 +27,12 @@ export async function stage2( ...@@ -26,8 +27,12 @@ export async function stage2(
) )
stats.stage2_rnet = Date.now() - ts stats.stage2_rnet = Date.now() - ts
const scoreDatas = await Promise.all(rnetOuts.map(out => out.scores.data())) const scoresTensor = rnetOuts.length > 1
const scores = scoreDatas.map(arr => Array.from(arr)).reduce((all, arr) => all.concat(arr)) ? tf.concat(rnetOuts.map(out => out.scores))
: rnetOuts[0].scores
const scores = Array.from(await scoresTensor.data())
scoresTensor.dispose()
const indices = scores const indices = scores
.map((score, idx) => ({ score, idx })) .map((score, idx) => ({ score, idx }))
.filter(c => c.score > scoreThreshold) .filter(c => c.score > scoreThreshold)
......
...@@ -4,6 +4,7 @@ import { extractImagePatches } from './extractImagePatches'; ...@@ -4,6 +4,7 @@ import { extractImagePatches } from './extractImagePatches';
import { nms } from './nms'; import { nms } from './nms';
import { ONet } from './ONet'; import { ONet } from './ONet';
import { ONetParams } from './types'; import { ONetParams } from './types';
import { tf } from '..';
export async function stage3( export async function stage3(
img: HTMLCanvasElement, img: HTMLCanvasElement,
...@@ -27,8 +28,12 @@ export async function stage3( ...@@ -27,8 +28,12 @@ export async function stage3(
) )
stats.stage3_onet = Date.now() - ts stats.stage3_onet = Date.now() - ts
const scoreDatas = await Promise.all(onetOuts.map(out => out.scores.data())) const scoresTensor = onetOuts.length > 1
const scores = scoreDatas.map(arr => Array.from(arr)).reduce((all, arr) => all.concat(arr)) ? tf.concat(onetOuts.map(out => out.scores))
: onetOuts[0].scores
const scores = Array.from(await scoresTensor.data())
scoresTensor.dispose()
const indices = scores const indices = scores
.map((score, idx) => ({ score, idx })) .map((score, idx) => ({ score, idx }))
.filter(c => c.score > scoreThreshold) .filter(c => c.score > scoreThreshold)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment