Skip to content

Commit

Permalink
Fix incorrect balance with rebasing tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
xhad committed Mar 24, 2024
1 parent bbc586b commit 8878826
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 53 deletions.
82 changes: 48 additions & 34 deletions src/ynLSD.sol
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
ILSDStakingNode[] public nodes;
uint256 public maxNodeCount;

/// @notice Mapping of asset balances held directly by the contract.
mapping(address => uint256) public balances;

//--------------------------------------------------------------------------------------
//---------------------------------- INITIALIZATION ----------------------------------
//--------------------------------------------------------------------------------------
Expand Down Expand Up @@ -159,8 +156,6 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
return _deposit(asset, amount, receiver, msg.sender);
}

event LogUint (string messgae, uint256 value);

function _deposit(
IERC20 asset,
uint256 amount,
Expand All @@ -176,25 +171,39 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
if (amount == 0) {
revert ZeroAmount();
}

uint256 previousTotalAssets = totalAssets();
uint256 assetAmountInETH = _transferAsset(asset, sender, amount);

asset.safeTransferFrom(sender, address(this), amount);

uint256 newBalance = IERC20(asset).balanceOf(address(this));

uint256 previousBalance = balances[address(asset)];

uint256 userAdjustedDeposit = newBalance > previousBalance
? newBalance - previousBalance
: newBalance < previousBalance
? previousBalance - newBalance
: newBalance;
// Calculate how many shares to be minted using the same formula as ynETH
shares = _convertToShares(assetAmountInETH, previousTotalAssets, Math.Rounding.Floor);

balances[address(asset)] = newBalance;
shares = convertToShares(asset, userAdjustedDeposit);
// Mint the calculated shares to the receiver
_mint(receiver, shares);

emit Deposit(sender, receiver, amount, shares);
}

/**
* @notice safeTransferFrom that returns ETH conversion and handles rebasing token deflation.
* @param asset The ERC20 asset to be transferred.
* @param sender The address of the sender.
* @param amount The amount of the asset to be transferred.
* @return assetAmountInETH The new balance of the users assets converted to ETH.
*/
function _transferAsset(IERC20 asset, address sender, uint256 amount) private returns (uint256 assetAmountInETH) {

uint256 previousBalance = asset.balanceOf(address(this));

asset.safeTransferFrom(sender, address(this), amount);
uint256 balance = asset.balanceOf(address(this));

if (balance < previousBalance + amount) {
uint256 difference = previousBalance + amount - balance;
assetAmountInETH = convertToETH(asset, amount - difference);
} else {
assetAmountInETH = convertToETH(asset, amount);
}
}

/**
* @dev Converts an ETH amount to shares based on the current exchange rate and specified rounding method.
Expand All @@ -203,10 +212,11 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
* This calculation can result in 0 during the bootstrap phase if `totalControlled` and `ynETHSupply` could be
* manipulated independently, which should not be possible.
* @param ethAmount The amount of ETH to convert to shares.
* @param preTotalAssets The total assets before the deposit.
* @param rounding The rounding method to use for the calculation.
* @return The number of shares equivalent to the given ETH amount.
* @return uint256 number of shares equivalent to the given ETH amount.
*/
function _convertToShares(uint256 ethAmount, Math.Rounding rounding) internal view returns (uint256) {
function _convertToShares(uint256 ethAmount, uint256 preTotalAssets, Math.Rounding rounding) internal view returns (uint256) {
// 1:1 exchange rate on the first stake.
// Use totalSupply to see if this is the bootstrap call, not totalAssets
if (totalSupply() == 0) {
Expand All @@ -220,19 +230,20 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
uint256 shares = Math.mulDiv(
ethAmount,
totalSupply(),
totalAssets(),
preTotalAssets,
rounding
);


return shares;
}

}

/// @notice Calculates the amount of shares to be minted for a given deposit.
/// @param asset The asset to be deposited.
/// @param amount The amount of asset to be deposited.
/// @return The amount of shares to be minted.
/**
* @notice Calculates the amount of shares to be minted for a given deposit.
* @param asset The asset to be deposited.
* @param amount The amount of asset to be deposited.
* @return uint256 of shares to be minted.
**/
function previewDeposit(IERC20 asset, uint256 amount) public view virtual returns (uint256) {
return convertToShares(asset, amount);
}
Expand All @@ -241,7 +252,7 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
* @notice This function calculates the total assets of the contract
* @dev It iterates over all the assets in the contract, gets the latest price for each asset from the oracle,
* multiplies it with the balance of the asset and adds it to the total
* @return total The total assets of the contract in the form of uint
* @return uint256 The total assets of the contract in the form of uint
*/
function totalAssets() public view returns (uint256) {
uint256 total = 0;
Expand All @@ -254,7 +265,7 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
return total;
}

/**
/**
* @notice Converts a given amount of a specific asset to shares
* @param asset The ERC-20 asset to be converted
* @param amount The amount of the asset to be converted
Expand All @@ -264,7 +275,7 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
IStrategy strategy = strategies[asset];
if(address(strategy) != address(0)){
uint256 assetAmountInETH = convertToETH(asset, amount);
shares = _convertToShares(assetAmountInETH, Math.Rounding.Floor);
shares = _convertToShares(assetAmountInETH, totalAssets(), Math.Rounding.Floor);
} else {
revert UnsupportedAsset(asset);
}
Expand All @@ -286,7 +297,10 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {

// Add balances for funds held directly in ynLSD.
for (uint256 i = 0; i < assets.length; i++) {
assetBalances[i] = balances[address(assets[i])];
assetStrategies[i] = strategies[assets[i]];

uint256 balanceThis = assets[i].balanceOf(address(this));
assetBalances[i] += balanceThis;
}

// Add balances contained in each LSDStakingNode, including those managed by strategies.
Expand All @@ -310,7 +324,7 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
* @dev This function takes into account the decimal places of the asset to ensure accurate conversion.
* @param asset The ERC20 token to be converted to ETH.
* @param amount The amount of the asset to be converted.
* @return The equivalent amount of the asset in ETH.
* @return uint256 equivalent amount of the asset in ETH.
*/
function convertToETH(IERC20 asset, uint amount) public view returns (uint256) {
uint256 assetPriceInETH = oracle.getLatestPrice(address(asset));
Expand Down Expand Up @@ -448,7 +462,7 @@ contract ynLSD is IynLSD, ynBase, ReentrancyGuardUpgradeable, IynLSDEvents {
}

IERC20(asset).safeTransfer(msg.sender, amount);
emit AssetRetrieved(asset, amount, nodeId, msg.sender);
emit AssetRetrieved(asset, IERC20(asset).balanceOf(msg.sender), nodeId, msg.sender);
}

//--------------------------------------------------------------------------------------
Expand Down
7 changes: 2 additions & 5 deletions test/foundry/integration/LSDStakingNode.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,8 @@ contract LSDStakingNodeTest is IntegrationBaseTest {
(bool success, ) = chainAddresses.lsd.STETH_ADDRESS.call{value: amount}("");
require(success, "ETH transfer failed");
uint256 balance = stETH.balanceOf(address(this));
assertEq(compareWithThreshold(balance, amount, 1), true, "Amount not received");
stETH.approve(address(ynlsd), amount);
ynlsd.deposit(stETH, amount, address(this));

assertEq(ynlsd.balanceOf(address(this)), balance, "User should have staked stETH");
stETH.approve(address(ynlsd), balance);
ynlsd.deposit(stETH, balance, address(this));

// 2. Deposit should fail when paused
IERC20[] memory assets = new IERC20[](1);
Expand Down
31 changes: 17 additions & 14 deletions test/foundry/scenarios/ynLSD.spec.sol
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,37 @@ contract YnLSDScenarioTest1 is IntegrationBaseTest {
User_stETH_deposit(asset, amount3, address(0x03));
}

function User_stETH_deposit(address asset, uint256 random3, address user) public {
function User_stETH_deposit(address asset, uint256 amount, address user) public {

vm.assume(random3 > 0 && random3 < 10_000 ether);
vm.assume(amount > 1 && amount < 10_000 ether);

uint256 previousTotalShares = ynlsd.totalSupply();
uint256 previousTotalDeposited = ynlsd.balances(asset);
// uint256 previousTotalDeposited = ynlsd.totalAssets();
uint256 previousTotalAssets = ynlsd.getTotalAssets()[0];

uint256 userDeposit = random3;
(bool success,) = asset.call{ value: userDeposit }("");
vm.startPrank(user);
vm.deal(user, amount);
(bool success,) = asset.call{ value: amount }("");
require(success, "ETH transfer failed");
IERC20 steth = IERC20(asset);

uint256 userDeposit = IERC20(asset).balanceOf(user);

steth.approve(address(ynlsd), userDeposit);
ynlsd.deposit(steth, userDeposit, user);

uint256 userShares = ynlsd.balanceOf(user);

uint256 currentTotalDeposited = ynlsd.balances(asset);
// uint256 currentTotalDeposited = ynlsd.balances(asset);
uint256 currentTotalAssets = ynlsd.getTotalAssets()[0];
uint256 currentTotalShares = ynlsd.totalSupply();

runInvariants(
user,
previousTotalDeposited,
// previousTotalDeposited,
previousTotalAssets,
previousTotalShares,
currentTotalDeposited,
// currentTotalDeposited,
currentTotalAssets,
currentTotalShares,
userDeposit,
Expand All @@ -67,18 +70,18 @@ contract YnLSDScenarioTest1 is IntegrationBaseTest {

function runInvariants(
address user,
uint256 previousTotalDeposited,
// uint256 previousTotalDeposited,
uint256 previousTotalAssets,
uint256 previousTotalShares,
uint256 currentTotalDeposited,
// uint256 currentTotalDeposited,
uint256 currentTotalAssets,
uint256 currentTotalShares,
uint256 userDeposit,
uint256 userShares
) public view{
Invariants.totalDepositIntegrity(currentTotalDeposited, previousTotalDeposited, userDeposit);
Invariants.totalAssetsIntegrity(currentTotalAssets, previousTotalAssets, userDeposit);
Invariants.shareMintIntegrity(currentTotalShares, previousTotalShares, userShares);
Invariants.userSharesIntegrity(ynlsd.balanceOf(user), 0, userShares);
// Invariants.totalDepositIntegrity(currentTotalDeposited, previousTotalDeposited, userDeposit);
Invariants.totalAssetsIntegrity(currentTotalAssets, previousTotalAssets, userDeposit);
Invariants.shareMintIntegrity(currentTotalShares, previousTotalShares, userShares);
Invariants.userSharesIntegrity(ynlsd.balanceOf(user), 0, userShares);
}
}

0 comments on commit 8878826

Please sign in to comment.