Commit 3afe611f by vincent

fixed loss function

parent 99cbd7e5
...@@ -18,10 +18,13 @@ function assignBoxesToAnchors(groundTruthBoxes, reshapedImgDims) { ...@@ -18,10 +18,13 @@ function assignBoxesToAnchors(groundTruthBoxes, reshapedImgDims) {
const numCells = getNumCells(inputSize) const numCells = getNumCells(inputSize)
return groundTruthBoxes.map(box => { return groundTruthBoxes.map(box => {
const { left: x, top: y, width, height } = box.rescale(reshapedImgDims) const { left, top, width, height } = box.rescale(reshapedImgDims)
const row = Math.round((y / inputSize) * numCells) const ctX = left + (width / 2)
const col = Math.round((x / inputSize) * numCells) const ctY = top + (height / 2)
const col = Math.floor((ctX / inputSize) * numCells)
const row = Math.floor((ctY / inputSize) * numCells)
const anchorsByIou = getAnchors().map((anchor, idx) => ({ const anchorsByIou = getAnchors().map((anchor, idx) => ({
idx, idx,
...@@ -92,11 +95,13 @@ function computeBoxAdjustments(groundTruthBoxes, reshapedImgDims) { ...@@ -92,11 +95,13 @@ function computeBoxAdjustments(groundTruthBoxes, reshapedImgDims) {
const centerX = (left + right) / 2 const centerX = (left + right) / 2
const centerY = (top + bottom) / 2 const centerY = (top + bottom) / 2
const dCenterX = centerX - (col * CELL_SIZE + (CELL_SIZE / 2)) //const dCenterX = centerX - (col * CELL_SIZE + (CELL_SIZE / 2))
const dCenterY = centerY - (row * CELL_SIZE + (CELL_SIZE / 2)) //const dCenterY = centerY - (row * CELL_SIZE + (CELL_SIZE / 2))
const dCenterX = centerX - (col * CELL_SIZE)
const dCenterY = centerY - (row * CELL_SIZE)
const dx = inverseSigmoid(dCenterX / inputSize) const dx = inverseSigmoid(dCenterX / CELL_SIZE)
const dy = inverseSigmoid(dCenterY / inputSize) const dy = inverseSigmoid(dCenterY / CELL_SIZE)
const dw = Math.log((width / CELL_SIZE) / getAnchors()[anchor].x) const dw = Math.log((width / CELL_SIZE) / getAnchors()[anchor].x)
const dh = Math.log((height / CELL_SIZE) / getAnchors()[anchor].y) const dh = Math.log((height / CELL_SIZE) / getAnchors()[anchor].y)
...@@ -134,13 +139,14 @@ function computeIous(predBoxes, groundTruthBoxes, reshapedImgDims) { ...@@ -134,13 +139,14 @@ function computeIous(predBoxes, groundTruthBoxes, reshapedImgDims) {
const iou = faceapi.iou( const iou = faceapi.iou(
box.rescale(reshapedImgDims), box.rescale(reshapedImgDims),
predBox.box predBox.box.rescale(reshapedImgDims)
) )
if (window.debug) { if (window.debug) {
console.log('ground thruth box:', box.rescale(reshapedImgDims)) console.log('ground thruth box:', box.rescale(reshapedImgDims).toRect())
console.log('predicted box:', predBox.box) console.log('predicted box:', predBox.box.rescale(reshapedImgDims).toRect())
console.log(iou) console.log('predicted score:', predBox.score)
console.log('iou:', iou)
} }
const anchorOffset = anchor * 5 const anchorOffset = anchor * 5
...@@ -164,31 +170,6 @@ function computeObjectLoss(outTensor, groundTruthBoxes, reshapedImgDims, padding ...@@ -164,31 +170,6 @@ function computeObjectLoss(outTensor, groundTruthBoxes, reshapedImgDims, padding
{ paddings } { paddings }
) )
if (window.debug) {
console.log(predBoxes)
console.log(predBoxes.filter(b => b.score > 0.1))
}
// debug
const numCells = getNumCells(Math.max(reshapedImgDims.width, reshapedImgDims.height))
if (predBoxes.length !== (numCells * numCells * getAnchors().length)) {
console.log(predBoxes.length)
throw new Error('predBoxes.length !== (numCells * numCells * 25)')
}
const isInvalid = num => !num && num !== 0
predBoxes.forEach(({ row, col, anchor }) => {
if ([row, col, anchor].some(isInvalid)) {
console.log(row, col, anchor)
throw new Error('row, col, anchor invalid')
}
})
// debug
const ious = computeIous( const ious = computeIous(
predBoxes, predBoxes,
groundTruthBoxes, groundTruthBoxes,
...@@ -208,7 +189,6 @@ function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims, mask, pa ...@@ -208,7 +189,6 @@ function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims, mask, pa
reshapedImgDims reshapedImgDims
) )
// debug
if (window.debug) { if (window.debug) {
const indToPos = [] const indToPos = []
const numCells = outTensor.shape[1] const numCells = outTensor.shape[1]
...@@ -220,24 +200,22 @@ function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims, mask, pa ...@@ -220,24 +200,22 @@ function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims, mask, pa
} }
} }
const m = Array.from(mask.dataSync()) const indices = Array.from(mask.dataSync()).map((val, ind) => ({ val, ind })).filter(v => v.val !== 0).map(v => v.ind)
const ind = m.map((val, ind) => ({ val, ind })).filter(v => v.val !== 0).map(v => v.ind)
const gt = Array.from(boxAdjustments.dataSync()) const gt = Array.from(boxAdjustments.dataSync())
const out = Array.from(outTensor.dataSync()) const out = Array.from(outTensor.dataSync())
const comp = ind.map(i => ( const comp = indices.map(i => (
{ {
pos: indToPos[i], pos: indToPos[i],
gt: gt[i], gt: gt[i],
out: out[i] out: out[i]
} }
)) ))
console.log(comp)
console.log(comp.map(c => `gt: ${c.gt}, out: ${c.out}`)) console.log(comp.map(c => `gt: ${c.gt}, out: ${c.out}`))
const printBbox = (which) => { const getBbox = (which) => {
const { col, row, anchor } = comp[0].pos const { row, col, anchor } = comp[0].pos
console.log(col, row, anchor)
const ctX = ((col + faceapi.sigmoid(comp[0][which])) / numCells) * paddings.x const ctX = ((col + faceapi.sigmoid(comp[0][which])) / numCells) * paddings.x
const ctY = ((row + faceapi.sigmoid(comp[1][which])) / numCells) * paddings.y const ctY = ((row + faceapi.sigmoid(comp[1][which])) / numCells) * paddings.y
const width = ((Math.exp(comp[2][which]) * getAnchors()[anchor].x) / numCells) * paddings.x const width = ((Math.exp(comp[2][which]) * getAnchors()[anchor].x) / numCells) * paddings.x
...@@ -245,15 +223,16 @@ function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims, mask, pa ...@@ -245,15 +223,16 @@ function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims, mask, pa
const x = (ctX - (width / 2)) const x = (ctX - (width / 2))
const y = (ctY - (height / 2)) const y = (ctY - (height / 2))
console.log(which, x * reshapedImgDims.width, y * reshapedImgDims.height, width * reshapedImgDims.width, height * reshapedImgDims.height)
}
printBbox('out') return new faceapi.BoundingBox(x, y, x + width, y + height)
printBbox('gt') }
const outRect = getBbox('out').rescale(reshapedImgDims).toRect()
const gtRect = getBbox('gt').rescale(reshapedImgDims).toRect()
console.log('out', outRect)
console.log('gtRect', gtRect)
} }
// debug
const lossTensor = tf.sub(boxAdjustments, outTensor) const lossTensor = tf.sub(boxAdjustments, outTensor)
......
...@@ -20,9 +20,14 @@ ...@@ -20,9 +20,14 @@
const weightsUrl = '/tmp/test.weights' const weightsUrl = '/tmp/test.weights'
const fromEpoch = 800 const fromEpoch = 800
window.debug = true
window.logTrainSteps = true
// hyper parameters // hyper parameters
window.objectScale = 5 window.objectScale = 5
window.noObjectScale = 0.5 window.noObjectScale = 1
window.coordScale = 1 window.coordScale = 1
window.saveEveryNthIteration = 50 window.saveEveryNthIteration = 50
...@@ -45,7 +50,7 @@ ...@@ -45,7 +50,7 @@
window.net = new faceapi.TinyYolov2(true) window.net = new faceapi.TinyYolov2(true)
window.net.load(weights) window.net.load(weights)
window.net.variable() window.net.variable()
window.detectionFilenames = (await fetchDetectionFilenames()).slice(0, numTrainSamples) window.detectionFilenames = (await fetchDetectionFilenames()).slice(1, numTrainSamples + 1)
window.lossMap = {} window.lossMap = {}
console.log('ready') console.log('ready')
...@@ -90,7 +95,7 @@ ...@@ -90,7 +95,7 @@
if (((i + 1) % saveEveryNthIteration) === 0) { if (((i + 1) % saveEveryNthIteration) === 0) {
saveWeights(window.net, 'adam_511_n1_' + i + '.weights') saveWeights(window.net, 'adam_511_' + i + '.weights')
} }
} }
} }
......
...@@ -138,7 +138,7 @@ ...@@ -138,7 +138,7 @@
async function run() { async function run() {
$('#imgByNr').keydown(onKeyDown) $('#imgByNr').keydown(onKeyDown)
const weights = await loadNetWeights('/tmp/test.weights') const weights = await loadNetWeights('/tmp/test2.weights')
window.net = new faceapi.TinyYolov2(true) window.net = new faceapi.TinyYolov2(true)
await window.net.load(weights) await window.net.load(weights)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment