Commit 459067ce by vincent

rewrite loss function with tf ops, such that it's differentiable

parent 7b651ee7
import { BoundingBox } from './BoundingBox';
export interface IRect {
x: number
y: number
......@@ -45,4 +46,8 @@ export class Rect implements IRect {
Math.floor(this.height)
)
}
public toBoundingBox(): BoundingBox {
return new BoundingBox(this.x, this.y, this.x + this.width, this.y + this.height)
}
}
\ No newline at end of file
......@@ -95,8 +95,6 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
const out = await this.forwardInput(netInput, inputSize)
const out0 = tf.tidy(() => tf.unstack(out)[0].expandDims()) as tf.Tensor4D
console.log(out0.shape)
const inputDimensions = {
width: netInput.getInputWidth(0),
height: netInput.getInputHeight(0)
......@@ -147,7 +145,7 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
for (let col = 0; col < numCells; col ++) {
for (let anchor = 0; anchor < NUM_BOXES; anchor ++) {
const score = sigmoid(scoresTensor.get(row, col, anchor, 0))
if (score > scoreThreshold) {
if (!scoreThreshold || score > scoreThreshold) {
const ctX = ((col + sigmoid(boxesTensor.get(row, col, anchor, 0))) / numCells) * paddings.x
const ctY = ((row + sigmoid(boxesTensor.get(row, col, anchor, 1))) / numCells) * paddings.y
const width = ((Math.exp(boxesTensor.get(row, col, anchor, 2)) * this.anchors[anchor].x) / numCells) * paddings.x
......
......@@ -46,6 +46,6 @@ export type TinyYolov2ForwardParams = {
}
export type PostProcessingParams = {
scoreThreshold: number
scoreThreshold?: number
paddings: Point
}
\ No newline at end of file
......@@ -70,20 +70,11 @@
ts = Date.now() - ts
console.log('step %s done (%s ms)', i, ts)
if (((i + 1) % saveEveryNthIteration) === 0) {
saveWeights(i)
saveWeights(window.trainNet, 'landmark_trained_weights_' + idx + '.weights')
}
}
}
function saveWeights(idx = 0) {
const binaryWeights = new Float32Array(
window.trainNet.getParamList()
.map(({ tensor }) => Array.from(tensor.dataSync()))
.reduce((flat, arr) => flat.concat(arr))
)
saveAs(new Blob([binaryWeights]), 'landmark_trained_weights_' + idx + '.weights')
}
</script>
</body>
......
......@@ -17,6 +17,9 @@ app.use(express.static(path.join(__dirname, '../../dist')))
const trainDataPath = path.resolve(process.env.TRAIN_DATA_PATH)
const imagesPath = path.join(trainDataPath, './final_images')
const detectionsPath = path.join(trainDataPath, './final_detections')
app.use(express.static(imagesPath))
app.use(express.static(detectionsPath))
const detectionFilenames = fs.readdirSync(detectionsPath)
app.use(express.static(trainDataPath))
......
......@@ -18,4 +18,13 @@ function shuffle(a) {
a[j] = x;
}
return a;
}
function saveWeights(net, filename = 'train_tmp') {
const binaryWeights = new Float32Array(
net.getParamList()
.map(({ tensor }) => Array.from(tensor.dataSync()))
.reduce((flat, arr) => flat.concat(arr))
)
saveAs(new Blob([binaryWeights]), filename)
}
\ No newline at end of file
......@@ -3,104 +3,195 @@ const objectScale = 1
const noObjectScale = 0.5
const coordScale = 5
const squared = e => Math.pow(e, 2)
const CELL_SIZE = 32
const isSameAnchor = (p1, p2) =>
p1.row === p2.row
&& p1.col === p2.col
&& p1.anchor === p2.anchor
const getNumCells = inputSize => inputSize / CELL_SIZE
const sum = vals => vals.reduce((sum, val) => sum + val, 0)
function computeNoObjectLoss(negative) {
return squared(0 - negative.score)
function getAnchors() {
return window.net.anchors
}
function computeObjectLoss({ groundTruth, pred }) {
return squared(
faceapi.iou(
groundTruth.box,
pred.box
)
- pred.score
)
}
function assignBoxesToAnchors(groundTruthBoxes, reshapedImgDims) {
function computeCoordLoss({ groundTruth, pred }, imgDims) {
const anchor = window.net.anchors[groundTruth.anchor]
const getWidthCorrections = box => Math.log((box.width / imgDims.width) / anchor.x)
const getHeightCorrections = box => Math.log((box.height / imgDims.height) / anchor.y)
const inputSize = Math.max(reshapedImgDims.width, reshapedImgDims.height)
const numCells = getNumCells(inputSize)
return squared((groundTruth.box.left - pred.box.left) / imgDims.width)
+ squared((groundTruth.box.top - pred.box.top) / imgDims.height)
+ squared(getWidthCorrections(groundTruth.box) - getWidthCorrections(pred.box))
+ squared(getHeightCorrections(groundTruth.box) - getHeightCorrections(pred.box))
}
function computeLoss(outBoxesByAnchor, groundTruth, imgDims) {
const { anchors } = window.net
const inputSize = Math.max(imgDims.width, imgDims.height)
const numCells = inputSize / 32
const groundTruthByAnchor = groundTruth.map(rect => {
const x = rect.x * imgDims.width
const y = rect.y * imgDims.height
const width = rect.width * imgDims.width
const height = rect.height * imgDims.height
return groundTruthBoxes.map(box => {
const { left: x, top: y, width, height } = box.rescale(reshapedImgDims)
const row = Math.round((y / inputSize) * numCells)
const col = Math.round((x / inputSize) * numCells)
const anchorsByIou = anchors.map((a, idx) => ({
const anchorsByIou = getAnchors().map((anchor, idx) => ({
idx,
iou: faceapi.iou(
new faceapi.BoundingBox(0, 0, a.x * 32, a.y * 32),
new faceapi.BoundingBox(0, 0, anchor.x * CELL_SIZE, anchor.y * CELL_SIZE),
new faceapi.BoundingBox(0, 0, width, height)
)
})).sort((a1, a2) => a2.iou - a1.iou)
console.log('anchorsByIou', anchorsByIou)
const anchor = anchorsByIou[0].idx
return {
box: new faceapi.BoundingBox(x, y, x + width, y + height),
row,
col,
anchor
return { row, col, anchor, box }
})
}
function getGroundTruthMask(groundTruthBoxes, inputSize) {
const numCells = getNumCells(inputSize)
const mask = tf.zeros([numCells, numCells, 25])
const buf = mask.buffer()
groundTruthBoxes.forEach(({ row, col, anchor }) => {
const anchorOffset = anchor * 5
for (let i = 0; i < 5; i++) {
buf.set(1, row, col, anchorOffset + i)
}
})
console.log('outBoxesByAnchor', outBoxesByAnchor.filter(o => o.score > 0.5).map(o => o))
console.log('outBoxesByAnchor', outBoxesByAnchor.filter(o => o.score > 0.5).map(o => o.box.rescale(imgDims)))
console.log('groundTruthByAnchor', groundTruthByAnchor)
const negatives = outBoxesByAnchor.filter(pred => !groundTruthByAnchor.find(gt => isSameAnchor(gt, pred)))
const positives = outBoxesByAnchor
.map(pred => ({
groundTruth: groundTruthByAnchor.find(gt => isSameAnchor(gt, pred)),
pred: {
...pred,
box: pred.box.rescale(imgDims)
}
}))
.filter(pos => !!pos.groundTruth)
console.log('negatives', negatives)
console.log('positives', positives)
const noObjectLoss = sum(negatives.map(computeNoObjectLoss))
const objectLoss = sum(positives.map(computeObjectLoss))
const coordLoss = sum(positives.map(positive => computeCoordLoss(positive, imgDims)))
console.log('noObjectLoss', noObjectLoss)
console.log('objectLoss', objectLoss)
console.log('coordLoss', coordLoss)
return noObjectScale * noObjectLoss
+ objectScale * objectLoss
+ coordScale * coordLoss
// we don't compute a class loss, since we only have 1 class
// + class_scale * sum(class_loss)
return mask
}
function computeBoxAdjustments(groundTruthBoxes, reshapedImgDims) {
const inputSize = Math.max(reshapedImgDims.width, reshapedImgDims.height)
const numCells = getNumCells(inputSize)
const adjustments = tf.zeros([numCells, numCells, 25])
const buf = adjustments.buffer()
groundTruthBoxes.forEach(({ row, col, anchor, box }) => {
const { left, top, right, bottom, width, height } = box.rescale(reshapedImgDims)
const centerX = (left + right) / 2
const centerY = (top + bottom) / 2
const dx = (centerX - (col * CELL_SIZE + (CELL_SIZE / 2))) / inputSize
const dy = (centerY - (row * CELL_SIZE + (CELL_SIZE / 2))) / inputSize
const dw = Math.log(width / getAnchors()[anchor].x)
const dh = Math.log(height / getAnchors()[anchor].y)
const anchorOffset = anchor * 5
buf.set(dx, row, col, anchorOffset + 0)
buf.set(dy, row, col, anchorOffset + 1)
buf.set(dw, row, col, anchorOffset + 2)
buf.set(dh, row, col, anchorOffset + 3)
})
return adjustments
}
function computeIous(predBoxes, groundTruthBoxes, reshapedImgDims) {
const numCells = getNumCells(Math.max(reshapedImgDims.width, reshapedImgDims.height))
const isSameAnchor = p1 => p2 =>
p1.row === p2.row
&& p1.col === p2.col
&& p1.anchor === p2.anchor
const ious = tf.zeros([numCells, numCells, 25])
const buf = ious.buffer()
groundTruthBoxes.forEach(({ row, col, anchor, box }) => {
const predBox = predBoxes.find(isSameAnchor({ row, col, anchor }))
if (!predBox) {
console.log(groundTruthBoxes)
console.log(predBoxes)
throw new Error(`no output box found for: row ${row}, col ${col}, anchor ${anchor}`)
}
const iou = faceapi.iou(
box.rescale(reshapedImgDims),
predBox.box.rescale(reshapedImgDims)
)
const anchorOffset = anchor * 5
buf.set(iou, row, col, anchorOffset + 4)
})
return ious
}
function computeNoObjectLoss(outTensor) {
return tf.tidy(() => tf.square(tf.sigmoid(outTensor)))
}
function computeObjectLoss(outTensor, groundTruthBoxes, reshapedImgDims, paddings) {
return tf.tidy(() => {
const predBoxes = window.net.postProcess(
outTensor,
{ paddings }
)
const ious = computeIous(
predBoxes,
groundTruthBoxes,
reshapedImgDims
)
return tf.square(tf.sub(ious, tf.sigmoid(outTensor)))
})
}
function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims) {
return tf.tidy(() => {
const boxAdjustments = computeBoxAdjustments(
groundTruthBoxes,
reshapedImgDims
)
return tf.square(tf.sub(boxAdjustments, outTensor))
})
}
function computeLoss(outTensor, groundTruth, reshapedImgDims, paddings) {
const inputSize = Math.max(reshapedImgDims.width, reshapedImgDims.height)
if (!inputSize) {
throw new Error(`invalid inputSize: ${inputSize}`)
}
let groundTruthBoxes = assignBoxesToAnchors(
groundTruth
.map(({ x, y, width, height }) => new faceapi.Rect(x, y, width, height))
.map(rect => rect.toBoundingBox()),
reshapedImgDims
)
const mask = getGroundTruthMask(
groundTruthBoxes,
inputSize
)
const inverseMask = tf.tidy(() => tf.sub(tf.scalar(1), mask))
const noObjectLoss = tf.tidy(() =>
tf.mul(
tf.scalar(noObjectScale),
tf.sum(tf.mul(inverseMask, computeNoObjectLoss(outTensor)))
)
)
const objectLoss = tf.tidy(() =>
tf.mul(
tf.scalar(objectScale),
tf.sum(tf.mul(mask, computeObjectLoss(outTensor, groundTruthBoxes, reshapedImgDims, paddings)))
)
)
const coordLoss = tf.tidy(() =>
tf.mul(
tf.scalar(coordScale),
tf.sum(tf.mul(mask, computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims)))
)
)
const totalLoss = tf.tidy(() => noObjectLoss.add(objectLoss).add(coordLoss))
return {
noObjectLoss,
objectLoss,
coordLoss,
totalLoss
}
}
\ No newline at end of file
......@@ -9,6 +9,8 @@
<script src="commons.js"></script>
<script src="FileSaver.js"></script>
<script src="trainUtils.js"></script>
<script src="train.js"></script>
<script src="loss.js"></script>
</head>
<body>
......@@ -39,6 +41,8 @@
window.net.load(weights)
window.net.variable()
window.detectionFilenames = await fetchDetectionFilenames()
console.log('ready')
}
const trainSizes = [608, 416, 320, 224]
......@@ -48,7 +52,7 @@
console.log('step', i)
let ts = Date.now()
const batchCreators = createBatchCreators(shuffle(detectionFilenames), batchSize)
const batchCreators = createBatchCreators(shuffle(window.detectionFilenames), batchSize)
for (let s = 0; s < trainSizes.length; s++) {
let ts2 = Date.now()
......@@ -63,7 +67,7 @@
console.log('step %s done (%s ms)', i, ts)
if (((i + 1) % saveEveryNthIteration) === 0) {
saveWeights(i)
saveWeights(window.net, 'tiny_yolov2_separable_model_' + i)
}
}
}
......
......@@ -9,45 +9,43 @@ async function trainStep(batchCreators, inputSize) {
let ts = Date.now()
const loss = optimizer.minimize(() => {
// TBD: batch loss
const batchIdx = 0
const outTensor = window.net.forwardInput(batchInput, inputSize)
const outTensorsByBatch = tf.tidy(() => outTensor.unstack().expandDims())
outTensor.dispose()
const losses = outTensorsByBatch.map(
(out, batchIdx) => {
const outBoxesByAnchor = window.net.postProcess(
out,
{
scoreThreshold: -1,
paddings: batchInput.getRelativePaddings(batchIdx)
}
)
const loss = computeLoss(
outBoxesByAnchor,
groundTruthBoxes[batchIdx],
netInput.getReshapedInputDimensions(batchIdx)
)
console.log(`loss for batch ${batchIdx}: ${loss}`)
return loss
}
const {
noObjectLoss,
objectLoss,
coordLoss,
totalLoss
} = computeLoss(
outTensor,
groundTruthBoxes[batchIdx],
batchInput.getReshapedInputDimensions(batchIdx),
batchInput.getRelativePaddings(batchIdx)
)
outTensorsByBatch.forEach(t => t.dispose())
return losses.reduce((sum, loss) => sum + loss, 0)
console.log('ground truth boxes:', groundTruthBoxes[batchIdx].length)
console.log(`noObjectLoss[${dataIdx}]: ${noObjectLoss.dataSync()}`)
console.log(`objectLoss[${dataIdx}]: ${objectLoss.dataSync()}`)
console.log(`coordLoss[${dataIdx}]: ${coordLoss.dataSync()}`)
console.log(`totalLoss[${dataIdx}]: ${totalLoss.dataSync()}`)
return totalLoss
}, true)
ts = Date.now() - ts
console.log(`loss[${dataIdx}]: ${loss}, ${ts} ms (${ts / batchInput.batchSize} ms / batch element)`)
console.log(`trainStep time for dataIdx ${dataIdx} (${inputSize}): ${ts} ms (${ts / batchInput.batchSize} ms / batch element)`)
loss.dispose()
await tf.nextFrame()
}))
}
function createBatchCreators(batchSize) {
function createBatchCreators(detectionFilenames, batchSize) {
if (batchSize < 1) {
throw new Error('invalid batch size: ' + batchSize)
}
......@@ -56,22 +54,21 @@ function createBatchCreators(batchSize) {
const pushToBatch = (remaining) => {
if (remaining.length) {
batches.push(remaining.slice(0, batchSize))
pushToBatch(remaining.
slice(batchSize))
pushToBatch(remaining.slice(batchSize))
}
return batches
}
pushToBatch(window.detectionFilenames)
pushToBatch(detectionFilenames)
const batchCreators = batches.map(detectionFilenames => async () => {
const groundTruthBoxes = detectionFilenames.map(
detectionFilenames.map(file => fetch(file).then(res => res.json()))
)
const batchCreators = batches.map(filenameForBatch => async () => {
const groundTruthBoxes = await Promise.all(filenameForBatch.map(
file => fetch(file).then(res => res.json())
))
const imgs = await Promise.all(
detectionFilenames.map(async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', ''))))
)
const imgs = await Promise.all(filenameForBatch.map(
async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', '')))
))
return {
imgs,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment