用Go代码模拟数组,足够硬核

发布时间:2024年01月09日

前奏

我们想要将一片连续的内存区域映射成高维度数组所使用的内存区域。说的更直白一些,就是想将高维数组映射到一维数组。
完成这个想法的前提是寻找高维数组和一维数组的下标关系。

关系映射

约定一哈,将左侧数组称为arr0,右侧一维数组称为arr1。

一维数组

假设将一个一维数组,映射到另一个一维数组。那么关系十分简单。如下图所示。
在这里插入图片描述arr0和 arr1长度相等,下标从0开始滚动直到arr等长度 - 1,假设arr0和arr1的长度为N,则:
i的取值范围是:[0,N)
k的取值范围是:[0,N)
即一维数组到一维数组的下标映射,使下标i和下标index同步即可。

arr1[0] = arr0[1]
arr1[1] = arr0[1]
...
arr1[index] = arr0[i]

公式:
index = i
arr1[index] = arr0[i]
arr1[i] = arr0[i]

二维数组

二维数组的下标向一维数组下标映射。
在这里插入图片描述假设左侧arr0的长度是M,arr0[0]中的一维数组长度是N,右侧arr1的长度是W
则下标关系可表示为:
i:[0,M)
j:[0,N)
index:[0,W)
那么,左右两个数组的长度关系可表示为:
W = M * N
若左侧数组arr0从前向后遍历,arr1同步遍历,则i,j与右侧数组的下标k关系可以表示为
index = i * N + j

arr1[0] = arr1[0][0]
arr1[1] = arr1[0][1]
arr1[2] = arr1[1][0]
arr1[3] = arr0[1][1]

在这里插入图片描述上述数组下标关系映射为

arr1[0] = arr0[0][0]
arr1[1] = arr0[0][1]
arr1[2] = arr0[0][2]
arr1[3] = arr0[1][0]
arr1[4] = arr0[1][1]
...
arr1[8] = arr0[2][2]

公式:
index = i * N + j
arr1[index] = arr0[i*N][j]
arr1[i * N + j] = arr0[i][j]

三维数组

3维数组到1维数组映射关系。

var arr [3][2][3]int
arr[0][0][0]  ==> arr1[0*2*3 + 0*3 + 0*1] =>  arr1[0]
arr[0][0][1]  ==> arr1[0*2*3 + 0*3 + 1*1] =>  arr1[1]
arr[0][0][2]  ==> arr1[0*2*3 + 0*3 + 2*1] =>  arr1[2]
arr[0][1][0]  ==> arr1[0*2*3 + 1*3 + 0*1] =>  arr1[3]
arr[0][1][1]  ==> arr1[0*2*3 + 1*3 + 1*1] =>  arr1[4]
arr[0][1][2]  ==> arr1[0*2*3 + 1*3 + 2*1] =>  arr1[5]
arr[1][0][0]  ==> arr1[1*2*3 + 0*3 + 0*1] =>  arr1[6]
arr[1][0][1]  ==> arr1[1*2*3 + 0*3 + 1*1] =>  arr1[7]
arr[1][0][2]  ==> arr1[1*2*3 + 0*3 + 2*1] =>  arr1[8]
arr[1][1][0]  ==> arr1[1*2*3 + 1*3 + 0*1] =>  arr1[9]
arr[1][1][1]  ==> arr1[1*2*3 + 1*3 + 1*1] =>  arr1[10]
arr[1][1][2]  ==> arr1[1*2*3 + 1*3 + 2*1] =>  arr1[11]
arr[2][0][0]  ==> arr1[2*2*3 + 0*3 + 0*1] =>  arr1[12]
arr[2][0][1]  ==> arr1[2*2*3 + 0*3 + 1*1] =>  arr1[13]
arr[2][0][2]  ==> arr1[2*2*3 + 0*3 + 2*1] =>  arr1[14]
arr[2][1][0]  ==> arr1[2*2*3 + 1*3 + 0*1] =>  arr1[15]
arr[2][1][1]  ==> arr1[2*2*3 + 1*3 + 1*1] =>  arr1[16]
arr[2][1][2]  ==> arr1[2*2*3 + 1*3 + 2*1] =>  arr1[17]

假设有两个数组

var [M][N][P]int arr0
var [M*N*P]int   arr1

假设3维数组下标分别是,i,j,k, 一维数组下标为index则:
index = i * N * P + j * P + k
arr1[index] = arr0[i][j][k]
arr1[i * N * P + j * P + k] = arr0[i][j][k]

总结一下公式

高维数组向一维数组映射,下标关系为:

var [A][B][C]...[Z]int  array0
var [A * B * C *...* Z] array1

array1[a * B*C...*Z + b * C*D*...*Z + z * 1] = array0[a][b]...[z]

假设我们有一个数组:dimensions := []int{2, 3, 5} // var array [2][3][5]int,存储了某数组各个维度的宽度。则可以通过逆序乘积的方式,求各个维度的系数,用下标i,j,k去乘以每个维度的系数就可以得到线性的一维数组下标,fmt.Println(i*tmpArray[0] + j*tmpArray[1] + k*tmpArray[2])

func TestDimension(t *testing.T) {
	dimensions := []int{2, 3, 5} //var array [2][3][5]int
	tmpArray := make([]int, len(dimensions))

	mul := 1 // 总乘积
	for i := len(dimensions) - 1; i >= 0; i-- {
		tmpArray[i] = mul
		mul *= dimensions[i]
	}
	//fmt.Println(dimensions)
	//fmt.Println(tmpArray)
	for i := 0; i < 2; i++ {
		for j := 0; j < 3; j++ {
			for k := 0; k < 5; k++ {
				fmt.Println(i*tmpArray[0] + j*tmpArray[1] + k*tmpArray[2])
			}
		}
	}
}

模拟数组代码

package Go

import (
	"fmt"
	"testing"
	"unsafe"
)

type array[T any] struct {
	data      unsafe.Pointer // 内存区域指针
	n         int            // 数组有n个维度
	dimension map[int]uint   // 每个维度的数据宽度,用于判定数组索引是否越界
	factor    map[int]uint   // 映射到一维数组的偏移量
	elemSize  uintptr        // 类型大小
}

// NewArray create an Array.
// array := NewArray[int](10,5,7,8) equals var array [10][5][7][8]int
func NewArray[T any](dimensions ...uint) *array[T] {
	n := len(dimensions) // n个维度的数组
	if n == 0 {
		panic("invalid dimensions")
	}
	dimensionMap := make(map[int]uint)
	factorMap := make(map[int]uint)

	f := uint(1)
	for i := range dimensions {
		dimensionMap[n-i-1] = dimensions[i]
		factorMap[n-i-1] = f // 反向索引映射
		f *= dimensions[n-i-1]
	}

	return &array[T]{
		data:      malloc[T](f),
		n:         n,
		dimension: dimensionMap,
		factor:    factorMap,
		elemSize:  unsafe.Sizeof(*(*T)(nil)),
	}
}

func malloc[T any](size uint) unsafe.Pointer {
	origin := unsafe.Sizeof(*(*T)(nil))
	memPointer := make([]uint8, uint(origin)*size)
	// defer runtime.KeepAlive(memPointer)
	return *(*unsafe.Pointer)(unsafe.Pointer(&memPointer))
}

func (a *array[T]) Len(dimensions ...uint) int {
	if len(dimensions) > a.n {
		panic("out of dimensions")
	}
	return 0
}

func (a *array[T]) Set(value T, index ...uint) {
	if len(index) != a.n {
		panic("out of dimensions")
	}
	location := uint(0)
	for i := range index {
		if index[i] >= a.dimension[i] || index[i] < 0 {
			panic("index out of bounds")
		}
		location += index[i] * a.factor[i]
	}
	*(*T)(unsafe.Add(a.data, uintptr(location)*a.elemSize)) = value
}

func (a *array[T]) Get(index ...uint) T {
	if len(index) != a.n {
		panic("out of dimensions")
	}
	location := uint(0)
	for i := range index {
		if index[i] >= a.dimension[i] || index[i] < 0 {
			panic("index out of bounds")
		}
		location += index[i] * a.factor[i]
	}
	return *(*T)(unsafe.Add(a.data, uintptr(location)*a.elemSize))
}

type student struct {
	Name string
	Age  int
}

func TestArray(t *testing.T) {
	studentArray := NewArray[student](3) // [10]student
	studentArray.Set(student{            // student[0] = student{...}
		Name: "zhang san",
		Age:  10,
	}, 0)

	studentArray.Set(student{ // student[1] = student{...}
		Name: "li si",
		Age:  20,
	}, 1)

	fmt.Println(studentArray.Get(0))
	fmt.Println(studentArray.Get(1))
	fmt.Println(studentArray.Get(2))

	array1 := NewArray[int](2, 2) // [2][2]int
	array1.Set(1, 0, 0)           // arr[0][0] = 1
	array1.Set(2, 0, 1)           // arr[0][1] = 2
	array1.Set(3, 1, 0)           // arr[1][0] = 3
	array1.Set(4, 1, 1)           // arr[1][1] = 4

	fmt.Println(array1.Get(0, 0)) // arr[0][0]
	fmt.Println(array1.Get(0, 1)) // arr[0][1]
	fmt.Println(array1.Get(1, 0)) // arr[1][0]
	fmt.Println(array1.Get(1, 1)) // arr[1][1]

	array2 := NewArray[int](3, 3, 3) // [3][3][3]int
	w := 1
	for i := 0; i < 3; i++ {
		for j := 0; j < 3; j++ {
			for k := 0; k < 3; k++ {
				array2.Set(w, uint(i), uint(j), uint(k))
				w++
			}
		}
	}

	w = 1
	for i := 0; i < 3; i++ {
		for j := 0; j < 3; j++ {
			for k := 0; k < 3; k++ {
				num := array2.Get(uint(i), uint(j), uint(k))
				if num != w {
					fmt.Println("---", num, w)
					panic("oi! 出错了")
				}
				w++
			}
		}
	}

	arr := *(*[27]int)(array2.data)
	fmt.Println(arr) // [1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27

}
文章来源:https://blog.csdn.net/dawnto/article/details/135479329
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。