Skip to content

Commit

Permalink
refactor(SwitchNetwork): revamp RPC starkNet_switchNetwork (#368)
Browse files Browse the repository at this point in the history
* chore: revamp switch network

* chore: remove legacy code for switch network

* fix: util `getCurrentNetwork `

* chore: update network state mgr with default network config

* chore: lint fix

* chore: lint fix

* chore: rebase

* chore: update comment

* chore: update comment

* chore: use new error format

* chore: rollback snapstate change

---------

Co-authored-by: khanti42 <[email protected]>
  • Loading branch information
stanleyyconsensys and khanti42 authored Oct 10, 2024
1 parent e61eb8b commit d0384bf
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 164 deletions.
7 changes: 5 additions & 2 deletions packages/starknet-snap/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import type {
SignTransactionParams,
SignDeclareTransactionParams,
VerifySignatureParams,
SwitchNetworkParams,
} from './rpcs';
import {
displayPrivateKey,
Expand All @@ -43,10 +44,10 @@ import {
signTransaction,
signDeclareTransaction,
verifySignature,
switchNetwork,
} from './rpcs';
import { sendTransaction } from './sendTransaction';
import { signDeployAccountTransaction } from './signDeployAccountTransaction';
import { switchNetwork } from './switchNetwork';
import type {
ApiParams,
ApiParamsWithKeyDeriver,
Expand Down Expand Up @@ -230,7 +231,9 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => {
return await addNetwork(apiParams);

case 'starkNet_switchNetwork':
return await switchNetwork(apiParams);
return await switchNetwork.execute(
apiParams.requestParams as unknown as SwitchNetworkParams,
);

case 'starkNet_getCurrentNetwork':
return await getCurrentNetwork(apiParams);
Expand Down
1 change: 1 addition & 0 deletions packages/starknet-snap/src/rpcs/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ export * from './signMessage';
export * from './signTransaction';
export * from './sign-declare-transaction';
export * from './verify-signature';
export * from './switch-network';
187 changes: 187 additions & 0 deletions packages/starknet-snap/src/rpcs/switch-network.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import type { constants } from 'starknet';

import { Config } from '../config';
import { NetworkStateManager } from '../state/network-state-manager';
import type { Network } from '../types/snapState';
import {
STARKNET_SEPOLIA_TESTNET_NETWORK,
STARKNET_MAINNET_NETWORK,
} from '../utils/constants';
import {
InvalidNetworkError,
InvalidRequestParamsError,
UserRejectedOpError,
} from '../utils/exceptions';
import { prepareConfirmDialog } from './__tests__/helper';
import { switchNetwork } from './switch-network';
import type { SwitchNetworkParams } from './switch-network';

jest.mock('../utils/logger');

describe('switchNetwork', () => {
const createRequestParam = (
chainId: constants.StarknetChainId | string,
enableAuthorize?: boolean,
): SwitchNetworkParams => {
const request: SwitchNetworkParams = {
chainId: chainId as constants.StarknetChainId,
};
if (enableAuthorize) {
request.enableAuthorize = enableAuthorize;
}
return request;
};

const mockNetworkStateManager = ({
network = STARKNET_SEPOLIA_TESTNET_NETWORK,
currentNetwork = STARKNET_MAINNET_NETWORK,
}: {
network?: Network | null;
currentNetwork?: Network;
}) => {
const txStateSpy = jest.spyOn(
NetworkStateManager.prototype,
'withTransaction',
);
const getNetworkSpy = jest.spyOn(
NetworkStateManager.prototype,
'getNetwork',
);
const setCurrentNetworkSpy = jest.spyOn(
NetworkStateManager.prototype,
'setCurrentNetwork',
);
const getCurrentNetworkSpy = jest.spyOn(
NetworkStateManager.prototype,
'getCurrentNetwork',
);

getNetworkSpy.mockResolvedValue(network);
getCurrentNetworkSpy.mockResolvedValue(currentNetwork);
txStateSpy.mockImplementation(async (fn) => {
return await fn({
accContracts: [],
erc20Tokens: [],
networks: Config.availableNetworks,
transactions: [],
});
});

return { getNetworkSpy, setCurrentNetworkSpy, getCurrentNetworkSpy };
};

it('switchs a network correctly', async () => {
const currentNetwork = STARKNET_MAINNET_NETWORK;
const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK;
const { getNetworkSpy, setCurrentNetworkSpy, getCurrentNetworkSpy } =
mockNetworkStateManager({
currentNetwork,
network: requestNetwork,
});
const request = createRequestParam(requestNetwork.chainId);

const result = await switchNetwork.execute(request);

expect(result).toBe(true);
expect(getCurrentNetworkSpy).toHaveBeenCalled();
expect(getNetworkSpy).toHaveBeenCalledWith(
{ chainId: requestNetwork.chainId },
expect.anything(),
);
expect(setCurrentNetworkSpy).toHaveBeenCalledWith(requestNetwork);
});

it('returns `true` if the request chainId is the same with current network', async () => {
const currentNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK;
const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK;
const { getNetworkSpy, setCurrentNetworkSpy, getCurrentNetworkSpy } =
mockNetworkStateManager({
currentNetwork,
network: requestNetwork,
});
const request = createRequestParam(requestNetwork.chainId);

const result = await switchNetwork.execute(request);

expect(result).toBe(true);
expect(getCurrentNetworkSpy).toHaveBeenCalled();
expect(getNetworkSpy).not.toHaveBeenCalled();
expect(setCurrentNetworkSpy).not.toHaveBeenCalled();
});

it('renders confirmation dialog', async () => {
const currentNetwork = STARKNET_MAINNET_NETWORK;
const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK;
mockNetworkStateManager({
currentNetwork,
network: requestNetwork,
});
const { confirmDialogSpy } = prepareConfirmDialog();
const request = createRequestParam(requestNetwork.chainId, true);

await switchNetwork.execute(request);

expect(confirmDialogSpy).toHaveBeenCalledWith([
{ type: 'heading', value: 'Do you want to switch to this network?' },
{
type: 'row',
label: 'Chain Name',
value: {
value: requestNetwork.name,
markdown: false,
type: 'text',
},
},
{
type: 'divider',
},
{
type: 'row',
label: 'Chain ID',
value: {
value: requestNetwork.chainId,
markdown: false,
type: 'text',
},
},
]);
});

it('throws `UserRejectedRequestError` if user denied the operation', async () => {
const currentNetwork = STARKNET_MAINNET_NETWORK;
const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK;
mockNetworkStateManager({
currentNetwork,
network: requestNetwork,
});
const { confirmDialogSpy } = prepareConfirmDialog();
confirmDialogSpy.mockResolvedValue(false);
const request = createRequestParam(requestNetwork.chainId, true);

await expect(switchNetwork.execute(request)).rejects.toThrow(
UserRejectedOpError,
);
});

it('throws `Network not supported` error if the request network is not support', async () => {
const currentNetwork = STARKNET_MAINNET_NETWORK;
const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK;
// Mock the network state manager to return null network
// even if the request chain id is not block by the superstruct
mockNetworkStateManager({
currentNetwork,
network: null,
});
const request = createRequestParam(requestNetwork.chainId);

await expect(switchNetwork.execute(request)).rejects.toThrow(
InvalidNetworkError,
);
});

it('throws `InvalidRequestParamsError` when request parameter is not correct', async () => {
await expect(
switchNetwork.execute({} as unknown as SwitchNetworkParams),
).rejects.toThrow(InvalidRequestParamsError);
});
});
126 changes: 126 additions & 0 deletions packages/starknet-snap/src/rpcs/switch-network.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import type { Component } from '@metamask/snaps-sdk';
import { divider, heading, row, text } from '@metamask/snaps-sdk';
import type { Infer } from 'superstruct';
import { assign, boolean } from 'superstruct';

import { NetworkStateManager } from '../state/network-state-manager';
import {
confirmDialog,
AuthorizableStruct,
BaseRequestStruct,
RpcController,
} from '../utils';
import { InvalidNetworkError, UserRejectedOpError } from '../utils/exceptions';

export const SwitchNetworkRequestStruct = assign(
AuthorizableStruct,
BaseRequestStruct,
);

export const SwitchNetworkResponseStruct = boolean();

export type SwitchNetworkParams = Infer<typeof SwitchNetworkRequestStruct>;

export type SwitchNetworkResponse = Infer<typeof SwitchNetworkResponseStruct>;

/**
* The RPC handler to switch the network.
*/
export class SwitchNetworkRpc extends RpcController<
SwitchNetworkParams,
SwitchNetworkResponse
> {
protected requestStruct = SwitchNetworkRequestStruct;

protected responseStruct = SwitchNetworkResponseStruct;

/**
* Execute the switching network request handler.
* It switch to a supported network based on the chain id.
* It will show a confirmation dialog to the user before switching a network.
*
* @param params - The parameters of the request.
* @param [params.enableAuthorize] - Optional, a flag to enable or display the confirmation dialog to the user.
* @param params.chainId - The chain id of the network to switch.
* @returns the response of the switching a network in boolean.
* @throws {UserRejectedRequestError} If the user rejects the request.
* @throws {Error} If the network with the chain id is not supported.
*/
async execute(params: SwitchNetworkParams): Promise<SwitchNetworkResponse> {
return super.execute(params);
}

protected async handleRequest(
params: SwitchNetworkParams,
): Promise<SwitchNetworkResponse> {
const { enableAuthorize, chainId } = params;
const networkStateMgr = new NetworkStateManager();

// Using transactional state interaction to ensure that the state is updated atomically
// To avoid a use case while 2 requests are trying to update/read the state at the same time
return await networkStateMgr.withTransaction<boolean>(async (state) => {
const currentNetwork = await networkStateMgr.getCurrentNetwork(state);

// Return early if the current network is the same as the requested network
if (currentNetwork.chainId === chainId) {
return true;
}

const network = await networkStateMgr.getNetwork(
{
chainId,
},
state,
);

// if the network is not in the list of networks that we support, we throw an error
if (!network) {
throw new InvalidNetworkError() as unknown as Error;
}

if (
// Get Starknet expected show the confirm dialog, while the companion doesnt needed,
// therefore, `enableAuthorize` is to enable/disable the confirmation
enableAuthorize &&
!(await this.getSwitchNetworkConsensus(network.name, network.chainId))
) {
throw new UserRejectedOpError() as unknown as Error;
}

await networkStateMgr.setCurrentNetwork(network);

return true;
});
}

protected async getSwitchNetworkConsensus(
networkName: string,
networkChainId: string,
) {
const components: Component[] = [];
components.push(heading('Do you want to switch to this network?'));
components.push(
row(
'Chain Name',
text({
value: networkName,
markdown: false,
}),
),
);
components.push(divider());
components.push(
row(
'Chain ID',
text({
value: networkChainId,
markdown: false,
}),
),
);

return await confirmDialog(components);
}
}

export const switchNetwork = new SwitchNetworkRpc();
Loading

0 comments on commit d0384bf

Please sign in to comment.