Deap 框架细节介绍

发布时间:2023年12月21日

创建一个 gp.PrimitiveSet 对象,对象名为 MAIN,自变量为 3

pset = gp.PrimitiveSet("MAIN", 3)
print(pset)
<deap.gp.PrimitiveSet object at 0x000001FBE182AB20>

gp.py(均为产生函数集与节点集)

PrimitiveSet

class PrimitiveSet(PrimitiveSetTyped):
    """Class same as :class:`~deap.gp.PrimitiveSetTyped`, except there is no
    definition of type.
    """

    def __init__(self, name, arity, prefix="ARG"): # 此时 name = MAIN, arity = 3,
        # 创建 arity 个类对象,案例中为 3,
        args = [__type__] * arity
        PrimitiveSetTyped.__init__(self, name, args, __type__, prefix)

在这里插入图片描述

PrimitiveSetTyped

class PrimitiveSetTyped(object):
    """Class that contains the primitives that can be used to solve a
    Strongly Typed GP problem. The set also defined the researched
    function return type, and input arguments type and number.
    """

    # name = MAIN,in_types=args(为每个变量创建的类组成的列表), ret_type = __type__(元类)
    def __init__(self, name, in_types, ret_type, prefix="ARG"):
        # 存储终端节点的字典
        self.terminals = defaultdict(list)

        # defaultdict(list) 为默认 value 为 list 的字典
        self.primitives = defaultdict(list) # 后续可通过pset.primitives[pset.ret][i] 依次提取函数, .name 可查看函数名
        self.arguments = []

        # setting "__builtins__" to None avoid the context
        # being polluted by builtins function when evaluating
        # GP expression.
        self.context = {"__builtins__": None}
        self.mapping = dict()

        # 存储终端节点个数
        self.terms_count = 0

        # 存储函数节点个数
        self.prims_count = 0

        # PrimitiveSet 类名
        self.name = name

        self.ret = ret_type # __type__(元类)

        # 自变量集合
        self.ins = in_types
        for i, type_ in enumerate(in_types):
            # 为每个自变量从 0 编号,如这里的 ARG0, ARG1, ARG2
            arg_str = "{prefix}{index}".format(prefix=prefix, index=i)

            # 将 arg_str 添加到 arguments 列表中
            self.arguments.append(arg_str)

            # 将每个自变量转换为终端节点类类型,
            term = Terminal(arg_str, True, type_)

            # 将 term 分为终端节点与函数节点,分别添加到 dict_ 中
            self._add(term)

            # 终端节点数 + 1
            self.terms_count += 1

Terminal

class Terminal(object):
    """Class that encapsulates terminal primitive in expression. Terminals can
    be values or 0-arity functions.
    """
    __slots__ = ('name', 'value', 'ret', 'conv_fct')

    def __init__(self, terminal, symbolic, ret):
        self.ret = ret # 节点类型,__type__
        self.value = terminal # ARG0
        self.name = str(terminal) # 'ARG0'
        self.conv_fct = str if symbolic else repr # 见详解 1 处

_add

def _add(self, prim):
    def addType(dict_, ret_type):
        if ret_type not in dict_:
            new_list = []
            for type_, list_ in list(dict_.items()):
                if issubclass(type_, ret_type):
                    for item in list_:
                        if item not in new_list:
                            new_list.append(item)
            dict_[ret_type] = new_list

    addType(self.primitives, prim.ret) # primitives 函数节点组成的字典,prim.ret 当前传入的节点类型
    addType(self.terminals, prim.ret) # self.terminals,终端节点组成的字典

    self.mapping[prim.name] = prim # prim.name = ARG0

    # 判断 prim 是否为 函数类型,即 Primitive
    if isinstance(prim, Primitive): # isinstance 判断是否为同一类型
        for type_ in prim.args:
            addType(self.primitives, type_)
            addType(self.terminals, type_)
        dict_ = self.primitives
    else:
        # 否则为终端节点类型
        dict_ = self.terminals

    for type_ in dict_:
        if issubclass(prim.ret, type_): # 判断 prim.ret 是否为 type_ 的子类
            dict_[type_].append(prim) # 将 prim 添加到字典 dict_ 中

addPrimitive(PrimitiveSet)

# arity 函数参数个数,如-为2,+为2,log为1,exp为1, ariry=2,name='minus'
def addPrimitive(self, primitive, arity, name=None):
    """Add primitive *primitive* with arity *arity* to the set.
    If a name *name* is provided, it will replace the attribute __name__
    attribute to represent/identify the primitive.
    """
    assert arity > 0, "arity should be >= 1"
    # 参数产生的类的组合,比如这里 - 含有两个参数,则 args = [<class 'object'>, <class 'object'>]
    args = [__type__] * arity
    # primitive:定义的函数,args = [<class 'object'>, <class 'object'>], __type__ = 元类, name='minus'
    PrimitiveSetTyped.addPrimitive(self, primitive, args, __type__, name)

addPrimitive(PrimitiveSetTyped)

# in_types = [<class 'object'>, <class 'object'>], ret_type = __type__(元类)
def addPrimitive(self, primitive, in_types, ret_type, name=None):
    """Add a primitive to the set.

    :param primitive: callable object or a function.
    :parma in_types: list of primitives arguments' type
    :param ret_type: type returned by the primitive.
    :param name: alternative name for the primitive instead
                 of its __name__ attribute.
    """
    if name is None:
    	"""此处设置 name 参数,默认为定义时的函数名,如 def add(int a, int b), 则 name 默认为 add"""
        name = primitive.__name__
    prim = Primitive(name, in_types, ret_type)

    assert name not in self.context or \
           self.context[name] is primitive, \
        "Primitives are required to have a unique name. " \
        "Consider using the argument 'name' to rename your " \
        "second '%s' primitive." % (name,)
    # 判断 prim 是否从未被添加进 _dict 字典中,若没有,则添加进去
    self._add(prim)
    self.context[prim.name] = primitive # 用于表示此名称的节点为函数节点
    self.prims_count += 1 # 函数节点个数 + 1

Primitive

class Primitive(object):
    """Class that encapsulates a primitive and when called with arguments it
    returns the Python code to call the primitive with the arguments.

        >>> pr = Primitive("mul", (int, int), int)
        >>> pr.format(1, 2)
        'mul(1, 2)'
    """
    # 定义实例化此类后能访问的属性,元组以外的属性均不可访问
    __slots__ = ('name', 'arity', 'args', 'ret', 'seq')

    def __init__(self, name, args, ret):
        self.name = name # 函数名,minus
        self.arity = len(args) # 参数个数,2
        self.args = args # 参数个数定义的类组成的集合,[<class 'object'>, <class 'object'>]
        self.ret = ret # 当前类属性 <class 'object'>
        args = ", ".join(map("{{{0}}}".format, list(range(self.arity)))) # 为每个参数编号, args = '{0}, {1}'
        
        # 'minus({0}, {1})'
        self.seq = "{name}({args})".format(name=self.name, args=args)

小技巧

# 在函数集中随便选择一个,
prim = random.choice(pset.primitives[type_]) # type_ = pset.ret

# 在终点集中随机选择一个节点
term = random.choice(pset.terminals[type_])
  • type_ = pset.ret:表示原语集合的类型
  • pset.primitives:函数集,用字典表示
pset.ret # <class 'object'>
pset.primitives # 函数集,字典类型
pset.primitives[pset.ret] # 在字典类型中提取 pset.ret(<class 'object'>),即可得到所有函数集中的函数集合
# 寻找所有函数组成的列表中,函数名 == "add",对应的索引位置,通常会返回一个列表,比如[1], 通过在代码最后加一个[0]提取, 即可得到 1 这个值
index = [primitive_index for primitive_index in range(len(pset.primitives[pset.ret])) if pset.primitives[pset.ret][primitive_index].name == "add"]

提取所有终端节点,返回一个所有终端节点组成的集合

pset.terminals # # 终端集,字典类型
pset.terminals[pset.ret]

提取以 Node 索引对应的节点为根节点的子树

tree=PrimitiveTree(genHalfAndHalf(pset, 1, 3))
for Node in range(len(tree)):
	tree[tree.searchSubtree(Node)] # tree.searchSubtree(Node) 返回的是一个切片 slice

获取函数集中第 i 个加入到函数集中的函数的函数名:

pset.primitives[pset.ret][i].name

所有的自变量组成的列表,[ARG0, ARG1, …]

pset.arguments

详解 1

这是一个Python中的三元表达式(ternary expression),用于根据条件选择不同的表达式。在这个三元表达式中,strrepr是Python中的两个内置函数,分别用于将对象转换为字符串和计算对象的可打印表示。symbolic是一个条件,如果为True,则选择str,否则选择repr

具体来说,这个三元表达式的意义是:

如果symbolic为True,则返回一个字符串表示;否则返回对象的可打印表示。

示例代码:

x = 10
symbolic = True
result = str(x) if symbolic else repr(x)
print(result)  # 输出 '10'

在上面的代码中,由于symbolic为True,所以选择了str(x),即将x转换为字符串'10'。因此,result的值为'10'
如果将symbolic改为False,则会选择repr(x),即返回x的可打印表示:

x = 10
symbolic = False
result = str(x) if symbolic else repr(x)
print(result)  # 输出 '10'

在这个例子中,由于symbolic为False,所以选择了repr(x),即返回x的可打印表示10。因此,result的值为10

终端节点参数:

  • ret:节点的类型
  • value:节点值
  • name:节点名称
  • conv_fct:节点信息显示

整体代码

import random

import numpy
from deap import gp, creator, base, tools


def minus(left,right):
    return numpy.subtract(left,right)

pset = gp.PrimitiveSet("MAIN", 3)
print(pset)

pset.addPrimitive(minus, 2, 'minus')

pset.renameArguments(ARG0='x0')

# 定义terminal set
pset.addTerminal(3)  # 常数
pset.addEphemeralConstant('num',lambda:random.randint(-5,5))  # 随机数:(name,随机数函数)

# 三种生成方法,full,grow,half(primitive set,min_depth,max_depth)
# Deap 框架中树的最大深度是91, 最小深度是0(只有一个叶子节点)
expr1 = gp.genFull(pset, 1, 3)    # list
expr2 = gp.genGrow(pset, 1, 3)
expr3 = gp.genHalfAndHalf(pset, 1, 3)
tree = gp.PrimitiveTree(expr1)  # 将表达式形式转换成树形结构

# 编译表达式:必须经过编译才能进行运算
function1 = gp.compile(tree, pset)
# function2 = gp.compile(expr2,pset)
result = function1(1, 2)   # 输入变量值能够得到该表达式的结果
print('result:', result)

gp.genFull、gp.genGrow、gp.genHalfAndHalf

  • gp.genFull:产生一棵满二叉树,树高等于树的深度时停止产生树,先从函数集中选择,当树高等于树的深度时,再从终端集中进行选择,最终选出的节点顺序存储在列表中。整体为一个类结构:[<deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.num object at 0x0000026F4903C270>, <deap.gp.Terminal object at 0x0000026F49038880>,
  • 注: 三个方法最终返回的都是一个列表,列表依次存储了树形结构的节点
def genFull(pset, min_, max_, type_=None):
    """Generate an expression where each leaf has the same depth
    between *min* and *max*.

    :param pset: Primitive set from which primitives are selected.
    :param min_: Minimum height of the produced trees.
    :param max_: Maximum Height of the produced trees.
    :param type_: The type that should return the tree when called, when
                  :obj:`None` (default) the type of :pset: (pset.ret)
                  is assumed.
    :returns: A full tree with all leaves at the same depth.
    """

    def condition(height, depth): # 终止条件:产生树的节点深度恰好等于树高
        """Expression generation stops when the depth is equal to height."""
        return depth == height

    return generate(pset, min_, max_, condition, type_)

generate

def generate(pset, min_, max_, condition, type_=None):
    """创建一颗树:树中的节点均存储在列表中,当满足 condition 条件时停止.

    :param pset: Primitive set from which primitives are selected.
    :param min_: 产生的树的最小高度
    :param max_: 产生的树的最大高度
    :param condition: The condition is a function that takes two arguments,
                      the height of the tree to build and the current
                      depth in the tree.
    :param type_: 返回树的类型,默认应该为 pset: (pset.ret)

    :returns: A grown tree with leaves at possibly different depths
              depending on the condition function.
    """
    if type_ is None:
        type_ = pset.ret #  pset.ret 为节点类别 <deap.gp.PrimitiveSet object at 0x000001C5DC245160>
    expr = [] # 将在函数集与终端集选出的节点放入到 expr 中
    height = random.randint(min_, max_) # 定义树高:从最小深度与最大深度之间随机选择一个值
    stack = [(0, type_)] # 利用栈结构来产生树结构,初始时深度为 0,节点为 type_ 类型
    while len(stack) != 0:
        depth, type_ = stack.pop()
        if condition(height, depth):
            try:
                # =在终点集中随机选择一个节点
                term = random.choice(pset.terminals[type_])
            except IndexError:
                _, _, traceback = sys.exc_info()
                raise IndexError("The gp.generate function tried to add " \
                                  "a terminal of type '%s', but there is " \
                                  "none available." % (type_,)).with_traceback(traceback)
            if isclass(term): # 需要判断是否为一个类类型
                term = term()
            expr.append(term) # 将选出的终点集添加进去
        else:
            try:
                prim = random.choice(pset.primitives[type_]) # 从函数集中随机选择一个节点
            except IndexError:
                _, _, traceback = sys.exc_info()
                raise IndexError("The gp.generate function tried to add " \
                                  "a primitive of type '%s', but there is " \
                                  "none available." % (type_,)).with_traceback(traceback)
            expr.append(prim) # 将选出的函数添加到 expr 中
            for arg in reversed(prim.args): # 遍历函数参数类
                stack.append((depth + 1, arg)) # 添加到栈中,后续产生更深的树
    return expr # 由许多节点组成的列表

gp.PrimitiveTree(产生树结构)

将:[<deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.num object at 0x0000026F4903C270>, <deap.gp.Terminal object at 0x0000026F49038880>, <deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.Terminal object at 0x0000026F49038740>, <deap.gp.Terminal object at 0x0000026F49038880>, <deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.Terminal object at 0x0000026F49038740>, <deap.gp.num object at 0x0000026F4903C360>, <deap.gp.Primitive object at 0x0000026F49035D10>, <deap.gp.Terminal object at 0x0000026F49038740>, <deap.gp.Terminal object at 0x0000026F49038740>]

转化为: minus(minus(minus(-1, 3), minus(ARG1, 3)), minus(minus(ARG1, 1), minus(ARG1, ARG1)))

gp.compile(tree, pset)

def compile(expr, pset):
    """Compile the expression *expr*.

    :param expr: Expression to compile. It can either be a PrimitiveTree,
                 a string of Python code or any object that when
                 converted into string produced a valid Python code
                 expression.
    :param pset: Primitive set against which the expression is compile.
    :returns: a function if the primitive set has 1 or more arguments,
              or return the results produced by evaluating the tree.
    """
    code = str(expr)
    if len(pset.arguments) > 0:
        # This section is a stripped version of the lambdify
        # function of SymPy 0.6.6.
        args = ",".join(arg for arg in pset.arguments)
        code = "lambda {args}: {code}".format(args=args, code=code)
    try:
        return eval(code, pset.context, {})
    except MemoryError:
        _, _, traceback = sys.exc_info()
        raise MemoryError("DEAP : Error in tree evaluation :"
                            " Python cannot evaluate a tree higher than 90. "
                            "To avoid this problem, you should use bloat control on your "
                            "operators. See the DEAP documentation for more information. "
                            "DEAP will now abort.").with_traceback(traceback)

这段代码实现了一个编译函数compile,用于将表达式编译为可执行的Python代码。函数接受两个参数:expr表示要编译的表达式(一个列表),pset表示该表达式所使用的原语集合。

如果pset中定义了一个或多个参数,编译后的代码将会被包装在一个lambda函数中,并使用原语集合中定义的参数名称。否则,编译后的代码将会直接执行。

在编译完成后,函数使用eval函数将代码字符串转换为可执行的Python代码,并在pset.context命名空间中执行它。如果编译过程中出现内存错误,函数会抛出一个MemoryError异常,并显示相关错误信息。最终,函数返回一个可执行的Python函数或表达式的求值结果。

这个函数实现了将表达式编译成可执行的Python代码的功能,可以用于在遗传算法或其他优化算法中对表达式进行求解。函数接受两个参数:

  • expr:要编译的表达式,可以是一个PrimitiveTree对象、一个Python代码字符串,或者任何能够被转换为有效Python代码表达式的对象。

  • pset:表达式使用的原语集合。原语集合是一个包含基本运算符、函数、常量和变量的集合,用于构建表达式。
    如果原语集合中定义了参数,则编译的代码将会被封装在一个lambda函数中,并使用原语集合中定义的参数名称。否则,编译后的代码将会直接执行。最后,函数使用eval函数将代码字符串转换为可执行的Python代码,并在pset.context命名空间中执行它。如果编译过程中出现内存错误,函数会抛出一个MemoryError异常,并显示相关错误信息。

函数的返回值取决于原语集合中是否定义了参数。如果原语集合中定义了一个或多个参数,则函数将返回一个可执行的Python函数。否则,函数将返回表达式求值的结果。

searchSubtree

def searchSubtree(self, begin):
    """Return a slice object that corresponds to the
    range of values that defines the subtree which has the
    element with index *begin* as its root.
    """
    end = begin + 1
    total = self[begin].arity
    while total > 0:
        total += self[end].arity - 1
        end += 1
    return slice(begin, end)

Toolbox

register

def register(self, alias, function, *args, **kargs):
    """
        >>> def func(a, b, c=3):
        ...     print a, b, c
        ...
        >>> tools = Toolbox()
        >>> tools.register("myFunc", func, 2, c=4)
        >>> tools.myFunc(3)
        2 3 4
    """
    pfunc = partial(function, *args, **kargs) # 
    pfunc.__name__ = alias
    pfunc.__doc__ = function.__doc__

    if hasattr(function, "__dict__") and not isinstance(function, type):
        # Some functions don't have a dictionary, in these cases
        # simply don't copy it. Moreover, if the function is actually
        # a class, we do not want to copy the dictionary.
        pfunc.__dict__.update(function.__dict__.copy())

    setattr(self, alias, pfunc)

References

[1] 【Python】详解类的 repr() 方法
[2] python slots 详解(上篇)

文章来源:https://blog.csdn.net/qq_46450354/article/details/130846779
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。