Commit bcb98830 by vincent

check in latest training script

parent 6c3c55e1
......@@ -21,10 +21,12 @@ app.use(express.static(imagesPath))
app.use(express.static(detectionsPath))
const detectionFilenames = fs.readdirSync(detectionsPath)
const detectionFilenamesMultibox = JSON.parse(fs.readFileSync(path.join(__dirname, './tinyYolov2/multibox.json')))
app.use(express.static(trainDataPath))
app.get('/detection_filenames', (req, res) => res.status(202).send(detectionFilenames))
app.get('/detection_filenames_multibox', (req, res) => res.status(202).send(detectionFilenamesMultibox))
app.get('/', (req, res) => res.sendFile(path.join(publicDir, 'train.html')))
app.get('/verify', (req, res) => res.sendFile(path.join(publicDir, 'verify.html')))
......
......@@ -17,14 +17,17 @@
<script>
tf = faceapi.tf
const startIdx224 = 3220
const startIdx320 = 20688
const startIdx416 = 950
const startIdx608 = 15220
const startIdx224 = 35060
const startIdx320 = 41188
const startIdx416 = 31050
const startIdx608 = 16520
const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
//const weightsUrl = '/tmp/tmp_2_count_41000.weights'
const fromEpoch = 0
const trainOnlyMultibox = false
window.debug = false
window.logTrainSteps = true
......@@ -42,6 +45,8 @@
window.optimizer = tf.train.adam(0.001, 0.9, 0.999, 1e-8)
// all samples
//const dataStartIdx = 8000
const dataStartIdx = 0
const numTrainSamples = Infinity
async function loadNetWeights(uri) {
......@@ -52,18 +57,28 @@
return fetch('/detection_filenames').then(res => res.json())
}
async function fetchDetectionFilenamesMultibox() {
return fetch('/detection_filenames_multibox').then(res => res.json())
}
async function run() {
const weights = await loadNetWeights(weightsUrl)
window.net = new faceapi.TinyYolov2(true)
window.net.load(weights)
window.net.variable()
window.detectionFilenames = (await fetchDetectionFilenames()).slice(0, numTrainSamples)
const fetchDetectionsFn = trainOnlyMultibox
? fetchDetectionFilenamesMultibox
: fetchDetectionFilenames
window.detectionFilenames = (await fetchDetectionsFn()).slice(dataStartIdx, dataStartIdx + numTrainSamples)
window.lossMap = {}
console.log('ready')
}
//const trainSizes = [224, 320, 416, 608]
//const trainSizes = [224, 320, 416]
const trainSizes = [608]
function logLossChange(lossType) {
......@@ -71,9 +86,22 @@
log(`${lossType} : ${faceapi.round(currentLoss[lossType])} (avg: ${faceapi.round(currentLoss[lossType] / detectionFilenames.length)}) (delta: ${currentLoss[lossType] - prevLoss[lossType]})`)
}
window.count = 0
function _onBatchProcessed(dataIdx, inputSize) {
window.count++
const idx = (dataIdx + 1) + (window.epoch * window.detectionFilenames.length)
console.log('dataIdx', dataIdx)
if ((window.count % saveEveryNthDataIdx) === 0) {
saveWeights(window.net, `tmp_2_count_${window.count}.weights`)
}
}
function onBatchProcessed(dataIdx, inputSize) {
if (((dataIdx + 1) % saveEveryNthDataIdx) === 0) {
saveWeights(window.net, `tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608 + dataIdx + 1}.weights`)
const idx = (dataIdx + 1) + (window.epoch * window.detectionFilenames.length)
console.log('idx', idx)
if ((idx % saveEveryNthDataIdx) === 0) {
saveWeights(window.net, `tmp__224_${startIdx224 + (inputSize === 224 ? idx : 0)}__320_${startIdx320 + (inputSize === 320 ? idx : 0)}__416_${startIdx416 + (inputSize === 416 ? idx : 0)}__608_${startIdx608 + (inputSize === 608 ? idx : 0)}.weights`)
}
}
......@@ -82,6 +110,7 @@
const batchSize = 1
for (let i = fromEpoch; i < trainSteps; i++) {
window.epoch = i
log('step', i)
let ts2 = Date.now()
......
......@@ -73,15 +73,17 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
.rescale({ height: imgHeight, width: imgWidth })
.rescale(scaleFactor)
const isTooTiny = box.width < 50 || box.height < 50
if (isTooTiny) {
const isTooTiny = box.width < 40 || box.height < 40
if (isTooTiny && window.debug) {
log(`skipping box for input size ${inputSize}: (${Math.floor(box.width)} x ${Math.floor(box.height)})`)
}
return !isTooTiny
})
if (!filteredGroundTruthBoxes.length) {
if (window.debug) {
log(`no boxes for input size ${inputSize}, ${groundTruthBoxes[batchIdx].length} boxes were too small`)
}
batchInput.dispose()
onBatchProcessed(dataIdx, inputSize)
return
......@@ -89,7 +91,6 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
let ts = Date.now()
const loss = minimize(filteredGroundTruthBoxes, batchInput, inputSize, batch)
ts = Date.now() - ts
if (window.logTrainSteps) {
log(`trainStep time for dataIdx ${dataIdx} (${inputSize}): ${ts} ms`)
......@@ -109,7 +110,15 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
await step(batchCreators.next(rescaleEveryNthBatch))
}
function createBatchCreators(detectionFilenames, batchSize, ) {
async function fetchGroundTruthBoxesForFile(file) {
const boxes = await fetch(file).then(res => res.json())
return {
file,
boxes
}
}
function createBatchCreators(detectionFilenames, batchSize) {
if (batchSize < 1) {
throw new Error('invalid batch size: ' + batchSize)
}
......@@ -126,9 +135,8 @@ function createBatchCreators(detectionFilenames, batchSize, ) {
pushToBatch(detectionFilenames)
const batchCreators = batches.map((filenamesForBatch, dataIdx) => async () => {
const groundTruthBoxes = await Promise.all(filenamesForBatch.map(
file => fetch(file).then(res => res.json())
))
const groundTruthBoxes = (await Promise.all(filenamesForBatch.map(fetchGroundTruthBoxesForFile)))
.map(({ boxes }) => boxes)
const imgs = await Promise.all(filenamesForBatch.map(
async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', '')))
......
......@@ -138,10 +138,11 @@
async function run() {
$('#imgByNr').keydown(onKeyDown)
const startIdx224 = 3220
const startIdx320 = 20688
const startIdx416 = 950
const startIdx608 = 15220
const startIdx224 = 35060
const startIdx320 = 41188
const startIdx416 = 31050
const startIdx608 = 16520
const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
......@@ -149,7 +150,7 @@
window.net = new faceapi.TinyYolov2(true)
await window.net.load(weights)
window.imgs = (await fetchDetectionFilenames()).slice(0, 100).map(f => f.replace('.json', ''))
window.imgs = (await fetchDetectionFilenames()).map(f => f.replace('.json', ''))
$('#loader').hide()
onSelectionChanged($('#selectList select').val())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment