[openzeppelin]:v4.8.3,[forge-std]:v1.5.6
ERC20Snapshot库是ERC20的拓展,增加了各账户余额及总流通量的快照机制。如果涉及到根据账户ERC20余额进行分红、投票等业务可以使用该库,其可有效防御在不同地址间转账进行“一币多用”的攻击。在一个快照横截面数据上进行分红、投票甚至是ERC20分叉都是最有效的解决方案。本库具有高效性,创建快照、快照上查询地址余额及总流通量的时间复杂度分别是O(1)
和O(log n)
。但快照功能的存在会增加ERC20发生转移时的gas成本。
注:重写_getCurrentSnapshotId()
方法可自定义快照id逻辑,如:使用区块高度作为快照id。但是要保证快照id在时间维度上具有单调性。每个区块都打快照将承担巨大的gas成本。
继承ERC20Snapshot合约:
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;
import "openzeppelin-contracts/contracts/token/ERC20/extensions/ERC20Snapshot.sol";
contract MockERC20Snapshot is ERC20Snapshot {
constructor(
string memory name,
string memory symbol,
address richer,
uint totalSupply
)
ERC20(name, symbol)
{
_mint(richer, totalSupply);
}
function mint(address account, uint amount) external {
_mint(account, amount);
}
function burn(address account, uint amount) external {
_burn(account, amount);
}
function snapshot() external {
_snapshot();
}
function getCurrentSnapshotId() external view returns (uint){
return _getCurrentSnapshotId();
}
}
全部foundry测试合约:
_snapshot()
:创建一个新的快照数据并返回其id。该函数可见性为internal,开发者需要自行决定调用它的位置和权限;_getCurrentSnapshotId()
:返回快照id计数器的当前值。注:由于快照数据的存在,开发者需要知道本ERC20可能面临的风险点:
// 对类型uint256[]使用OpenZeppelin中的utils/Arrays.sol库
using Arrays for uint256[];
// 对OpenZeppelin中utils/Counters.sol库中的Counter结构体使用utils/Counters.sol库
using Counters for Counters.Counter;
// 快照数据的底层实际上是两个动态uint256数组。一个数组记录的是一系列id,另一个数组记录的是每个id对应的值(业务上是账户余额或总流通量)
struct Snapshots {
uint256[] ids;
uint256[] values;
}
// 每个账户余额的快照数据
// key: 账户地址, value: 该地址的快照数据
mapping(address => Snapshots) private _accountBalanceSnapshots;
// total supply的快照数据
Snapshots private _totalSupplySnapshots;
// 快照id的计数器。从1开始递增,所以id=0是无效id
Counters.Counter private _currentSnapshotId;
// 调用_snapshot()时,会抛出
event Snapshot(uint256 id);
function _snapshot() internal virtual returns (uint256) {
// 快照id计数器自加1
_currentSnapshotId.increment();
// 获取当前可用的快照id
uint256 currentId = _getCurrentSnapshotId();
// 抛出事件
emit Snapshot(currentId);
// 返回上面获取到的当前可用快照id
return currentId;
}
function _getCurrentSnapshotId() internal view virtual returns (uint256) {
// 返回快照id计数器的当前值
return _currentSnapshotId.current();
}
foundry代码验证:
contract ERC20SnapshotTest is Test {
MockERC20Snapshot private _testing = new MockERC20Snapshot("test name", "test symbol", address(this), 10000);
event Snapshot(uint256 id);
function test_SnapshotAndGetCurrentSnapshotId() external {
assertEq(_testing.getCurrentSnapshotId(), 0);
// increasing snapshots ids from 1
for (uint id = 1; id < 10; ++id) {
vm.expectEmit(address(_testing));
emit Snapshot(id);
_testing.snapshot();
// check getCurrentSnapshotId()
assertEq(_testing.getCurrentSnapshotId(), id);
}
}
}
重写ERC20._beforeTokenTransfer()
方法。每当ERC20内部调用_mint()
, _burn()
或_transfer()
时,会对应进行快照的更新。
function _beforeTokenTransfer(
address from,
address to,
uint256 amount
) internal virtual override {
// 调用ERC20._beforeTokenTransfer()
super._beforeTokenTransfer(from, to, amount);
if (from == address(0)) {
// 如果本次操作是由_mint()引起(会引起total supply变化)
// 更新to地址的快照数据(记录token转移之前的数据)
_updateAccountSnapshot(to);
// 更新total supply快照数据(记录token转移之前的数据)
_updateTotalSupplySnapshot();
} else if (to == address(0)) {
// 如果本次操作是由_burn()引起(会引起total supply变化)
// 更新from地址的快照数据(记录token转移之前的数据)
_updateAccountSnapshot(from);
// 更新total supply快照数据(记录token转移之前的数据)
_updateTotalSupplySnapshot();
} else {
// 如果本次操作是由_transfer()引起(不会引起total supply变化)
// 更新from地址的快照数据(记录token转移之前的数据)
_updateAccountSnapshot(from);
// 更新to地址的快照数据(记录token转移之前的数据)
_updateAccountSnapshot(to);
}
}
// 更新account地址对应的快照数据
function _updateAccountSnapshot(address account) private {
// 将account当前余额更新到account对应的快照数据中
_updateSnapshot(_accountBalanceSnapshots[account], balanceOf(account));
}
// 更新total supply的快照数据
function _updateTotalSupplySnapshot() private {
// 将当前的total supply更新到total supply的快照数据中
_updateSnapshot(_totalSupplySnapshots, totalSupply());
}
// 向指定快照数据中增添新的值
// - snapshots: 指定的快照数据
// - currentValue: 待增添的值
function _updateSnapshot(Snapshots storage snapshots, uint256 currentValue) private {
// currentId为获取快照id计数器的当前值
uint256 currentId = _getCurrentSnapshotId();
if (_lastSnapshotId(snapshots.ids) < currentId) {
// 如果该快照数据中记录的最近快照数据id小于currentId时:
// 向该快照数据的ids数组尾部增添currentId
snapshots.ids.push(currentId);
// 向该快照数据的values数组尾部增添currentValue
snapshots.values.push(currentValue);
}
// 注:如果该快照数据中记录的最近快照数据id等于currentId时,不会做任何操作
// 即如果在同一snapshot上多次更新,只会记录第一次更新的数据
// 具体逻辑流程:
// 假设A转给10个token给B,并触发snapshot。这时在A和B对应的Snapshots中并没有记录上面这笔交易。而在触发snapshot后,A又转给1个token给C。这时,在_beforeTokenTransfer()中,会将A在此次转账前的余额push到snapshots.ids和snapshots.values中
}
// 获取一个ids序列(uint256数组)中的最后一个元素
function _lastSnapshotId(uint256[] storage ids) private view returns (uint256) {
if (ids.length == 0) {
// 如果传入ids为空数组,直接返回0
return 0;
} else {
// 如果传入ids不是空数组,返回其尾部元素
return ids[ids.length - 1];
}
}
获取在snapshotId快照上的total supply。
function totalSupplyAt(uint256 snapshotId) public view virtual returns (uint256) {
// 调用_valueAt()方法,从total supply快照数据中查找snapshotId对应的值
// snapshotted: snapshotId是否存在于total supply快照数据中
// value: snapshotId在total supply快照数据中对应的值。如果不存在,则为0
(bool snapshotted, uint256 value) = _valueAt(snapshotId, _totalSupplySnapshots);
// 如果snapshotId存在于total supply快照数据中,返回对应value。如果不存在,直接返回当前的total supply
return snapshotted ? value : totalSupply();
}
// 在指定快照数据中检索指定快照id对应的值
// - snapshotId: 要检索的快照id
// - snapshots: 快照数据(即指定的Snapshots结构体)
function _valueAt(uint256 snapshotId, Snapshots storage snapshots) private view returns (bool, uint256) {
// 检验snapshotId的有效性:
// 1. 待查询snapshotId需要大于0
require(snapshotId > 0, "ERC20Snapshot: id is 0");
// 2. 待查询snapshotId需要小于等于快照id计数器的当前值(每产生一个新的快照id,id计数器会自增1)
require(snapshotId <= _getCurrentSnapshotId(), "ERC20Snapshot: nonexistent id");
// 从一个Snapshots结构体中检索一个有效snapshotId对应的值时会出现以下三种可能性:
// 1. 在该快照后,该值没有被修改过。所以在该snapshot id上没有任何与该值相关的记录;
// 2. 在该快照后,该值被修改过。所以在该snapshot id上可以查到与该值记录;
// 3. 在该快照后,该值没有被修改过。但是在其后面多个快照后被修改。这种情况,取大于该快照id的最小快照id来检索该值
// 使用二分法,从指定快照数据的ids中查找大于等于snapshotId的第一个元素所处在的index
// 注:如果snapshots.id为空,返回0
// findUpperBound()函数细节分析参见:https://learnblockchain.cn/article/6111
uint256 index = snapshots.ids.findUpperBound(snapshotId);
if (index == snapshots.ids.length) {
// 如果index为快照数据的ids的长度,表示该snapshotId大于ids中所有元素
// 即在该snapshotId上没有关于该值的记录
// 返回false和0
return (false, 0);
} else {
// 如果index不是快照数据的ids的长度,表示该snapshotId存在于snapshots.ids中
// 或小于等于snapshots.ids的尾部元素值
// 返回true和第一个大于等于snapshotId的快照上记录的值
return (true, snapshots.values[index]);
}
}
foundry代码验证:
contract ERC20SnapshotTest is Test {
MockERC20Snapshot private _testing = new MockERC20Snapshot("test name", "test symbol", address(this), 10000);
address private holder1 = address(1);
address private holder2 = address(2);
function test_TotalSupplyAt() external {
// revert if snapshot id is 0
vm.expectRevert("ERC20Snapshot: id is 0");
_testing.totalSupplyAt(0);
// revert if snapshot id is not created
vm.expectRevert("ERC20Snapshot: nonexistent id");
_testing.totalSupplyAt(1);
uint totalSupply = _testing.totalSupply();
assertEq(totalSupply, 10000);
_testing.snapshot();
assertEq(_testing.totalSupplyAt(1), totalSupply);
// mint
_testing.mint(address(this), 1);
_testing.mint(address(this), 2);
_testing.mint(address(this), 3);
totalSupply += 1 + 2 + 3;
_testing.snapshot();
// mint after snapshot
_testing.mint(address(this), 4);
assertEq(_testing.totalSupplyAt(2), totalSupply);
totalSupply += 4;
// burn
_testing.burn(address(this), 5);
_testing.burn(address(this), 6);
_testing.burn(address(this), 7);
totalSupply -= 5 + 6 + 7;
_testing.snapshot();
// burn after snapshot
_testing.burn(address(this), 8);
assertEq(_testing.totalSupplyAt(3), totalSupply);
totalSupply -= 8;
// transfer
_testing.transfer(holder1, 9);
_testing.transfer(holder2, 10);
_testing.snapshot();
// transfer after snapshot
vm.prank(holder1);
_testing.transfer(holder2, 1);
// totalSupplyAt(4) not change
assertEq(_testing.totalSupplyAt(4), totalSupply);
}
}
获取account地址在snapshotId快照上的余额。
function balanceOfAt(address account, uint256 snapshotId) public view virtual returns (uint256) {
// 调用_valueAt()方法,从account地址的快照数据中查找snapshotId对应的值
// snapshotted: snapshotId是否存在于account地址的快照数据中
// value: snapshotId在account地址的快照数据中对应的值。如果不存在,则为0
(bool snapshotted, uint256 value) = _valueAt(snapshotId, _accountBalanceSnapshots[account]);
// 如果snapshotId存在于account地址的快照数据中,返回对应value。如果不存在,直接返回account当前的余额
return snapshotted ? value : balanceOf(account);
}
foundry代码验证:
contract ERC20SnapshotTest is Test {
MockERC20Snapshot private _testing = new MockERC20Snapshot("test name", "test symbol", address(this), 10000);
address private holder1 = address(1);
address private holder2 = address(2);
function test_BalanceOfAt() external {
// revert if snapshot id is 0
vm.expectRevert("ERC20Snapshot: id is 0");
_testing.balanceOfAt(address(this), 0);
// revert if snapshot id is not created
vm.expectRevert("ERC20Snapshot: nonexistent id");
_testing.balanceOfAt(address(this), 1);
uint balance = _testing.balanceOf(address(this));
assertEq(balance, 10000);
_testing.snapshot();
assertEq(_testing.balanceOfAt(address(this), 1), balance);
assertEq(_testing.balanceOfAt(holder1, 1), 0);
assertEq(_testing.balanceOfAt(holder2, 1), 0);
// mint
_testing.mint(address(this), 1);
_testing.mint(address(this), 2);
_testing.mint(address(this), 3);
balance += 1 + 2 + 3;
_testing.snapshot();
// mint after snapshot
_testing.mint(address(this), 4);
assertEq(_testing.balanceOfAt(address(this), 2), balance);
assertEq(_testing.balanceOfAt(holder1, 2), 0);
assertEq(_testing.balanceOfAt(holder2, 2), 0);
balance += 4;
// burn
_testing.burn(address(this), 5);
_testing.burn(address(this), 6);
_testing.burn(address(this), 7);
balance -= 5 + 6 + 7;
_testing.snapshot();
// burn after snapshot
_testing.burn(address(this), 8);
assertEq(_testing.balanceOfAt(address(this), 3), balance);
assertEq(_testing.balanceOfAt(holder1, 3), 0);
assertEq(_testing.balanceOfAt(holder2, 3), 0);
balance -= 8;
// transfer
_testing.transfer(holder1, 9);
_testing.transfer(holder2, 10);
_testing.snapshot();
// transfer after snapshot
vm.prank(holder1);
_testing.transfer(address(this), 1);
assertEq(_testing.balanceOfAt(address(this), 4), balance - 9 - 10);
assertEq(_testing.balanceOfAt(holder1, 4), 9);
assertEq(_testing.balanceOfAt(holder2, 4), 10);
}
function test_MintAndBurnAndTransferInASnapshot() external {
uint totalSupply = _testing.totalSupply();
uint balanceHolder1;
uint balanceHolder2;
uint balanceHolderThis = _testing.balanceOf(address(this));
// snapshot 1
_testing.transfer(holder1, 1);
_testing.mint(holder2, 2);
_testing.burn(address(this), 3);
_testing.snapshot();
totalSupply = totalSupply + 2 - 3;
balanceHolder1 += 1;
balanceHolder2 += 2;
balanceHolderThis -= 1 + 3;
assertEq(_testing.totalSupplyAt(1), totalSupply);
assertEq(_testing.balanceOfAt(holder1, 1), balanceHolder1);
assertEq(_testing.balanceOfAt(holder2, 1), balanceHolder2);
assertEq(_testing.balanceOfAt(address(this), 1), balanceHolderThis);
// snapshot 2
_testing.burn(holder1, 1);
_testing.transfer(holder2, 4);
_testing.mint(address(this), 5);
_testing.snapshot();
totalSupply = totalSupply + 5 - 1;
balanceHolder1 -= 1;
balanceHolder2 += 4;
balanceHolderThis = balanceHolderThis - 4 + 5;
assertEq(_testing.totalSupplyAt(2), totalSupply);
assertEq(_testing.balanceOfAt(holder1, 2), balanceHolder1);
assertEq(_testing.balanceOfAt(holder2, 2), balanceHolder2);
assertEq(_testing.balanceOfAt(address(this), 2), balanceHolderThis);
// snapshot 3
_testing.mint(holder1, 6);
_testing.burn(holder2, 2);
vm.prank(holder2);
_testing.transfer(address(this), 3);
_testing.snapshot();
totalSupply = totalSupply + 6 - 2;
balanceHolder1 += 6;
balanceHolder2 -= 2 + 3;
balanceHolderThis += 3;
assertEq(_testing.totalSupplyAt(3), totalSupply);
assertEq(_testing.balanceOfAt(holder1, 3), balanceHolder1);
assertEq(_testing.balanceOfAt(holder2, 3), balanceHolder2);
assertEq(_testing.balanceOfAt(address(this), 3), balanceHolderThis);
}
}
ps:
本人热爱图灵,热爱中本聪,热爱V神。
以下是我个人的公众号,如果有技术问题可以关注我的公众号来跟我交流。
同时我也会在这个公众号上每周更新我的原创文章,喜欢的小伙伴或者老伙计可以支持一下!
如果需要转发,麻烦注明作者。十分感谢!
公众号名称:后现代泼痞浪漫主义奠基人