Commit 656602f7 by vincent

fixed loadweightmap always fetching from relative url

parent 7c3c58f9
import * as tf from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs-core';
export function getModelUris(uri: string | undefined, defaultModelName: string) { export function getModelUris(uri: string | undefined, defaultModelName: string) {
const parts = (uri || '').split('/')
const modelBaseUri = (
(uri || '').endsWith('.json')
? parts.slice(0, parts.length - 1)
: parts
).filter(s => s).join('/')
const defaultManifestFilename = `${defaultModelName}-weights_manifest.json` const defaultManifestFilename = `${defaultModelName}-weights_manifest.json`
const manifestUri = !uri || !modelBaseUri
? defaultManifestFilename if (!uri) {
: ( return {
uri.endsWith('.json') modelBaseUri: '',
? uri manifestUri: defaultManifestFilename
: `${modelBaseUri}/${defaultManifestFilename}` }
) }
return { manifestUri, modelBaseUri } if (uri === '/') {
return {
modelBaseUri: '/',
manifestUri: `/${defaultManifestFilename}`
}
}
const parts = uri.split('/').filter(s => s)
const manifestFile = uri.endsWith('.json')
? parts[parts.length - 1]
: defaultManifestFilename
let modelBaseUri = (uri.endsWith('.json') ? parts.slice(0, parts.length - 1) : parts).join('/')
modelBaseUri = uri.startsWith('/') ? `/${modelBaseUri}` : modelBaseUri
return {
modelBaseUri,
manifestUri: modelBaseUri === '/' ? `/${manifestFile}` : `${modelBaseUri}/${manifestFile}`
}
} }
export async function loadWeightMap( export async function loadWeightMap(
......
import { getModelUris } from '../../../src/commons/loadWeightMap'; import { getModelUris } from '../../../src/commons/loadWeightMap';
const FAKE_DEFAULT_MODEL_NAME = 'default_model_name' const FAKE_DEFAULT_MODEL_NAME = 'fake_model_name'
describe('loadWeightMap', () => { describe('loadWeightMap', () => {
describe('getModelUris', () => { describe('getModelUris', () => {
it('returns uris from top level url if no argument passed', () => { it('returns uris from relative url if no argument passed', () => {
const result = getModelUris(undefined, FAKE_DEFAULT_MODEL_NAME) const result = getModelUris(undefined, FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`) expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
expect(result.modelBaseUri).toEqual('') expect(result.modelBaseUri).toEqual('')
}) })
it('returns uris from top level url for empty string', () => { it('returns uris from relative url for empty string', () => {
const result = getModelUris('', FAKE_DEFAULT_MODEL_NAME) const result = getModelUris('', FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`) expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
expect(result.modelBaseUri).toEqual('') expect(result.modelBaseUri).toEqual('')
}) })
it('returns uris for top level url', () => { it('returns uris for top level url, leading slash preserved', () => {
const result = getModelUris('/', FAKE_DEFAULT_MODEL_NAME) const result = getModelUris('/', FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(`${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`) expect(result.manifestUri).toEqual(`/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
expect(result.modelBaseUri).toEqual('') expect(result.modelBaseUri).toEqual('/')
}) })
it('returns uris, given url path', () => { it('returns uris, given url path', () => {
...@@ -35,8 +35,8 @@ describe('loadWeightMap', () => { ...@@ -35,8 +35,8 @@ describe('loadWeightMap', () => {
expect(result.modelBaseUri).toEqual(uri) expect(result.modelBaseUri).toEqual(uri)
}) })
it('returns uris, given url path, leading slash', () => { it('returns uris, given url path, leading slash preserved', () => {
const uri = 'path/to/modelfiles' const uri = '/path/to/modelfiles'
const result = getModelUris(`/${uri}`, FAKE_DEFAULT_MODEL_NAME) const result = getModelUris(`/${uri}`, FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(`${uri}/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`) expect(result.manifestUri).toEqual(`${uri}/${FAKE_DEFAULT_MODEL_NAME}-weights_manifest.json`)
...@@ -51,6 +51,14 @@ describe('loadWeightMap', () => { ...@@ -51,6 +51,14 @@ describe('loadWeightMap', () => {
expect(result.modelBaseUri).toEqual('path/to/modelfiles') expect(result.modelBaseUri).toEqual('path/to/modelfiles')
}) })
it('returns uris, given manifest uri, leading slash preserved', () => {
const uri = '/path/to/modelfiles/model-weights_manifest.json'
const result = getModelUris(uri, FAKE_DEFAULT_MODEL_NAME)
expect(result.manifestUri).toEqual(uri)
expect(result.modelBaseUri).toEqual('/path/to/modelfiles')
})
}) })
}) })
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment