Unverified Commit 9c30d353 by justadudewhohacks Committed by GitHub

Merge pull request #60 from justadudewhohacks/fixes

fixed loadweightmap always fetching from relative url + bump tfjs-core version
parents 7c3c58f9 3eebacf0
export declare function getModelUris(uri: string | undefined, defaultModelName: string): { export declare function getModelUris(uri: string | undefined, defaultModelName: string): {
manifestUri: string;
modelBaseUri: string; modelBaseUri: string;
manifestUri: string;
}; };
export declare function loadWeightMap(uri: string | undefined, defaultModelName: string): Promise<any>; export declare function loadWeightMap(uri: string | undefined, defaultModelName: string): Promise<any>;
...@@ -3,17 +3,29 @@ Object.defineProperty(exports, "__esModule", { value: true }); ...@@ -3,17 +3,29 @@ Object.defineProperty(exports, "__esModule", { value: true });
var tslib_1 = require("tslib"); var tslib_1 = require("tslib");
var tf = require("@tensorflow/tfjs-core"); var tf = require("@tensorflow/tfjs-core");
function getModelUris(uri, defaultModelName) { function getModelUris(uri, defaultModelName) {
var parts = (uri || '').split('/');
var modelBaseUri = ((uri || '').endsWith('.json')
? parts.slice(0, parts.length - 1)
: parts).filter(function (s) { return s; }).join('/');
var defaultManifestFilename = defaultModelName + "-weights_manifest.json"; var defaultManifestFilename = defaultModelName + "-weights_manifest.json";
var manifestUri = !uri || !modelBaseUri if (!uri) {
? defaultManifestFilename return {
: (uri.endsWith('.json') modelBaseUri: '',
? uri manifestUri: defaultManifestFilename
: modelBaseUri + "/" + defaultManifestFilename); };
return { manifestUri: manifestUri, modelBaseUri: modelBaseUri }; }
if (uri === '/') {
return {
modelBaseUri: '/',
manifestUri: "/" + defaultManifestFilename
};
}
var parts = uri.split('/').filter(function (s) { return s; });
var manifestFile = uri.endsWith('.json')
? parts[parts.length - 1]
: defaultManifestFilename;
var modelBaseUri = (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/');
modelBaseUri = uri.startsWith('/') ? "/" + modelBaseUri : modelBaseUri;
return {
modelBaseUri: modelBaseUri,
manifestUri: modelBaseUri === '/' ? "/" + manifestFile : modelBaseUri + "/" + manifestFile
};
} }
exports.getModelUris = getModelUris; exports.getModelUris = getModelUris;
function loadWeightMap(uri, defaultModelName) { function loadWeightMap(uri, defaultModelName) {
......
{"version":3,"file":"loadWeightMap.js","sourceRoot":"","sources":["../../src/commons/loadWeightMap.ts"],"names":[],"mappings":";;;AAAA,0CAA4C;AAE5C,sBAA6B,GAAuB,EAAE,gBAAwB;IAC5E,IAAM,KAAK,GAAG,CAAC,GAAG,IAAI,EAAE,CAAC,CAAC,KAAK,CAAC,GAAG,CAAC,CAAA;IAEpC,IAAM,YAAY,GAAG,CACnB,CAAC,GAAG,IAAI,EAAE,CAAC,CAAC,QAAQ,CAAC,OAAO,CAAC;QAC3B,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;QAClC,CAAC,CAAC,KAAK,CACV,CAAC,MAAM,CAAC,UAAA,CAAC,IAAI,OAAA,CAAC,EAAD,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;IAE1B,IAAM,uBAAuB,GAAM,gBAAgB,2BAAwB,CAAA;IAC3E,IAAM,WAAW,GAAG,CAAC,GAAG,IAAI,CAAC,YAAY;QACvC,CAAC,CAAC,uBAAuB;QACzB,CAAC,CAAC,CACA,GAAG,CAAC,QAAQ,CAAC,OAAO,CAAC;YACnB,CAAC,CAAC,GAAG;YACL,CAAC,CAAI,YAAY,SAAI,uBAAyB,CACjD,CAAA;IAEH,OAAO,EAAE,WAAW,aAAA,EAAE,YAAY,cAAA,EAAE,CAAA;AACtC,CAAC;AAnBD,oCAmBC;AAED,uBACE,GAAuB,EACvB,gBAAwB;;;;;;oBAGlB,KAAgC,YAAY,CAAC,GAAG,EAAE,gBAAgB,CAAC,EAAjE,WAAW,iBAAA,EAAE,YAAY,kBAAA,CAAwC;oBAEjD,qBAAM,KAAK,CAAC,WAAW,CAAC,EAAA;wBAA/B,qBAAM,CAAC,SAAwB,CAAC,CAAC,IAAI,EAAE,EAAA;;oBAAlD,QAAQ,GAAG,SAAuC;oBAExD,sBAAO,EAAE,CAAC,EAAE,CAAC,WAAW,CAAC,QAAQ,EAAE,YAAY,CAAC,EAAA;;;;CACjD;AAVD,sCAUC"} {"version":3,"file":"loadWeightMap.js","sourceRoot":"","sources":["../../src/commons/loadWeightMap.ts"],"names":[],"mappings":";;;AAAA,0CAA4C;AAE5C,sBAA6B,GAAuB,EAAE,gBAAwB;IAC5E,IAAM,uBAAuB,GAAM,gBAAgB,2BAAwB,CAAA;IAE3E,IAAI,CAAC,GAAG,EAAE;QACR,OAAO;YACL,YAAY,EAAE,EAAE;YAChB,WAAW,EAAE,uBAAuB;SACrC,CAAA;KACF;IAED,IAAI,GAAG,KAAK,GAAG,EAAE;QACf,OAAO;YACL,YAAY,EAAE,GAAG;YACjB,WAAW,EAAE,MAAI,uBAAyB;SAC3C,CAAA;KACF;IAED,IAAM,KAAK,GAAG,GAAG,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,MAAM,CAAC,UAAA,CAAC,IAAI,OAAA,CAAC,EAAD,CAAC,CAAC,CAAA;IAE3C,IAAM,YAAY,GAAG,GAAG,CAAC,QAAQ,CAAC,OAAO,CAAC;QACxC,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC;QACzB,CAAC,CAAC,uBAAuB,CAAA;IAE3B,IAAI,YAAY,GAAG,CAAC,GAAG,CAAC,QAAQ,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAA;IAC/F,YAAY,GAAG,GAAG,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,MAAI,YAAc,CAAC,CAAC,CAAC,YAAY,CAAA;IAEtE,OAAO;QACL,YAAY,cAAA;QACZ,WAAW,EAAE,YAAY,KAAK,GAAG,CAAC,CAAC,CAAC,MAAI,YAAc,CAAC,CAAC,CAAI,YAAY,SAAI,YAAc;KAC3F,CAAA;AACH,CAAC;AA9BD,oCA8BC;AAED,uBACE,GAAuB,EACvB,gBAAwB;;;;;;oBAGlB,KAAgC,YAAY,CAAC,GAAG,EAAE,gBAAgB,CAAC,EAAjE,WAAW,iBAAA,EAAE,YAAY,kBAAA,CAAwC;oBAEjD,qBAAM,KAAK,CAAC,WAAW,CAAC,EAAA;wBAA/B,qBAAM,CAAC,SAAwB,CAAC,CAAC,IAAI,EAAE,EAAA;;oBAAlD,QAAQ,GAAG,SAAuC;oBAExD,sBAAO,EAAE,CAAC,EAAE,CAAC,WAAW,CAAC,QAAQ,EAAE,YAAY,CAAC,EAAA;;;;CACjD;AAVD,sCAUC"}
\ No newline at end of file \ No newline at end of file
...@@ -1165,17 +1165,29 @@ ...@@ -1165,17 +1165,29 @@
} }
function getModelUris(uri, defaultModelName) { function getModelUris(uri, defaultModelName) {
var parts = (uri || '').split('/');
var modelBaseUri = ((uri || '').endsWith('.json')
? parts.slice(0, parts.length - 1)
: parts).filter(function (s) { return s; }).join('/');
var defaultManifestFilename = defaultModelName + "-weights_manifest.json"; var defaultManifestFilename = defaultModelName + "-weights_manifest.json";
var manifestUri = !uri || !modelBaseUri if (!uri) {
? defaultManifestFilename return {
: (uri.endsWith('.json') modelBaseUri: '',
? uri manifestUri: defaultManifestFilename
: modelBaseUri + "/" + defaultManifestFilename); };
return { manifestUri: manifestUri, modelBaseUri: modelBaseUri }; }
if (uri === '/') {
return {
modelBaseUri: '/',
manifestUri: "/" + defaultManifestFilename
};
}
var parts = uri.split('/').filter(function (s) { return s; });
var manifestFile = uri.endsWith('.json')
? parts[parts.length - 1]
: defaultManifestFilename;
var modelBaseUri = (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/');
modelBaseUri = uri.startsWith('/') ? "/" + modelBaseUri : modelBaseUri;
return {
modelBaseUri: modelBaseUri,
manifestUri: modelBaseUri === '/' ? "/" + manifestFile : modelBaseUri + "/" + manifestFile
};
} }
function loadWeightMap(uri, defaultModelName) { function loadWeightMap(uri, defaultModelName) {
return __awaiter$1(this, void 0, void 0, function () { return __awaiter$1(this, void 0, void 0, function () {
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -2,6 +2,32 @@ ...@@ -2,6 +2,32 @@
"requires": true, "requires": true,
"lockfileVersion": 1, "lockfileVersion": 1,
"dependencies": { "dependencies": {
"@tensorflow/tfjs-core": {
"version": "0.12.7",
"resolved": "https://registry.npmjs.org/@tensorflow/tfjs-core/-/tfjs-core-0.12.7.tgz",
"integrity": "sha512-rKwPygxC1hRyCcR7lX9PMcXCwcQ1lNnDMvQd43C/f7QaMQL090H1l7zNyPh8HD4dX28W2f/A0tCpkGRr2fZ48w==",
"requires": {
"@types/seedrandom": "2.4.27",
"@types/webgl-ext": "0.0.29",
"@types/webgl2": "0.0.4",
"seedrandom": "2.4.3"
}
},
"@types/seedrandom": {
"version": "2.4.27",
"resolved": "https://registry.npmjs.org/@types/seedrandom/-/seedrandom-2.4.27.tgz",
"integrity": "sha1-nbVjk33YaRX2kJK8QyWdL0hXjkE="
},
"@types/webgl-ext": {
"version": "0.0.29",
"resolved": "https://registry.npmjs.org/@types/webgl-ext/-/webgl-ext-0.0.29.tgz",
"integrity": "sha512-ZlVjDQU5Vlc9hF4LGdDldujZUf0amwlwGv1RI2bfvdrEHIl6X/7MZVpemJUjS7NxD9XaKfE8SlFrxsfXpUkt/A=="
},
"@types/webgl2": {
"version": "0.0.4",
"resolved": "https://registry.npmjs.org/@types/webgl2/-/webgl2-0.0.4.tgz",
"integrity": "sha512-PACt1xdErJbMUOUweSrbVM7gSIYm1vTncW2hF6Os/EeWi6TXYAYMPp+8v6rzHmypE5gHrxaxZNXgMkJVIdZpHw=="
},
"accepts": { "accepts": {
"version": "1.3.5", "version": "1.3.5",
"resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz", "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz",
...@@ -550,6 +576,11 @@ ...@@ -550,6 +576,11 @@
"resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz",
"integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==" "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg=="
}, },
"seedrandom": {
"version": "2.4.3",
"resolved": "https://registry.npmjs.org/seedrandom/-/seedrandom-2.4.3.tgz",
"integrity": "sha1-JDhQTa0zkXMUv/GKxNeU8W1qrsw="
},
"send": { "send": {
"version": "0.16.2", "version": "0.16.2",
"resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz", "resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz",
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"author": "justadudewhohacks", "author": "justadudewhohacks",
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"@tensorflow/tfjs-core": "^0.12.7",
"express": "^4.16.3", "express": "^4.16.3",
"request": "^2.87.0" "request": "^2.87.0"
} }
......
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
export function getModelUris(uri: string | undefined, defaultModelName: string) { export function getModelUris(uri: string | undefined, defaultModelName: string) {
const parts = (uri || '').split('/')
const modelBaseUri = (
(uri || '').endsWith('.json')
? parts.slice(0, parts.length - 1)
: parts
).filter(s => s).join('/')
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json` const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`
const manifestUri = !uri || !modelBaseUri
? defaultManifestFilename if (!uri) {
: ( return {
uri.endsWith('.json') modelBaseUri: '',
? uri manifestUri: defaultManifestFilename
: `${modelBaseUri}/${defaultManifestFilename}` }
) }
return { manifestUri, modelBaseUri } if (uri === '/') {
return {
modelBaseUri: '/',
manifestUri: `/${defaultManifestFilename}`
}
}
const parts = uri.split('/').filter(s => s)
const manifestFile = uri.endsWith('.json')
? parts[parts.length - 1]
: defaultManifestFilename
let modelBaseUri = (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/')
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri
return {
modelBaseUri,
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`
}
} }
export async function loadWeightMap( export async function loadWeightMap(
......
import { getModelUris } from '../../../src/commons/loadWeightMap'; import { getModelUris } from '../../../src/commons/loadWeightMap';
const FAKE_DEFAULT_MODEL_NAME = 'default_model_name' const FAKE_DEFAULT_MODEL_NAME = 'fake_model_name'
describe('loadWeightMap', () => { describe('loadWeightMap', () => {
describe('getModelUris', () => { describe('getModelUris', () => {
it('returns uris from top level url if no argument passed', () => { it('returns uris from relative url if no argument passed', () => {
const result = getModelUris(undefined, FAKE_DEFAULT_MODEL_NAME) const result = getModelUris(undefined, FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`) expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
expect(result.modelBaseUri).toEqual('') expect(result.modelBaseUri).toEqual('')
}) })
it('returns uris from top level url for empty string', () => { it('returns uris from relative url for empty string', () => {
const result = getModelUris('', FAKE_DEFAULT_MODEL_NAME) const result = getModelUris('', FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`) expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
expect(result.modelBaseUri).toEqual('') expect(result.modelBaseUri).toEqual('')
}) })
it('returns uris for top level url', () => { it('returns uris for top level url, leading slash preserved', () => {
const result = getModelUris('/', FAKE_DEFAULT_MODEL_NAME) const result = getModelUris('/', FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`) expect(result.manifestUri).toEqual(`/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
expect(result.modelBaseUri).toEqual('') expect(result.modelBaseUri).toEqual('/')
}) })
it('returns uris, given url path', () => { it('returns uris, given url path', () => {
...@@ -35,8 +35,8 @@ describe('loadWeightMap', () => { ...@@ -35,8 +35,8 @@ describe('loadWeightMap', () => {
expect(result.modelBaseUri).toEqual(uri) expect(result.modelBaseUri).toEqual(uri)
}) })
it('returns uris, given url path, leading slash', () => { it('returns uris, given url path, leading slash preserved', () => {
const uri = 'path/to/modelfiles' const uri = '/path/to/modelfiles'
const result = getModelUris(`/${uri}`, FAKE_DEFAULT_MODEL_NAME) const result = getModelUris(`/${uri}`, FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(`${uri}/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`) expect(result.manifestUri).toEqual(`${uri}/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
...@@ -51,6 +51,14 @@ describe('loadWeightMap', () => { ...@@ -51,6 +51,14 @@ describe('loadWeightMap', () => {
expect(result.modelBaseUri).toEqual('path/to/modelfiles') expect(result.modelBaseUri).toEqual('path/to/modelfiles')
}) })
it('returns uris, given manifest uri, leading slash preserved', () => {
const uri = '/path/to/modelfiles/model-weights_manifest.json'
const result = getModelUris(uri, FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(uri)
expect(result.modelBaseUri).toEqual('/path/to/modelfiles')
})
}) })
}) })
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment