diff --git a/sdks/node/src/__tests__/clientApi.test.ts b/sdks/node/src/__tests__/clientApi.test.ts index 90c36a5c66..701639625f 100644 --- a/sdks/node/src/__tests__/clientApi.test.ts +++ b/sdks/node/src/__tests__/clientApi.test.ts @@ -17,7 +17,7 @@ describe('Client API', () => { defaultEntityName: jest.fn().mockResolvedValue('test-entity'), })); - const client = await init('test-project'); + const client = await init({ project: 'test-project' }); const gottenClient = requireGlobalClient(); expect(gottenClient).toBeDefined(); diff --git a/sdks/node/src/__tests__/dataset.test.ts b/sdks/node/src/__tests__/dataset.test.ts index 9aaba3c6c9..5cf81d33b2 100644 --- a/sdks/node/src/__tests__/dataset.test.ts +++ b/sdks/node/src/__tests__/dataset.test.ts @@ -3,7 +3,7 @@ import { Dataset } from '../dataset'; describe('Dataset', () => { test('should save a dataset', async () => { - const client = await init('test-project'); + const client = await init({ project: 'test-project' }); const data = [ { id: 1, value: 2 }, { id: 2, value: 3 }, diff --git a/sdks/node/src/__tests__/table.test.ts b/sdks/node/src/__tests__/table.test.ts index 6a9d2bcf27..6ecd4c3eb3 100644 --- a/sdks/node/src/__tests__/table.test.ts +++ b/sdks/node/src/__tests__/table.test.ts @@ -19,7 +19,7 @@ describe('table', () => { } // Saving the table generates refs for the table and its rows - const client = await init('test-project'); + const client = await init({ project: 'test-project' }); (client as any).saveTable(table); // TODO: Saving a Table is not public... but maybe it should be? const ref = await table.__savedRef; diff --git a/sdks/node/src/__tests__/weaveObject.test.ts b/sdks/node/src/__tests__/weaveObject.test.ts index a53c00e0de..c035f61c50 100644 --- a/sdks/node/src/__tests__/weaveObject.test.ts +++ b/sdks/node/src/__tests__/weaveObject.test.ts @@ -28,7 +28,7 @@ describe('weaveObject', () => { }); test('class-example', async () => { - const client = await init('test-project'); + const client = await init({ project: 'test-project' }); const obj = new ExampleObject('test', 1); // save an object diff --git a/sdks/node/src/clientApi.ts b/sdks/node/src/clientApi.ts index 15b803cd82..3ce4520ca8 100644 --- a/sdks/node/src/clientApi.ts +++ b/sdks/node/src/clientApi.ts @@ -5,13 +5,23 @@ import { getApiKey } from './wandb/settings'; import { WandbServerApi } from './wandb/wandbServerApi'; import { CallStackEntry, WeaveClient } from './weaveClient'; +export interface InitOptions { + project: string; + entity?: string; + projectName?: string; + host?: string; + apiKey?: string; +} + // Global client instance export let globalClient: WeaveClient | null = null; -export async function init(projectName: string): Promise { - const host = 'https://api.wandb.ai'; - const apiKey = getApiKey(); - +export async function init({ + project, + entity, + host = 'https://api.wandb.ai', + apiKey = getApiKey(), +}: InitOptions): Promise { const headers: Record = { 'User-Agent': `W&B Internal JS Client ${process.env.VERSION || 'unknown'}`, Authorization: `Basic ${Buffer.from(`api:${apiKey}`).toString('base64')}`, @@ -20,7 +30,8 @@ export async function init(projectName: string): Promise { try { const wandbServerApi = new WandbServerApi(host, apiKey); const defaultEntityName = await wandbServerApi.defaultEntityName(); - const projectId = `${defaultEntityName}/${projectName}`; + const entityName = entity ?? defaultEntityName; + const projectId = `${entityName}/${project}`; const retryFetch = createFetchWithRetry({ baseDelay: 1000,