Commit 7b651ee7 by vincent

init training function

parent aa7dd87c
......@@ -37,9 +37,10 @@ function computeCoordLoss({ groundTruth, pred }, imgDims) {
+ squared(getHeightCorrections(groundTruth.box) - getHeightCorrections(pred.box))
}
function computeLoss(outBoxesByAnchor, groundTruth, inputSize, imgDims) {
function computeLoss(outBoxesByAnchor, groundTruth, imgDims) {
const { anchors } = window.net
const inputSize = Math.max(imgDims.width, imgDims.height)
const numCells = inputSize / 32
const groundTruthByAnchor = groundTruth.map(rect => {
......
......@@ -17,7 +17,7 @@
const weightsUrl = '/tmp/initial_tiny_yolov2_glorot_normal.weights'
window.saveEveryNthIteration = 2
window.saveEveryNthIteration = 1
window.trainSteps = 100
window.optimizer = tf.train.adam(0.001, 0.9, 0.999, 1e-8)
......@@ -41,38 +41,27 @@
window.detectionFilenames = await fetchDetectionFilenames()
}
/*
const trainSizes = [608, 416, 320, 224]
const outTensor = await window.net.forward(netInput, 608)
const detections = await window.net.locateFaces(netInput, forwardParams)
const outBoxesByAnchor = window.net.postProcess(
outTensor,
{
scoreThreshold: 0,
paddings: netInput.getRelativePaddings(0)
}
)
async function train(batchSize = 1) {
for (let i = 0; i < trainSteps; i++) {
console.log('step', i)
let ts = Date.now()
const groundTruth = detections.map(det => det.forSize(1, 1).box)
const batchCreators = createBatchCreators(shuffle(detectionFilenames), batchSize)
console.log(computeLoss(
outBoxesByAnchor,
groundTruth,
netInput.inputSize,
netInput.getReshapedInputDimensions(0)
))
for (let s = 0; s < trainSizes.length; s++) {
let ts2 = Date.now()
await trainStep(batchCreators, trainSizes[s])
*/
ts2 = Date.now() - ts2
console.log('train for size %s done (%s ms)', trainSizes[s], ts2)
}
async function train(batchSize = 1) {
for (let i = 0; i < trainSteps; i++) {
console.log('step', i)
const batchCreators = createBatchCreators(shuffle(detectionFilenames), batchSize)
let ts = Date.now()
await trainStep(batchCreators)
ts = Date.now() - ts
console.log('step %s done (%s ms)', i, ts)
if (((i + 1) % saveEveryNthIteration) === 0) {
saveWeights(i)
}
......
async function trainStep(batchCreators) {
async function trainStep(batchCreators, inputSize) {
await promiseSequential(batchCreators.map((batchCreator, dataIdx) => async () => {
const { batchInput, groundTruthBoxes } = await batchCreator()
/*
// TODO: skip if groundTruthBoxes are too tiny
const { imgs, groundTruthBoxes } = await batchCreator()
const batchInput = (await faceapi.toNetInput(imgs)).managed()
let ts = Date.now()
const cost = optimizer.minimize(() => {
const out = window.trainNet.forwardInput(batchInput.managed())
const loss = lossFunction(
landmarksBatchTensor,
out
const loss = optimizer.minimize(() => {
const outTensor = window.net.forwardInput(batchInput, inputSize)
const outTensorsByBatch = tf.tidy(() => outTensor.unstack().expandDims())
outTensor.dispose()
const losses = outTensorsByBatch.map(
(out, batchIdx) => {
const outBoxesByAnchor = window.net.postProcess(
out,
{
scoreThreshold: -1,
paddings: batchInput.getRelativePaddings(batchIdx)
}
)
const loss = computeLoss(
outBoxesByAnchor,
groundTruthBoxes[batchIdx],
netInput.getReshapedInputDimensions(batchIdx)
)
console.log(`loss for batch ${batchIdx}: ${loss}`)
return loss
}
)
return loss
outTensorsByBatch.forEach(t => t.dispose())
return losses.reduce((sum, loss) => sum + loss, 0)
}, true)
ts = Date.now() - ts
console.log(`loss[${dataIdx}]: ${await cost.data()}, ${ts} ms (${ts / batchInput.batchSize} ms / batch element)`)
landmarksBatchTensor.dispose()
cost.dispose()
console.log(`loss[${dataIdx}]: ${loss}, ${ts} ms (${ts / batchInput.batchSize} ms / batch element)`)
await tf.nextFrame()
}))
*/
}
function createBatchCreators(batchSize) {
......@@ -42,17 +65,16 @@ function createBatchCreators(batchSize) {
pushToBatch(window.detectionFilenames)
const batchCreators = batches.map(detectionFilenames => async () => {
const imgs = detectionFilenames.map(
const groundTruthBoxes = detectionFilenames.map(
detectionFilenames.map(file => fetch(file).then(res => res.json()))
)
const groundTruthBoxes = await Promise.all(
const imgs = await Promise.all(
detectionFilenames.map(async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', ''))))
)
const batchInput = await faceapi.toNetInput(imgs)
return {
batchInput,
imgs,
groundTruthBoxes
}
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment