Commit bcb98830 by vincent

check in latest training script

parent 6c3c55e1
...@@ -21,10 +21,12 @@ app.use(express.static(imagesPath)) ...@@ -21,10 +21,12 @@ app.use(express.static(imagesPath))
app.use(express.static(detectionsPath)) app.use(express.static(detectionsPath))
const detectionFilenames = fs.readdirSync(detectionsPath) const detectionFilenames = fs.readdirSync(detectionsPath)
const detectionFilenamesMultibox = JSON.parse(fs.readFileSync(path.join(__dirname, './tinyYolov2/multibox.json')))
app.use(express.static(trainDataPath)) app.use(express.static(trainDataPath))
app.get('/detection_filenames', (req, res) => res.status(202).send(detectionFilenames)) app.get('/detection_filenames', (req, res) => res.status(202).send(detectionFilenames))
app.get('/detection_filenames_multibox', (req, res) => res.status(202).send(detectionFilenamesMultibox))
app.get('/', (req, res) => res.sendFile(path.join(publicDir, 'train.html'))) app.get('/', (req, res) => res.sendFile(path.join(publicDir, 'train.html')))
app.get('/verify', (req, res) => res.sendFile(path.join(publicDir, 'verify.html'))) app.get('/verify', (req, res) => res.sendFile(path.join(publicDir, 'verify.html')))
......
...@@ -17,14 +17,17 @@ ...@@ -17,14 +17,17 @@
<script> <script>
tf = faceapi.tf tf = faceapi.tf
const startIdx224 = 3220 const startIdx224 = 35060
const startIdx320 = 20688 const startIdx320 = 41188
const startIdx416 = 950 const startIdx416 = 31050
const startIdx608 = 15220 const startIdx608 = 16520
const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights` const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
//const weightsUrl = '/tmp/tmp_2_count_41000.weights'
const fromEpoch = 0 const fromEpoch = 0
const trainOnlyMultibox = false
window.debug = false window.debug = false
window.logTrainSteps = true window.logTrainSteps = true
...@@ -42,6 +45,8 @@ ...@@ -42,6 +45,8 @@
window.optimizer = tf.train.adam(0.001, 0.9, 0.999, 1e-8) window.optimizer = tf.train.adam(0.001, 0.9, 0.999, 1e-8)
// all samples // all samples
//const dataStartIdx = 8000
const dataStartIdx = 0
const numTrainSamples = Infinity const numTrainSamples = Infinity
async function loadNetWeights(uri) { async function loadNetWeights(uri) {
...@@ -52,18 +57,28 @@ ...@@ -52,18 +57,28 @@
return fetch('/detection_filenames').then(res => res.json()) return fetch('/detection_filenames').then(res => res.json())
} }
async function fetchDetectionFilenamesMultibox() {
return fetch('/detection_filenames_multibox').then(res => res.json())
}
async function run() { async function run() {
const weights = await loadNetWeights(weightsUrl) const weights = await loadNetWeights(weightsUrl)
window.net = new faceapi.TinyYolov2(true) window.net = new faceapi.TinyYolov2(true)
window.net.load(weights) window.net.load(weights)
window.net.variable() window.net.variable()
window.detectionFilenames = (await fetchDetectionFilenames()).slice(0, numTrainSamples)
const fetchDetectionsFn = trainOnlyMultibox
? fetchDetectionFilenamesMultibox
: fetchDetectionFilenames
window.detectionFilenames = (await fetchDetectionsFn()).slice(dataStartIdx, dataStartIdx + numTrainSamples)
window.lossMap = {} window.lossMap = {}
console.log('ready') console.log('ready')
} }
//const trainSizes = [224, 320, 416, 608] //const trainSizes = [224, 320, 416]
const trainSizes = [608] const trainSizes = [608]
function logLossChange(lossType) { function logLossChange(lossType) {
...@@ -71,9 +86,22 @@ ...@@ -71,9 +86,22 @@
log(`${lossType} : ${faceapi.round(currentLoss[lossType])} (avg: ${faceapi.round(currentLoss[lossType] / detectionFilenames.length)}) (delta: ${currentLoss[lossType] - prevLoss[lossType]})`) log(`${lossType} : ${faceapi.round(currentLoss[lossType])} (avg: ${faceapi.round(currentLoss[lossType] / detectionFilenames.length)}) (delta: ${currentLoss[lossType] - prevLoss[lossType]})`)
} }
window.count = 0
function _onBatchProcessed(dataIdx, inputSize) {
window.count++
const idx = (dataIdx + 1) + (window.epoch * window.detectionFilenames.length)
console.log('dataIdx', dataIdx)
if ((window.count % saveEveryNthDataIdx) === 0) {
saveWeights(window.net, `tmp_2_count_${window.count}.weights`)
}
}
function onBatchProcessed(dataIdx, inputSize) { function onBatchProcessed(dataIdx, inputSize) {
if (((dataIdx + 1) % saveEveryNthDataIdx) === 0) { const idx = (dataIdx + 1) + (window.epoch * window.detectionFilenames.length)
saveWeights(window.net, `tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608 + dataIdx + 1}.weights`) console.log('idx', idx)
if ((idx % saveEveryNthDataIdx) === 0) {
saveWeights(window.net, `tmp__224_${startIdx224 + (inputSize === 224 ? idx : 0)}__320_${startIdx320 + (inputSize === 320 ? idx : 0)}__416_${startIdx416 + (inputSize === 416 ? idx : 0)}__608_${startIdx608 + (inputSize === 608 ? idx : 0)}.weights`)
} }
} }
...@@ -82,6 +110,7 @@ ...@@ -82,6 +110,7 @@
const batchSize = 1 const batchSize = 1
for (let i = fromEpoch; i < trainSteps; i++) { for (let i = fromEpoch; i < trainSteps; i++) {
window.epoch = i
log('step', i) log('step', i)
let ts2 = Date.now() let ts2 = Date.now()
......
...@@ -73,15 +73,17 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc ...@@ -73,15 +73,17 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
.rescale({ height: imgHeight, width: imgWidth }) .rescale({ height: imgHeight, width: imgWidth })
.rescale(scaleFactor) .rescale(scaleFactor)
const isTooTiny = box.width < 50 || box.height < 50 const isTooTiny = box.width < 40 || box.height < 40
if (isTooTiny) { if (isTooTiny && window.debug) {
log(`skipping box for input size ${inputSize}: (${Math.floor(box.width)} x ${Math.floor(box.height)})`) log(`skipping box for input size ${inputSize}: (${Math.floor(box.width)} x ${Math.floor(box.height)})`)
} }
return !isTooTiny return !isTooTiny
}) })
if (!filteredGroundTruthBoxes.length) { if (!filteredGroundTruthBoxes.length) {
if (window.debug) {
log(`no boxes for input size ${inputSize}, ${groundTruthBoxes[batchIdx].length} boxes were too small`) log(`no boxes for input size ${inputSize}, ${groundTruthBoxes[batchIdx].length} boxes were too small`)
}
batchInput.dispose() batchInput.dispose()
onBatchProcessed(dataIdx, inputSize) onBatchProcessed(dataIdx, inputSize)
return return
...@@ -89,7 +91,6 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc ...@@ -89,7 +91,6 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
let ts = Date.now() let ts = Date.now()
const loss = minimize(filteredGroundTruthBoxes, batchInput, inputSize, batch) const loss = minimize(filteredGroundTruthBoxes, batchInput, inputSize, batch)
ts = Date.now() - ts ts = Date.now() - ts
if (window.logTrainSteps) { if (window.logTrainSteps) {
log(`trainStep time for dataIdx ${dataIdx} (${inputSize}): ${ts} ms`) log(`trainStep time for dataIdx ${dataIdx} (${inputSize}): ${ts} ms`)
...@@ -109,7 +110,15 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc ...@@ -109,7 +110,15 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
await step(batchCreators.next(rescaleEveryNthBatch)) await step(batchCreators.next(rescaleEveryNthBatch))
} }
function createBatchCreators(detectionFilenames, batchSize, ) { async function fetchGroundTruthBoxesForFile(file) {
const boxes = await fetch(file).then(res => res.json())
return {
file,
boxes
}
}
function createBatchCreators(detectionFilenames, batchSize) {
if (batchSize < 1) { if (batchSize < 1) {
throw new Error('invalid batch size: ' + batchSize) throw new Error('invalid batch size: ' + batchSize)
} }
...@@ -126,9 +135,8 @@ function createBatchCreators(detectionFilenames, batchSize, ) { ...@@ -126,9 +135,8 @@ function createBatchCreators(detectionFilenames, batchSize, ) {
pushToBatch(detectionFilenames) pushToBatch(detectionFilenames)
const batchCreators = batches.map((filenamesForBatch, dataIdx) => async () => { const batchCreators = batches.map((filenamesForBatch, dataIdx) => async () => {
const groundTruthBoxes = await Promise.all(filenamesForBatch.map( const groundTruthBoxes = (await Promise.all(filenamesForBatch.map(fetchGroundTruthBoxesForFile)))
file => fetch(file).then(res => res.json()) .map(({ boxes }) => boxes)
))
const imgs = await Promise.all(filenamesForBatch.map( const imgs = await Promise.all(filenamesForBatch.map(
async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', ''))) async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', '')))
......
...@@ -138,10 +138,11 @@ ...@@ -138,10 +138,11 @@
async function run() { async function run() {
$('#imgByNr').keydown(onKeyDown) $('#imgByNr').keydown(onKeyDown)
const startIdx224 = 3220
const startIdx320 = 20688 const startIdx224 = 35060
const startIdx416 = 950 const startIdx320 = 41188
const startIdx608 = 15220 const startIdx416 = 31050
const startIdx608 = 16520
const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights` const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
...@@ -149,7 +150,7 @@ ...@@ -149,7 +150,7 @@
window.net = new faceapi.TinyYolov2(true) window.net = new faceapi.TinyYolov2(true)
await window.net.load(weights) await window.net.load(weights)
window.imgs = (await fetchDetectionFilenames()).slice(0, 100).map(f => f.replace('.json', '')) window.imgs = (await fetchDetectionFilenames()).map(f => f.replace('.json', ''))
$('#loader').hide() $('#loader').hide()
onSelectionChanged($('#selectList select').val()) onSelectionChanged($('#selectList select').val())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment