diff --git a/package.json b/package.json index eebd27d2..a55e8d93 100644 --- a/package.json +++ b/package.json @@ -299,6 +299,20 @@ "items": { "type": "string" } + }, + "azure-api-center.tenant": { + "type": "object", + "description": "A specific tenant to sign in to", + "properties": { + "name": { + "type": "string", + "description": "tenant name" + }, + "id": { + "type": "string", + "description": "tenant id" + } + } } } } diff --git a/src/azure/azureLogin/azureAccount.ts b/src/azure/azureLogin/azureAccount.ts index 852ec0a4..5202d78b 100644 --- a/src/azure/azureLogin/azureAccount.ts +++ b/src/azure/azureLogin/azureAccount.ts @@ -2,7 +2,8 @@ // Licensed under the MIT license. import { SubscriptionClient, TenantIdDescription } from "@azure/arm-resources-subscriptions"; import { TokenCredential } from "@azure/core-auth"; -import { AuthenticationSession, QuickPickItem, Uri, env, window } from "vscode"; +import { AuthenticationSession, ConfigurationTarget, QuickPickItem, Uri, env, window, workspace } from "vscode"; +import { extensionName, tenantSetting } from "../../constants"; import { UiStrings } from "../../uiStrings"; import { GeneralUtils } from "../../utils/generalUtils"; import { SelectionType, SignInStatus, SubscriptionFilter, Tenant } from "./authTypes"; @@ -14,6 +15,14 @@ export namespace AzureAccount { await AzureSessionProviderHelper.getSessionProvider().signIn(); } + export function getSelectedTenant(): Tenant | undefined { + return workspace.getConfiguration(extensionName).get(tenantSetting); + } + + export async function updateSelectedTenant(value?: Tenant): Promise { + await workspace.getConfiguration(extensionName).update(tenantSetting, value, ConfigurationTarget.Global, true); + } + export async function selectTenant(): Promise { const sessionProvider = AzureSessionProviderHelper.getSessionProvider(); if (sessionProvider.signInStatus !== SignInStatus.SignedIn) { @@ -42,6 +51,7 @@ export namespace AzureAccount { } sessionProvider.selectedTenant = selectedTenant; + await updateSelectedTenant(selectedTenant); } type SubscriptionQuickPickItem = QuickPickItem & { subscription: SubscriptionFilter }; diff --git a/src/azure/azureLogin/azureSessionProvider.ts b/src/azure/azureLogin/azureSessionProvider.ts index a37df6e8..aef3d037 100644 --- a/src/azure/azureLogin/azureSessionProvider.ts +++ b/src/azure/azureLogin/azureSessionProvider.ts @@ -143,6 +143,9 @@ export namespace AzureSessionProviderHelper { this.tenants = newTenants; this.signInStatusValue = newSignInStatus; if (signInStatusChanged || tenantsChanged || selectedTenantChanged) { + if (newSignInStatus === SignInStatus.SignedOut) { + await AzureAccount.updateSelectedTenant(); + } this.onSignInStatusChangeEmitter.fire(this.signInStatusValue); } } @@ -220,7 +223,11 @@ export namespace AzureSessionProviderHelper { ); const results = await Promise.all(getSessionPromises); const accessibleTenants = results.filter(GeneralUtils.succeeded).map((r) => r.result); - return accessibleTenants.length === 1 ? AzureAccount.findTenant(tenants, accessibleTenants[0].tenantId) : null; + if (accessibleTenants.length === 1) { + return AzureAccount.findTenant(tenants, accessibleTenants[0].tenantId); + } + const lastTenant = AzureAccount.getSelectedTenant(); + return lastTenant && accessibleTenants.some(item => item.tenantId === lastTenant.id) ? lastTenant : null; } private async getArmSession( diff --git a/src/constants.ts b/src/constants.ts index 21917fb3..36cde156 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -16,6 +16,7 @@ export const azureApiGuidelineRulesetFile = "https://raw.githubusercontent.com/a export const spectralOwaspRulesetFile = "https://unpkg.com/@stoplight/spectral-owasp-ruleset/dist/ruleset.mjs"; export const MODEL_SELECTOR: vscode.LanguageModelChatSelector = { vendor: 'copilot', family: 'gpt-4' }; export const ExceedTokenLimit = "Message exceeds token limit"; +export const tenantSetting: string = 'tenant'; export const AzureAccountType = { createAzureAccount: "azureapicenterCreateAzureAccount", diff --git a/src/test/unit/azure/azureLogin/azureAccount.test.ts b/src/test/unit/azure/azureLogin/azureAccount.test.ts new file mode 100644 index 00000000..d7726ad9 --- /dev/null +++ b/src/test/unit/azure/azureLogin/azureAccount.test.ts @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +import * as assert from "assert"; +import * as sinon from "sinon"; +import * as vscode from "vscode"; +import { AzureAccount } from "../../../../azure/azureLogin/azureAccount"; + +describe("Azure Account test case", () => { + let sandbox = null as any; + before(() => { + sandbox = sinon.createSandbox(); + }); + afterEach(() => { + sandbox.restore(); + }); + it("getSelectedTenant happy path", async () => { + let spyConf = sandbox.stub(vscode.workspace, "getConfiguration").returns({ + get: () => { + return "test"; + }, + } as any); + let res = await AzureAccount.getSelectedTenant(); + sandbox.assert.calledOnce(spyConf); + assert.strictEqual(res, "test"); + }); +});