Michael.W基于Foundry精读Openzeppelin第46期——ERC20Snapshot.sol

发布时间:2024年01月02日

0. 版本

[openzeppelin]:v4.8.3,[forge-std]:v1.5.6

0.1 ERC20Snapshot.sol

Github: https://github.com/OpenZeppelin/openzeppelin-contracts/blob/v4.8.3/contracts/token/ERC20/extensions/ERC20Snapshot.sol

ERC20Snapshot库是ERC20的拓展,增加了各账户余额及总流通量的快照机制。如果涉及到根据账户ERC20余额进行分红、投票等业务可以使用该库,其可有效防御在不同地址间转账进行“一币多用”的攻击。在一个快照横截面数据上进行分红、投票甚至是ERC20分叉都是最有效的解决方案。本库具有高效性,创建快照、快照上查询地址余额及总流通量的时间复杂度分别是O(1)O(log n)。但快照功能的存在会增加ERC20发生转移时的gas成本。

注:重写_getCurrentSnapshotId()方法可自定义快照id逻辑,如:使用区块高度作为快照id。但是要保证快照id在时间维度上具有单调性。每个区块都打快照将承担巨大的gas成本。

1. 目标合约

继承ERC20Snapshot合约:

Github: https://github.com/RevelationOfTuring/foundry-openzeppelin-contracts/blob/master/src/token/ERC20/extensions/MockERC20Snapshot.sol

// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;

import "openzeppelin-contracts/contracts/token/ERC20/extensions/ERC20Snapshot.sol";

contract MockERC20Snapshot is ERC20Snapshot {
    constructor(
        string memory name,
        string memory symbol,
        address richer,
        uint totalSupply
    )
    ERC20(name, symbol)
    {
        _mint(richer, totalSupply);
    }

    function mint(address account, uint amount) external {
        _mint(account, amount);
    }

    function burn(address account, uint amount) external {
        _burn(account, amount);
    }

    function snapshot() external {
        _snapshot();
    }

    function getCurrentSnapshotId() external view returns (uint){
        return _getCurrentSnapshotId();
    }
}

全部foundry测试合约:

Github: https://github.com/RevelationOfTuring/foundry-openzeppelin-contracts/blob/master/test/token/ERC20/extensions/ERC20Snapshot.t.sol

2. 代码精读

2.1 _snapshot() && _getCurrentSnapshotId()
  • _snapshot():创建一个新的快照数据并返回其id。该函数可见性为internal,开发者需要自行决定调用它的位置和权限;
  • _getCurrentSnapshotId():返回快照id计数器的当前值。

注:由于快照数据的存在,开发者需要知道本ERC20可能面临的风险点:

  1. 从快照中检索数据消耗的gas呈对数级增长(二分法)。快照越多,检索成本越高;
  2. 一个用户的快照数据越多,该用户作为转账from和to的gas消耗就越高(可能会被攻击者利用)。
	// 对类型uint256[]使用OpenZeppelin中的utils/Arrays.sol库
    using Arrays for uint256[];
    // 对OpenZeppelin中utils/Counters.sol库中的Counter结构体使用utils/Counters.sol库
    using Counters for Counters.Counter;

    // 快照数据的底层实际上是两个动态uint256数组。一个数组记录的是一系列id,另一个数组记录的是每个id对应的值(业务上是账户余额或总流通量)
    struct Snapshots {
        uint256[] ids;
        uint256[] values;
    }

    // 每个账户余额的快照数据
    // key: 账户地址, value: 该地址的快照数据
    mapping(address => Snapshots) private _accountBalanceSnapshots;
    // total supply的快照数据
    Snapshots private _totalSupplySnapshots;

    // 快照id的计数器。从1开始递增,所以id=0是无效id
    Counters.Counter private _currentSnapshotId;

    // 调用_snapshot()时,会抛出
    event Snapshot(uint256 id);

    function _snapshot() internal virtual returns (uint256) {
        // 快照id计数器自加1
        _currentSnapshotId.increment();
        // 获取当前可用的快照id
        uint256 currentId = _getCurrentSnapshotId();
        // 抛出事件
        emit Snapshot(currentId);
        // 返回上面获取到的当前可用快照id
        return currentId;
    }
    
    function _getCurrentSnapshotId() internal view virtual returns (uint256) {
        // 返回快照id计数器的当前值
        return _currentSnapshotId.current();
    }

foundry代码验证:

contract ERC20SnapshotTest is Test {
    MockERC20Snapshot private _testing = new MockERC20Snapshot("test name", "test symbol", address(this), 10000);

    event Snapshot(uint256 id);

    function test_SnapshotAndGetCurrentSnapshotId() external {
        assertEq(_testing.getCurrentSnapshotId(), 0);
        // increasing snapshots ids from 1
        for (uint id = 1; id < 10; ++id) {
            vm.expectEmit(address(_testing));
            emit Snapshot(id);
            _testing.snapshot();
            // check getCurrentSnapshotId()
            assertEq(_testing.getCurrentSnapshotId(), id);
        }
    }
}
2.2 _beforeTokenTransfer(address from, address to, uint256 amount)

重写ERC20._beforeTokenTransfer()方法。每当ERC20内部调用_mint(), _burn()_transfer()时,会对应进行快照的更新。

    function _beforeTokenTransfer(
        address from,
        address to,
        uint256 amount
    ) internal virtual override {
        // 调用ERC20._beforeTokenTransfer()
        super._beforeTokenTransfer(from, to, amount);
        
        if (from == address(0)) {
            // 如果本次操作是由_mint()引起(会引起total supply变化)
            // 更新to地址的快照数据(记录token转移之前的数据)
            _updateAccountSnapshot(to);
            // 更新total supply快照数据(记录token转移之前的数据)
            _updateTotalSupplySnapshot();
        } else if (to == address(0)) {
            // 如果本次操作是由_burn()引起(会引起total supply变化)
            // 更新from地址的快照数据(记录token转移之前的数据)
            _updateAccountSnapshot(from);
            // 更新total supply快照数据(记录token转移之前的数据)
            _updateTotalSupplySnapshot();
        } else {
            // 如果本次操作是由_transfer()引起(不会引起total supply变化)
            // 更新from地址的快照数据(记录token转移之前的数据)
            _updateAccountSnapshot(from);
            // 更新to地址的快照数据(记录token转移之前的数据)
            _updateAccountSnapshot(to);
        }
    }
    
    // 更新account地址对应的快照数据
    function _updateAccountSnapshot(address account) private {
        // 将account当前余额更新到account对应的快照数据中
        _updateSnapshot(_accountBalanceSnapshots[account], balanceOf(account));
    }

    // 更新total supply的快照数据
    function _updateTotalSupplySnapshot() private {
        // 将当前的total supply更新到total supply的快照数据中
        _updateSnapshot(_totalSupplySnapshots, totalSupply());
    }

    // 向指定快照数据中增添新的值
    // - snapshots: 指定的快照数据
    // - currentValue: 待增添的值
    function _updateSnapshot(Snapshots storage snapshots, uint256 currentValue) private {
        // currentId为获取快照id计数器的当前值
        uint256 currentId = _getCurrentSnapshotId();
        if (_lastSnapshotId(snapshots.ids) < currentId) {
            // 如果该快照数据中记录的最近快照数据id小于currentId时:
            // 向该快照数据的ids数组尾部增添currentId
            snapshots.ids.push(currentId);
            // 向该快照数据的values数组尾部增添currentValue
            snapshots.values.push(currentValue);
        }

        // 注:如果该快照数据中记录的最近快照数据id等于currentId时,不会做任何操作
        // 即如果在同一snapshot上多次更新,只会记录第一次更新的数据
        // 具体逻辑流程:
        // 假设A转给10个token给B,并触发snapshot。这时在A和B对应的Snapshots中并没有记录上面这笔交易。而在触发snapshot后,A又转给1个token给C。这时,在_beforeTokenTransfer()中,会将A在此次转账前的余额push到snapshots.ids和snapshots.values中
    }
    
    // 获取一个ids序列(uint256数组)中的最后一个元素
    function _lastSnapshotId(uint256[] storage ids) private view returns (uint256) {
        if (ids.length == 0) {
            // 如果传入ids为空数组,直接返回0
            return 0;
        } else {
            // 如果传入ids不是空数组,返回其尾部元素
            return ids[ids.length - 1];
        }
    }
2.3 totalSupplyAt(uint256 snapshotId)

获取在snapshotId快照上的total supply。

    function totalSupplyAt(uint256 snapshotId) public view virtual returns (uint256) {
        // 调用_valueAt()方法,从total supply快照数据中查找snapshotId对应的值
        // snapshotted: snapshotId是否存在于total supply快照数据中
        // value: snapshotId在total supply快照数据中对应的值。如果不存在,则为0
        (bool snapshotted, uint256 value) = _valueAt(snapshotId, _totalSupplySnapshots);
        // 如果snapshotId存在于total supply快照数据中,返回对应value。如果不存在,直接返回当前的total supply
        return snapshotted ? value : totalSupply();
    }

    // 在指定快照数据中检索指定快照id对应的值
    // - snapshotId: 要检索的快照id
    // - snapshots: 快照数据(即指定的Snapshots结构体)
    function _valueAt(uint256 snapshotId, Snapshots storage snapshots) private view returns (bool, uint256) {
        // 检验snapshotId的有效性:
        // 1. 待查询snapshotId需要大于0
        require(snapshotId > 0, "ERC20Snapshot: id is 0");
        // 2. 待查询snapshotId需要小于等于快照id计数器的当前值(每产生一个新的快照id,id计数器会自增1)
        require(snapshotId <= _getCurrentSnapshotId(), "ERC20Snapshot: nonexistent id");

        // 从一个Snapshots结构体中检索一个有效snapshotId对应的值时会出现以下三种可能性:
        // 1. 在该快照后,该值没有被修改过。所以在该snapshot id上没有任何与该值相关的记录;
        // 2. 在该快照后,该值被修改过。所以在该snapshot id上可以查到与该值记录;
        // 3. 在该快照后,该值没有被修改过。但是在其后面多个快照后被修改。这种情况,取大于该快照id的最小快照id来检索该值
        
        // 使用二分法,从指定快照数据的ids中查找大于等于snapshotId的第一个元素所处在的index
        // 注:如果snapshots.id为空,返回0
        // findUpperBound()函数细节分析参见:https://learnblockchain.cn/article/6111
        uint256 index = snapshots.ids.findUpperBound(snapshotId);
        if (index == snapshots.ids.length) {
            // 如果index为快照数据的ids的长度,表示该snapshotId大于ids中所有元素
            // 即在该snapshotId上没有关于该值的记录
            // 返回false和0
            return (false, 0);
        } else {
            // 如果index不是快照数据的ids的长度,表示该snapshotId存在于snapshots.ids中
            // 或小于等于snapshots.ids的尾部元素值
            // 返回true和第一个大于等于snapshotId的快照上记录的值
            return (true, snapshots.values[index]);
        }
    }

foundry代码验证:

contract ERC20SnapshotTest is Test {
    MockERC20Snapshot private _testing = new MockERC20Snapshot("test name", "test symbol", address(this), 10000);
    address private holder1 = address(1);
    address private holder2 = address(2);

    function test_TotalSupplyAt() external {
        // revert if snapshot id is 0
        vm.expectRevert("ERC20Snapshot: id is 0");
        _testing.totalSupplyAt(0);

        // revert if snapshot id is not created
        vm.expectRevert("ERC20Snapshot: nonexistent id");
        _testing.totalSupplyAt(1);

        uint totalSupply = _testing.totalSupply();
        assertEq(totalSupply, 10000);
        _testing.snapshot();
        assertEq(_testing.totalSupplyAt(1), totalSupply);

        // mint
        _testing.mint(address(this), 1);
        _testing.mint(address(this), 2);
        _testing.mint(address(this), 3);
        totalSupply += 1 + 2 + 3;
        _testing.snapshot();
        // mint after snapshot
        _testing.mint(address(this), 4);
        assertEq(_testing.totalSupplyAt(2), totalSupply);
        totalSupply += 4;

        // burn
        _testing.burn(address(this), 5);
        _testing.burn(address(this), 6);
        _testing.burn(address(this), 7);
        totalSupply -= 5 + 6 + 7;
        _testing.snapshot();
        // burn after snapshot
        _testing.burn(address(this), 8);
        assertEq(_testing.totalSupplyAt(3), totalSupply);
        totalSupply -= 8;

        // transfer
        _testing.transfer(holder1, 9);
        _testing.transfer(holder2, 10);
        _testing.snapshot();
        // transfer after snapshot
        vm.prank(holder1);
        _testing.transfer(holder2, 1);
        // totalSupplyAt(4) not change
        assertEq(_testing.totalSupplyAt(4), totalSupply);
    }
}
2.4 balanceOfAt(address account, uint256 snapshotId)

获取account地址在snapshotId快照上的余额。

    function balanceOfAt(address account, uint256 snapshotId) public view virtual returns (uint256) {
        // 调用_valueAt()方法,从account地址的快照数据中查找snapshotId对应的值
        // snapshotted: snapshotId是否存在于account地址的快照数据中
        // value: snapshotId在account地址的快照数据中对应的值。如果不存在,则为0
        (bool snapshotted, uint256 value) = _valueAt(snapshotId, _accountBalanceSnapshots[account]);
        // 如果snapshotId存在于account地址的快照数据中,返回对应value。如果不存在,直接返回account当前的余额
        return snapshotted ? value : balanceOf(account);
    }

foundry代码验证:

contract ERC20SnapshotTest is Test {
    MockERC20Snapshot private _testing = new MockERC20Snapshot("test name", "test symbol", address(this), 10000);
    address private holder1 = address(1);
    address private holder2 = address(2);

    function test_BalanceOfAt() external {
        // revert if snapshot id is 0
        vm.expectRevert("ERC20Snapshot: id is 0");
        _testing.balanceOfAt(address(this), 0);

        // revert if snapshot id is not created
        vm.expectRevert("ERC20Snapshot: nonexistent id");
        _testing.balanceOfAt(address(this), 1);

        uint balance = _testing.balanceOf(address(this));
        assertEq(balance, 10000);
        _testing.snapshot();
        assertEq(_testing.balanceOfAt(address(this), 1), balance);
        assertEq(_testing.balanceOfAt(holder1, 1), 0);
        assertEq(_testing.balanceOfAt(holder2, 1), 0);

        // mint
        _testing.mint(address(this), 1);
        _testing.mint(address(this), 2);
        _testing.mint(address(this), 3);
        balance += 1 + 2 + 3;
        _testing.snapshot();
        // mint after snapshot
        _testing.mint(address(this), 4);
        assertEq(_testing.balanceOfAt(address(this), 2), balance);
        assertEq(_testing.balanceOfAt(holder1, 2), 0);
        assertEq(_testing.balanceOfAt(holder2, 2), 0);
        balance += 4;

        // burn
        _testing.burn(address(this), 5);
        _testing.burn(address(this), 6);
        _testing.burn(address(this), 7);
        balance -= 5 + 6 + 7;
        _testing.snapshot();
        // burn after snapshot
        _testing.burn(address(this), 8);
        assertEq(_testing.balanceOfAt(address(this), 3), balance);
        assertEq(_testing.balanceOfAt(holder1, 3), 0);
        assertEq(_testing.balanceOfAt(holder2, 3), 0);
        balance -= 8;

        // transfer
        _testing.transfer(holder1, 9);
        _testing.transfer(holder2, 10);
        _testing.snapshot();
        // transfer after snapshot
        vm.prank(holder1);
        _testing.transfer(address(this), 1);
        assertEq(_testing.balanceOfAt(address(this), 4), balance - 9 - 10);
        assertEq(_testing.balanceOfAt(holder1, 4), 9);
        assertEq(_testing.balanceOfAt(holder2, 4), 10);
    }

    function test_MintAndBurnAndTransferInASnapshot() external {
        uint totalSupply = _testing.totalSupply();
        uint balanceHolder1;
        uint balanceHolder2;
        uint balanceHolderThis = _testing.balanceOf(address(this));

        // snapshot 1
        _testing.transfer(holder1, 1);
        _testing.mint(holder2, 2);
        _testing.burn(address(this), 3);
        _testing.snapshot();
        totalSupply = totalSupply + 2 - 3;
        balanceHolder1 += 1;
        balanceHolder2 += 2;
        balanceHolderThis -= 1 + 3;

        assertEq(_testing.totalSupplyAt(1), totalSupply);
        assertEq(_testing.balanceOfAt(holder1, 1), balanceHolder1);
        assertEq(_testing.balanceOfAt(holder2, 1), balanceHolder2);
        assertEq(_testing.balanceOfAt(address(this), 1), balanceHolderThis);

        // snapshot 2
        _testing.burn(holder1, 1);
        _testing.transfer(holder2, 4);
        _testing.mint(address(this), 5);
        _testing.snapshot();
        totalSupply = totalSupply + 5 - 1;
        balanceHolder1 -= 1;
        balanceHolder2 += 4;
        balanceHolderThis = balanceHolderThis - 4 + 5;

        assertEq(_testing.totalSupplyAt(2), totalSupply);
        assertEq(_testing.balanceOfAt(holder1, 2), balanceHolder1);
        assertEq(_testing.balanceOfAt(holder2, 2), balanceHolder2);
        assertEq(_testing.balanceOfAt(address(this), 2), balanceHolderThis);

        // snapshot 3
        _testing.mint(holder1, 6);
        _testing.burn(holder2, 2);
        vm.prank(holder2);
        _testing.transfer(address(this), 3);
        _testing.snapshot();
        totalSupply = totalSupply + 6 - 2;
        balanceHolder1 += 6;
        balanceHolder2 -= 2 + 3;
        balanceHolderThis += 3;

        assertEq(_testing.totalSupplyAt(3), totalSupply);
        assertEq(_testing.balanceOfAt(holder1, 3), balanceHolder1);
        assertEq(_testing.balanceOfAt(holder2, 3), balanceHolder2);
        assertEq(_testing.balanceOfAt(address(this), 3), balanceHolderThis);
    }
}

ps:
本人热爱图灵,热爱中本聪,热爱V神。
以下是我个人的公众号,如果有技术问题可以关注我的公众号来跟我交流。
同时我也会在这个公众号上每周更新我的原创文章,喜欢的小伙伴或者老伙计可以支持一下!
如果需要转发,麻烦注明作者。十分感谢!

在这里插入图片描述

公众号名称:后现代泼痞浪漫主义奠基人

文章来源:https://blog.csdn.net/michael_wgy_/article/details/135350118
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。