Aloha 机械臂的学习记录4——act:detr_vae.py的代码部分

发布时间:2024年01月19日

detr_vae.py的原始代码如下:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer

import numpy as np

import IPython
e = IPython.embed


def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size()).normal_())
    return mu + std * eps


def get_sinusoid_encoding_table(n_position, d_hid):
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)


class DETRVAE(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names, vq, vq_class, vq_dim, action_dim):
        """ Initializes the model.
        Parameters:
            backbones: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            state_dim: robot state dimension of the environment
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.num_queries = num_queries
        self.camera_names = camera_names
        self.transformer = transformer
        self.encoder = encoder
        self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
        self.state_dim, self.action_dim = state_dim, action_dim
        hidden_dim = transformer.d_model
        self.action_head = nn.Linear(hidden_dim, action_dim)
        self.is_pad_head = nn.Linear(hidden_dim, 1)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        if backbones is not None:
            self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
            self.backbones = nn.ModuleList(backbones)
            self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
        else:
            # input_dim = 14 + 7 # robot_state + env_state
            self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
            self.input_proj_env_state = nn.Linear(7, hidden_dim)
            self.pos = torch.nn.Embedding(2, hidden_dim)
            self.backbones = None

        # encoder extra parameters
        self.latent_dim = 32 # final size of latent z # TODO tune
        self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
        self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
        self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim)  # project qpos to embedding

        print(f'Use VQ: {self.vq}, {self.vq_class}, {self.vq_dim}')
        if self.vq:
            self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
        else:
            self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
        self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq

        # decoder extra parameters
        if self.vq:
            self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
        else:
            self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
        self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent


    def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
        bs, _ = qpos.shape
        if self.encoder is None:
            latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
            latent_input = self.latent_out_proj(latent_sample)
            probs = binaries = mu = logvar = None
        else:
            # cvae encoder
            is_training = actions is not None # train or val
            ### Obtain latent z from action sequence
            if is_training:
                # project action sequence to embedding dim, and concat with a CLS token
                action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
                qpos_embed = self.encoder_joint_proj(qpos)  # (bs, hidden_dim)
                qpos_embed = torch.unsqueeze(qpos_embed, axis=1)  # (bs, 1, hidden_dim)
                cls_embed = self.cls_embed.weight # (1, hidden_dim)
                cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
                encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
                encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
                # do not mask cls token
                cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
                is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1)  # (bs, seq+1)
                # obtain position embedding
                pos_embed = self.pos_table.clone().detach()
                pos_embed = pos_embed.permute(1, 0, 2)  # (seq+1, 1, hidden_dim)
                # query model
                encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
                encoder_output = encoder_output[0] # take cls output only
                latent_info = self.latent_proj(encoder_output)
                
                if self.vq:
                    logits = latent_info.reshape([*latent_info.shape[:-1], self.vq_class, self.vq_dim])
                    probs = torch.softmax(logits, dim=-1)
                    binaries = F.one_hot(torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(-1), self.vq_dim).view(-1, self.vq_class, self.vq_dim).float()
                    binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
                    probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
                    straigt_through = binaries_flat - probs_flat.detach() + probs_flat
                    latent_input = self.latent_out_proj(straigt_through)
                    mu = logvar = None
                else:
                    probs = binaries = None
                    mu = latent_info[:, :self.latent_dim]
                    logvar = latent_info[:, self.latent_dim:]
                    latent_sample = reparametrize(mu, logvar)
                    latent_input = self.latent_out_proj(latent_sample)

            else:
                mu = logvar = binaries = probs = None
                if self.vq:
                    latent_input = self.latent_out_proj(vq_sample.view(-1, self.vq_class * self.vq_dim))
                else:
                    latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
                    latent_input = self.latent_out_proj(latent_sample)

        return latent_input, probs, binaries, mu, logvar

    def forward(self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        latent_input, probs, binaries, mu, logvar = self.encode(qpos, actions, is_pad, vq_sample)

        # cvae decoder
        if self.backbones is not None:
            # Image observation features and position embeddings
            all_cam_features = []
            all_cam_pos = []
            for cam_id, cam_name in enumerate(self.camera_names):
                features, pos = self.backbones[cam_id](image[:, cam_id])
                features = features[0] # take the last layer feature
                pos = pos[0]
                all_cam_features.append(self.input_proj(features))
                all_cam_pos.append(pos)
            # proprioception features
            proprio_input = self.input_proj_robot_state(qpos)
            # fold camera dimension into width dimension
            src = torch.cat(all_cam_features, axis=3)
            pos = torch.cat(all_cam_pos, axis=3)
            hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
        else:
            qpos = self.input_proj_robot_state(qpos)
            env_state = self.input_proj_env_state(env_state)
            transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
            hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
        a_hat = self.action_head(hs)
        is_pad_hat = self.is_pad_head(hs)
        return a_hat, is_pad_hat, [mu, logvar], probs, binaries



class CNNMLP(nn.Module):
    def __init__(self, backbones, state_dim, camera_names):
        """ Initializes the model.
        Parameters:
            backbones: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            state_dim: robot state dimension of the environment
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.camera_names = camera_names
        self.action_head = nn.Linear(1000, state_dim) # TODO add more
        if backbones is not None:
            self.backbones = nn.ModuleList(backbones)
            backbone_down_projs = []
            for backbone in backbones:
                down_proj = nn.Sequential(
                    nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
                    nn.Conv2d(128, 64, kernel_size=5),
                    nn.Conv2d(64, 32, kernel_size=5)
                )
                backbone_down_projs.append(down_proj)
            self.backbone_down_projs = nn.ModuleList(backbone_down_projs)

            mlp_in_dim = 768 * len(backbones) + state_dim
            self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=self.action_dim, hidden_depth=2)
        else:
            raise NotImplementedError

    def forward(self, qpos, image, env_state, actions=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        is_training = actions is not None # train or val
        bs, _ = qpos.shape
        # Image observation features and position embeddings
        all_cam_features = []
        for cam_id, cam_name in enumerate(self.camera_names):
            features, pos = self.backbones[cam_id](image[:, cam_id])
            features = features[0] # take the last layer feature
            pos = pos[0] # not used
            all_cam_features.append(self.backbone_down_projs[cam_id](features))
        # flatten everything
        flattened_features = []
        for cam_feature in all_cam_features:
            flattened_features.append(cam_feature.reshape([bs, -1]))
        flattened_features = torch.cat(flattened_features, axis=1) # 768 each
        features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
        a_hat = self.mlp(features)
        return a_hat


def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
        mods.append(nn.Linear(hidden_dim, output_dim))
    trunk = nn.Sequential(*mods)
    return trunk


def build_encoder(args):
    d_model = args.hidden_dim # 256
    dropout = args.dropout # 0.1
    nhead = args.nheads # 8
    dim_feedforward = args.dim_feedforward # 2048
    num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
    normalize_before = args.pre_norm # False
    activation = "relu"

    encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                            dropout, activation, normalize_before)
    encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
    encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

    return encoder


def build(args):
    state_dim = 14 # TODO hardcode

    # From state
    # backbone = None # from state for now, no need for conv nets
    # From image
    backbones = []
    for _ in args.camera_names:
        backbone = build_backbone(args)
        backbones.append(backbone)

    transformer = build_transformer(args)

    if args.no_encoder:
        encoder = None
    else:
        encoder = build_transformer(args)

    model = DETRVAE(
        backbones,
        transformer,
        encoder,
        state_dim=state_dim,
        num_queries=args.num_queries,
        camera_names=args.camera_names,
        vq=args.vq,
        vq_class=args.vq_class,
        vq_dim=args.vq_dim,
        action_dim=args.action_dim,
    )

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of parameters: %.2fM" % (n_parameters/1e6,))

    return model

def build_cnnmlp(args):
    state_dim = 14 # TODO hardcode

    # From state
    # backbone = None # from state for now, no need for conv nets
    # From image
    backbones = []
    for _ in args.camera_names:
        backbone = build_backbone(args)
        backbones.append(backbone)

    model = CNNMLP(
        backbones,
        state_dim=state_dim,
        camera_names=args.camera_names,
    )

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of parameters: %.2fM" % (n_parameters/1e6,))

    return model

让我们通读这段代码:

函数reparametrize部分:

def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size()).normal_())
    return mu + std * eps

????????这个 reparametrize 函数是变分自编码器(VAE)中的一个重要部分,用于实现所谓的“重参数化技巧”。重参数化技巧是一种在训练VAE时常用的方法,它允许模型在训练过程中通过反向传播更新其参数。下面是对这个函数的详细解释:

函数参数

  • mu:均值向量,代表编码后的潜在空间中的均值。
  • logvar:对数方差向量,它是方差的对数形式。使用对数方差而不是直接使用方差可以提供数值稳定性,特别是在计算标准差时。

函数操作

计算标准差

  • std = logvar.div(2).exp()

  • 这一步将对数方差除以2,然后对结果取指数,得到方差的平方根,即标准差。在正态分布中,方差是标准差的平方,因此这里先除以2,再取指数得到标准差。

生成随机噪声

  • eps = Variable(std.data.new(std.size()).normal_())

    这一步生成与标准差具有相同形状的随机噪声(从标准正态分布中采样)。Variable 是 PyTorch 中的一个类,用于封装张量,以便自动计算梯度。这里使用 std.data.new(std.size()).normal_() 生成正态分布的随机数。

重参数化操作

  • return mu + std * eps

  • 最后,将均值与标准差和随机噪声的乘积相加。这实际上是从以 mu 为均值、std 为标准差的正态分布中采样。这种方法使得采样操作可导,允许在训练过程中通过反向传播算法更新 mulogvar

作用

????????使用重参数化技巧可以使VAE的训练通过随机梯度下降(或其他优化算法)进行,因为它允许模型在训练过程中反向传播梯度。这对于训练生成模型如VAE至关重要,因为它允许模型学习如何编码输入数据到一个潜在的、连续的表示空间中,并从这个空间中有效地生成新的样本。

函数get_sinusoid_encoding_table部分:

def get_sinusoid_encoding_table(n_position, d_hid):
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

????????这段代码定义了一个函数 get_sinusoid_encoding_table,用于生成正弦波编码表(Sinusoidal Positional Encoding),这是在 Transformer 模型中用于位置编码的一种方法。这种编码方式是为了使模型能够利用序列中元素的顺序信息。下面是对这个函数的详细解释:

函数参数

  • n_position:表示编码表中的位置数,即序列的最大长度。
  • d_hid:表示隐藏层的维度,即编码向量的大小。

函数操作

  1. 定义获取位置角度向量的函数

get_position_angle_vec(position):这个内部函数为给定的位置生成一个角度向量。向量中的每个元素对应于该位置的不同维度。对于每个维度 hid_j,该位置的角度计算为 position / (10000^(2 * hid_j / d_hid))。这种计算方式确保了不同位置的角度变化在所有维度上是不同的,从而让模型能够区分序列中不同的位置。

  1. 生成正弦波编码表

创建一个数组 sinusoid_table,其中包含从 0 到 n_position-1 的每个位置的角度向量。

对于表中的偶数索引维度(0::2),使用 np.sin 函数应用正弦变换。

对于表中的奇数索引维度(1::2),使用 np.cos 函数应用余弦变换。

返回值

返回一个经过正弦和余弦变换的编码表,并使用 torch.FloatTensor 将其转换为 PyTorch 张量,并通过 unsqueeze(0) 增加一个维度,这通常用于批处理。

作用

????????这种正弦波位置编码方式为 Transformer 模型提供了一种有效的方式来编码序列中元素的位置信息。由于 Transformer 模型本身不包含任何递归或卷积层,因此无法自然地处理序列数据中的顺序信息。通过添加正弦波位置编码,模型能够利用位置信息来更好地理解和处理序列数据。这种编码方式是 Transformer 架构的一个关键组成部分,广泛应用于自然语言处理和其他序列处理任务中。

DETRVAE类:

class DETRVAE(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names, vq, vq_class, vq_dim, action_dim):
        """ Initializes the model.
        Parameters:
            backbones: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            state_dim: robot state dimension of the environment
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.num_queries = num_queries
        self.camera_names = camera_names
        self.transformer = transformer
        self.encoder = encoder
        self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
        self.state_dim, self.action_dim = state_dim, action_dim
        hidden_dim = transformer.d_model
        self.action_head = nn.Linear(hidden_dim, action_dim)
        self.is_pad_head = nn.Linear(hidden_dim, 1)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        if backbones is not None:
            self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
            self.backbones = nn.ModuleList(backbones)
            self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
        else:
            # input_dim = 14 + 7 # robot_state + env_state
            self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
            self.input_proj_env_state = nn.Linear(7, hidden_dim)
            self.pos = torch.nn.Embedding(2, hidden_dim)
            self.backbones = None

        # encoder extra parameters
        self.latent_dim = 32 # final size of latent z # TODO tune
        self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
        self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
        self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim)  # project qpos to embedding

        print(f'Use VQ: {self.vq}, {self.vq_class}, {self.vq_dim}')
        if self.vq:
            self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
        else:
            self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
        self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq

        # decoder extra parameters
        if self.vq:
            self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
        else:
            self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
        self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent


    def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
        bs, _ = qpos.shape
        if self.encoder is None:
            latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
            latent_input = self.latent_out_proj(latent_sample)
            probs = binaries = mu = logvar = None
        else:
            # cvae encoder
            is_training = actions is not None # train or val
            ### Obtain latent z from action sequence
            if is_training:
                # project action sequence to embedding dim, and concat with a CLS token
                action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
                qpos_embed = self.encoder_joint_proj(qpos)  # (bs, hidden_dim)
                qpos_embed = torch.unsqueeze(qpos_embed, axis=1)  # (bs, 1, hidden_dim)
                cls_embed = self.cls_embed.weight # (1, hidden_dim)
                cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
                encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
                encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
                # do not mask cls token
                cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
                is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1)  # (bs, seq+1)
                # obtain position embedding
                pos_embed = self.pos_table.clone().detach()
                pos_embed = pos_embed.permute(1, 0, 2)  # (seq+1, 1, hidden_dim)
                # query model
                encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
                encoder_output = encoder_output[0] # take cls output only
                latent_info = self.latent_proj(encoder_output)
                
                if self.vq:
                    logits = latent_info.reshape([*latent_info.shape[:-1], self.vq_class, self.vq_dim])
                    probs = torch.softmax(logits, dim=-1)
                    binaries = F.one_hot(torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(-1), self.vq_dim).view(-1, self.vq_class, self.vq_dim).float()
                    binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
                    probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
                    straigt_through = binaries_flat - probs_flat.detach() + probs_flat
                    latent_input = self.latent_out_proj(straigt_through)
                    mu = logvar = None
                else:
                    probs = binaries = None
                    mu = latent_info[:, :self.latent_dim]
                    logvar = latent_info[:, self.latent_dim:]
                    latent_sample = reparametrize(mu, logvar)
                    latent_input = self.latent_out_proj(latent_sample)

            else:
                mu = logvar = binaries = probs = None
                if self.vq:
                    latent_input = self.latent_out_proj(vq_sample.view(-1, self.vq_class * self.vq_dim))
                else:
                    latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
                    latent_input = self.latent_out_proj(latent_sample)

        return latent_input, probs, binaries, mu, logvar

    def forward(self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        latent_input, probs, binaries, mu, logvar = self.encode(qpos, actions, is_pad, vq_sample)

        # cvae decoder
        if self.backbones is not None:
            # Image observation features and position embeddings
            all_cam_features = []
            all_cam_pos = []
            for cam_id, cam_name in enumerate(self.camera_names):
                features, pos = self.backbones[cam_id](image[:, cam_id])
                features = features[0] # take the last layer feature
                pos = pos[0]
                all_cam_features.append(self.input_proj(features))
                all_cam_pos.append(pos)
            # proprioception features
            proprio_input = self.input_proj_robot_state(qpos)
            # fold camera dimension into width dimension
            src = torch.cat(all_cam_features, axis=3)
            pos = torch.cat(all_cam_pos, axis=3)
            hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
        else:
            qpos = self.input_proj_robot_state(qpos)
            env_state = self.input_proj_env_state(env_state)
            transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
            hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
        a_hat = self.action_head(hs)
        is_pad_hat = self.is_pad_head(hs)
        return a_hat, is_pad_hat, [mu, logvar], probs, binaries

????????这段代码定义了一个名为 DETRVAE 的类,它继承自 PyTorch 的 nn.Module,并且似乎是一种结合了变分自编码器(VAE)和 Transformer 的深度学习模型。这种模型可能用于处理包含图像、位置和动作序列的复杂数据。让我们逐步解析这个类的主要部分:

初始化函数 __init__

  • 参数:包括用于提取特征的神经网络(backbones)、Transformer模型、编码器、状态维度、查询数量、摄像头名称、以及与变分量化(VQ)相关的参数。
  • 初始化各种网络层和嵌入,例如:
    • action_headis_pad_head 是线性层,用于最终的动作预测。
    • query_embed 是嵌入层,用于处理对象查询。
    • 根据是否提供了 backbones,选择不同的特征提取方法。

编码器函数 encode

????????这个函数实现了 VAE 编码器的功能,将输入数据(如机器人的位置 qpos 和动作序列 actions)编码为潜在空间的表示。

  • 变分量化(VQ):如果启用,使用特殊的量化技术来处理潜在表示。
  • 重参数化技巧:在标准VAE中使用,将编码的均值和对数方差转换为潜在样本。

前向传播函数 forward

  • 处理图像数据、位置信息、环境状态和动作序列。
  • 如果有 backbones,使用这些网络来处理图像特征并与位置和潜在输入结合。
  • 如果没有 backbones,直接处理位置和环境状态。
  • 使用 Transformer 模型处理合成的特征。
  • 最终,通过动作头 action_head 和填充标记头 is_pad_head 生成动作预测和填充标记预测。

DETRVAE类中的函数__init__部分:

    def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names, vq, vq_class, vq_dim, action_dim):
        """ Initializes the model.
        Parameters:
            backbones: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            state_dim: robot state dimension of the environment
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.num_queries = num_queries
        self.camera_names = camera_names
        self.transformer = transformer
        self.encoder = encoder
        self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
        self.state_dim, self.action_dim = state_dim, action_dim
        hidden_dim = transformer.d_model
        self.action_head = nn.Linear(hidden_dim, action_dim)
        self.is_pad_head = nn.Linear(hidden_dim, 1)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        if backbones is not None:
            self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
            self.backbones = nn.ModuleList(backbones)
            self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
        else:
            # input_dim = 14 + 7 # robot_state + env_state
            self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
            self.input_proj_env_state = nn.Linear(7, hidden_dim)
            self.pos = torch.nn.Embedding(2, hidden_dim)
            self.backbones = None

        # encoder extra parameters
        self.latent_dim = 32 # final size of latent z # TODO tune
        self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
        self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
        self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim)  # project qpos to embedding

        print(f'Use VQ: {self.vq}, {self.vq_class}, {self.vq_dim}')
        if self.vq:
            self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
        else:
            self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
        self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq

        # decoder extra parameters
        if self.vq:
            self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
        else:
            self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
        self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent

????????这段代码定义了 DETRVAE 类的初始化方法,是一个结合了变分自编码器(VAE)和 Transformer 架构的深度学习模型的构造函数。该模型设计用于处理包含图像、状态信息和动作序列的复杂数据。以下是对初始化方法的详细解析:

初始化函数 __init__

该函数用于初始化 DETRVAE 模型的各个组件。

  • 参数

    • backbones:用于特征提取的卷积神经网络(CNN)模块。
    • transformer:Transformer 架构的模块。
    • encoder:用于编码输入数据的模块。
    • state_dim:环境中机器人状态的维度。
    • num_queries:对象查询的数量,即 DETR 可以在单个图像中检测的最大对象数量。
    • camera_names:摄像头名称列表。
    • vq, vq_class, vq_dim:与变分量化相关的参数。
    • action_dim:动作的维度。
  • 模型组件

    • self.num_queries:存储对象查询的数量。
    • self.transformerself.encoder:存储传入的 Transformer 和编码器模块。
    • self.action_headself.is_pad_head:线性层,用于动作预测和填充(pad)标记预测。
    • self.query_embed:嵌入层,用于处理对象查询。
  • 特征提取

    • 如果提供了 backbones,使用它们来处理图像特征,并通过 self.input_proj 进行投影。
    • 否则,直接处理状态信息。
  • 变分编码器组件

    • self.latent_dim:潜在空间的维度。
    • self.cls_embed:额外的分类(cls)标记嵌入。
    • self.encoder_action_projself.encoder_joint_proj:线性层,用于将动作和位置(qpos)投影到嵌入空间。
    • self.latent_proj:线性层,用于将隐藏状态投影到潜在的标准差和方差上。
    • self.register_buffer('pos_table', ...):注册一个正弦波位置编码表。
  • 解码器额外参数

    • self.latent_out_proj:线性层,用于将潜在样本投影回嵌入空间。
    • self.additional_pos_embed:用于位置和潜在输入的学习位置嵌入。

?

DETRVAE类中的函数encode部分:

    def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
        bs, _ = qpos.shape
        if self.encoder is None:
            latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
            latent_input = self.latent_out_proj(latent_sample)
            probs = binaries = mu = logvar = None
        else:
            # cvae encoder
            is_training = actions is not None # train or val
            ### Obtain latent z from action sequence
            if is_training:
                # project action sequence to embedding dim, and concat with a CLS token
                action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
                qpos_embed = self.encoder_joint_proj(qpos)  # (bs, hidden_dim)
                qpos_embed = torch.unsqueeze(qpos_embed, axis=1)  # (bs, 1, hidden_dim)
                cls_embed = self.cls_embed.weight # (1, hidden_dim)
                cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
                encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
                encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
                # do not mask cls token
                cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
                is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1)  # (bs, seq+1)
                # obtain position embedding
                pos_embed = self.pos_table.clone().detach()
                pos_embed = pos_embed.permute(1, 0, 2)  # (seq+1, 1, hidden_dim)
                # query model
                encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
                encoder_output = encoder_output[0] # take cls output only
                latent_info = self.latent_proj(encoder_output)
                
                if self.vq:
                    logits = latent_info.reshape([*latent_info.shape[:-1], self.vq_class, self.vq_dim])
                    probs = torch.softmax(logits, dim=-1)
                    binaries = F.one_hot(torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(-1), self.vq_dim).view(-1, self.vq_class, self.vq_dim).float()
                    binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
                    probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
                    straigt_through = binaries_flat - probs_flat.detach() + probs_flat
                    latent_input = self.latent_out_proj(straigt_through)
                    mu = logvar = None
                else:
                    probs = binaries = None
                    mu = latent_info[:, :self.latent_dim]
                    logvar = latent_info[:, self.latent_dim:]
                    latent_sample = reparametrize(mu, logvar)
                    latent_input = self.latent_out_proj(latent_sample)

            else:
                mu = logvar = binaries = probs = None
                if self.vq:
                    latent_input = self.latent_out_proj(vq_sample.view(-1, self.vq_class * self.vq_dim))
                else:
                    latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
                    latent_input = self.latent_out_proj(latent_sample)

        return latent_input, probs, binaries, mu, logvar

????????这段代码是 DETRVAE 类中 encode 方法的实现,它负责将输入数据编码为潜在空间的表示。这个方法是变分自编码器(VAE)和变分量化(VQ)技术的结合。下面是对这个方法的详细解释:

函数参数

  • qpos:机器人或实体的位置信息。
  • actions:动作序列。
  • is_pad:用于指示序列中填充(padding)部分的标记。
  • vq_sample:变分量化的样本。

编码过程

判断编码器是否存在

  • 如果没有提供编码器(self.encoder 为空),则创建一个零向量作为潜在样本,并通过投影层(self.latent_out_proj)转换。

使用编码器

  • 判断是否在训练模式(is_training),即检查是否提供了动作序列(actions)。
  • 对于训练模式:
    • 将动作序列通过一个线性层(self.encoder_action_proj)投影到嵌入空间。
    • 对位置信息(qpos)执行类似的投影(self.encoder_joint_proj)。
    • 将类别(CLS)标记的嵌入添加到序列的开始。
    • 组合这些嵌入,并将它们输入到编码器。
    • 使用位置编码(pos_embed)和可选的填充掩码(is_pad)。
    • 获取编码器的输出,并将其投影到潜在空间的表示。

变分量化(VQ)或标准VAE处理

  • 如果使用 VQ:
    • 将编码器输出转换为离散的概率分布。
    • 从这些概率中采样以得到二进制表示。
    • 将这些二进制表示投影回嵌入空间。
  • 否则(标准 VAE):
    • 从编码器输出中提取均值(mu)和对数方差(logvar)。
    • 使用重参数化技巧来生成潜在样本。
    • 将潜在样本投影回嵌入空间。

非训练模式

  • 如果不是训练模式,根据是否使用 VQ,创建一个零向量或使用提供的 VQ 样本来生成潜在输入。

返回值

  • 返回潜在输入、概率、二进制表示、均值和对数方差。

DETRVAE类中的函数forward部分:

    def forward(self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        latent_input, probs, binaries, mu, logvar = self.encode(qpos, actions, is_pad, vq_sample)

        # cvae decoder
        if self.backbones is not None:
            # Image observation features and position embeddings
            all_cam_features = []
            all_cam_pos = []
            for cam_id, cam_name in enumerate(self.camera_names):
                features, pos = self.backbones[cam_id](image[:, cam_id])
                features = features[0] # take the last layer feature
                pos = pos[0]
                all_cam_features.append(self.input_proj(features))
                all_cam_pos.append(pos)
            # proprioception features
            proprio_input = self.input_proj_robot_state(qpos)
            # fold camera dimension into width dimension
            src = torch.cat(all_cam_features, axis=3)
            pos = torch.cat(all_cam_pos, axis=3)
            hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
        else:
            qpos = self.input_proj_robot_state(qpos)
            env_state = self.input_proj_env_state(env_state)
            transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
            hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
        a_hat = self.action_head(hs)
        is_pad_hat = self.is_pad_head(hs)
        return a_hat, is_pad_hat, [mu, logvar], probs, binaries

????? 这段代码定义了 DETRVAE 类的 forward 方法,它实现了模型的前向传播过程,即如何处理输入数据并生成输出。该方法结合了变分自编码器(VAE)和 Transformer 架构。下面是对这个方法的详细解释:

函数参数

  • qpos:机器人或实体的位置信息,维度为 [batch, qpos_dim]。
  • qvel:机器人或实体的速度信息。
  • effort:可能表示机器人或实体的努力水平或其他类似的测量值。
  • image:图像数据,维度为 [batch, num_cam, channel, height, width]。
  • env_state:环境状态,这里未使用(None)。
  • actions:动作序列,维度为 [batch, seq, action_dim]。
  • is_pad:填充(padding)标记,用于标识序列中的填充部分。
  • vq_sample:变分量化的样本。

编码过程

  • 调用 self.encode 方法对输入数据进行编码,生成潜在的表示(latent_input)、概率(probs)、二进制表示(binaries)、均值(mu)和对数方差(logvar)。

解码过程

  • 使用特征提取网络(如果 backbones 不为空):
    • 对每个摄像头图像使用卷积网络(backbones)提取特征。
    • 将提取的图像特征和位置信息进行投影(self.input_proj)和组合。
    • 将这些特征与潜在输入和机器人状态(通过 self.input_proj_robot_state 处理)一起传递给 Transformer 模型。
  • 直接处理位置和环境状态(如果没有 backbones):
    • 将位置信息和环境状态信息通过投影层处理。
    • 将这些信息作为输入传递给 Transformer 模型。

输出

  • 使用 Transformer 模型的输出,通过 action_headis_pad_head 生成动作预测和填充标记预测。
  • 返回动作预测(a_hat)、填充标记预测(is_pad_hat)、以及编码阶段生成的统计数据和概率信息。

CNNMLP类:

class CNNMLP(nn.Module):
    def __init__(self, backbones, state_dim, camera_names):
        """ Initializes the model.
        Parameters:
            backbones: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            state_dim: robot state dimension of the environment
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.camera_names = camera_names
        self.action_head = nn.Linear(1000, state_dim) # TODO add more
        if backbones is not None:
            self.backbones = nn.ModuleList(backbones)
            backbone_down_projs = []
            for backbone in backbones:
                down_proj = nn.Sequential(
                    nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
                    nn.Conv2d(128, 64, kernel_size=5),
                    nn.Conv2d(64, 32, kernel_size=5)
                )
                backbone_down_projs.append(down_proj)
            self.backbone_down_projs = nn.ModuleList(backbone_down_projs)

            mlp_in_dim = 768 * len(backbones) + state_dim
            self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=self.action_dim, hidden_depth=2)
        else:
            raise NotImplementedError

    def forward(self, qpos, image, env_state, actions=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        is_training = actions is not None # train or val
        bs, _ = qpos.shape
        # Image observation features and position embeddings
        all_cam_features = []
        for cam_id, cam_name in enumerate(self.camera_names):
            features, pos = self.backbones[cam_id](image[:, cam_id])
            features = features[0] # take the last layer feature
            pos = pos[0] # not used
            all_cam_features.append(self.backbone_down_projs[cam_id](features))
        # flatten everything
        flattened_features = []
        for cam_feature in all_cam_features:
            flattened_features.append(cam_feature.reshape([bs, -1]))
        flattened_features = torch.cat(flattened_features, axis=1) # 768 each
        features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
        a_hat = self.mlp(features)
        return a_hat

????????这段代码定义了一个名为 CNNMLP 的类,它继承自 PyTorch 的 nn.Module。这个类似乎是为了实现一个结合卷积神经网络(CNN)和多层感知机(MLP)的模型,主要用于处理图像和状态信息,并输出动作预测。以下是对这个类的详细解析:

初始化函数 __init__

  • 参数

    • backbones:用于特征提取的 CNN 模块列表。
    • state_dim:环境中机器人状态的维度。
    • camera_names:摄像头名称列表。
  • 模型组件

    • self.action_head:线性层,输出维度为 state_dim,用于动作预测。
    • self.backbones:CNN 模块列表,用于从图像中提取特征。
    • self.backbone_down_projs:CNN 下采样投影,用于将特征图的维度降低。
    • self.mlp:多层感知机,用于从合并的特征中生成动作预测。

前向传播函数 forward

  • 参数

    • qpos:机器人或实体的位置信息,维度为 [batch, qpos_dim]。
    • image:图像数据,维度为 [batch, num_cam, channel, height, width]。
    • env_state:环境状态,这里未使用。
    • actions:动作序列,用于判断是否处于训练模式。
  • 处理过程

    • 提取每个摄像头的图像特征。
    • 使用定义在 self.backbone_down_projs 中的下采样投影将特征图的维度降低。
    • 将所有摄像头的特征平铺并合并。
    • 将合并的特征与位置信息(qpos)拼接。
    • 通过 MLP(self.mlp)处理合并的特征以生成动作预测(a_hat)。

CNNMLP类中的函数__init__部分:

    def __init__(self, backbones, state_dim, camera_names):
        """ Initializes the model.
        Parameters:
            backbones: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            state_dim: robot state dimension of the environment
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.camera_names = camera_names
        self.action_head = nn.Linear(1000, state_dim) # TODO add more
        if backbones is not None:
            self.backbones = nn.ModuleList(backbones)
            backbone_down_projs = []
            for backbone in backbones:
                down_proj = nn.Sequential(
                    nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
                    nn.Conv2d(128, 64, kernel_size=5),
                    nn.Conv2d(64, 32, kernel_size=5)
                )
                backbone_down_projs.append(down_proj)
            self.backbone_down_projs = nn.ModuleList(backbone_down_projs)

            mlp_in_dim = 768 * len(backbones) + state_dim
            self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=self.action_dim, hidden_depth=2)
        else:
            raise NotImplementedError

????????这段代码定义了一个名为 CNNMLP 的类的初始化函数,该类继承自 PyTorch 的 nn.ModuleCNNMLP 类是一个深度学习模型,它结合了卷积神经网络(CNN)和多层感知机(MLP)来处理图像和状态信息。以下是对初始化函数的详细解析:

初始化函数 __init__

  • 参数

    • backbones:一个 PyTorch 模块列表,每个模块是一个卷积神经网络,用于图像特征提取。
    • state_dim:表示环境中机器人状态的维度。
    • camera_names:摄像头名称列表,用于标识不同的图像输入。
  • 模型组件初始化

    • self.camera_names:存储传入的摄像头名称。
    • self.action_head:一个线性层,其输入维度为 1000(这里似乎是硬编码的值,可能需要根据实际情况调整),输出维度为 state_dim,用于最终的动作预测。
    • self.backbones:包含传入的 CNN 模块的模块列表。
    • self.backbone_down_projs:对于每个 CNN 模块,创建一个下采样投影序列,该序列包括三个卷积层,用于逐步减少特征图的维度。
    • self.mlp:一个 MLP,用于处理合并后的图像特征和状态信息。其输入维度是基于所有 CNN 提取的特征维度之和加上状态维度计算的。

功能

  • 这个类被设计为先通过多个卷积网络提取图像特征,然后通过下采样投影降低特征维度。
  • 提取的特征与机器人的状态信息结合,然后通过一个 MLP 网络来生成动作预测。
  • 这种模型可能用于机器人控制、自动化任务执行或其他需要同时处理图像和状态信息的应用。

注意事项

  • 模型的输入和输出维度:在实际应用中,可能需要根据特定任务和数据调整模型的输入和输出维度,例如 self.action_head 的输入维度和 self.mlp 的各个参数。
  • backbones 模块:需要确保 backbones 中的每个模块都有一个 num_channels 属性,这个属性表示该模块输出特征图的通道数。
  • 代码中提到的 "transformer" 和 "num_queries" 等参数在初始化函数中未被使用,可能是遗留的注释或预留的扩展点。

CNNMLP类中的函数forward部分:

    def forward(self, qpos, image, env_state, actions=None):
        """
        qpos: batch, qpos_dim
        image: batch, num_cam, channel, height, width
        env_state: None
        actions: batch, seq, action_dim
        """
        is_training = actions is not None # train or val
        bs, _ = qpos.shape
        # Image observation features and position embeddings
        all_cam_features = []
        for cam_id, cam_name in enumerate(self.camera_names):
            features, pos = self.backbones[cam_id](image[:, cam_id])
            features = features[0] # take the last layer feature
            pos = pos[0] # not used
            all_cam_features.append(self.backbone_down_projs[cam_id](features))
        # flatten everything
        flattened_features = []
        for cam_feature in all_cam_features:
            flattened_features.append(cam_feature.reshape([bs, -1]))
        flattened_features = torch.cat(flattened_features, axis=1) # 768 each
        features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
        a_hat = self.mlp(features)
        return a_hat

????????这段代码定义了 CNNMLP 类的 forward 方法,它实现了模型的前向传播过程,即如何处理输入数据并生成动作预测。该方法主要涉及图像特征提取和多层感知机(MLP)的应用。以下是对这个方法的详细解释:

函数参数

  • qpos:机器人或实体的位置信息,维度为 [batch, qpos_dim]。
  • image:图像数据,维度为 [batch, num_cam, channel, height, width]。
  • env_state:环境状态,这里未使用(None)。
  • actions:动作序列,维度为 [batch, seq, action_dim]。这个参数用于判断是否处于训练模式。

图像特征提取

  • 遍历每个摄像头,使用对应的 backbones CNN 模块提取图像特征。
  • 使用 self.backbone_down_projs 中定义的下采样投影进一步处理每个摄像头的特征。
  • 将所有摄像头的特征平铺(flatten)并合并。

动作预测

  • 将平铺的图像特征与位置信息(qpos)拼接。
  • 通过 MLP(self.mlp)处理合并的特征以生成动作预测(a_hat)。

返回值

  • 返回动作预测 a_hat

函数mlp部分:

def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
        mods.append(nn.Linear(hidden_dim, output_dim))
    trunk = nn.Sequential(*mods)
    return trunk

????????这段代码定义了一个名为 mlp 的函数,用于构建一个多层感知机(MLP)网络。这个网络由多个线性层(全连接层)和非线性激活函数(ReLU)组成。以下是对这个函数的详细解释:

函数参数

  • input_dim:输入层的维度。
  • hidden_dim:隐藏层的维度。
  • output_dim:输出层的维度。
  • hidden_depth:隐藏层的数量。

构建过程

  1. 无隐藏层hidden_depth == 0):

    • 如果没有隐藏层,函数只创建一个从输入维度到输出维度的线性层。
  2. 有隐藏层

    • 如果有一个或多个隐藏层,函数首先创建一个从输入维度到隐藏维度的线性层,后面跟着一个 ReLU 激活函数。
    • 然后,对于每一个额外的隐藏层,添加一个从隐藏维度到隐藏维度的线性层,后面跟着一个 ReLU 激活函数。
    • 最后,添加一个从隐藏维度到输出维度的线性层。
  3. 组合模块

    • 使用 nn.Sequential 将所有创建的模块(线性层和激活函数)按顺序组合在一起,形成完整的 MLP 网络。

返回值

  • 函数返回构建好的 MLP 网络。

函数build_encoder部分:

def build_encoder(args):
    d_model = args.hidden_dim # 256
    dropout = args.dropout # 0.1
    nhead = args.nheads # 8
    dim_feedforward = args.dim_feedforward # 2048
    num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
    normalize_before = args.pre_norm # False
    activation = "relu"

    encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                            dropout, activation, normalize_before)
    encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
    encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

    return encoder

????????这段代码定义了一个名为 build_encoder 的函数,用于构建一个 Transformer 编码器。这个函数根据提供的参数来配置和创建编码器。以下是对这个函数的详细解释:

函数参数

  • args:一个包含多个配置参数的对象。这些参数可能是从命令行解析得到的,或者在某个配置文件中定义。

参数配置

  • d_model:隐藏层的维度。
  • dropout:在层中使用的 dropout 比率。
  • nhead:多头注意力机制中的头数。
  • dim_feedforward:前馈网络中的维度。
  • num_encoder_layers:编码器中的层数。
  • normalize_before:是否在每个子层之前进行层归一化(Layer Normalization)。
  • activation:激活函数的类型,在这个例子中是 "relu"。

构建过程

  1. 创建单个编码器层

    • 使用 TransformerEncoderLayer 创建一个编码器层,配置它使用上述参数。
  2. 创建层归一化(如果启用):

    • 如果 normalize_before 为真,则创建一个 nn.LayerNorm 层用于归一化。
  3. 创建编码器

    • 使用 TransformerEncoder 创建编码器,包含多个编码器层和可选的层归一化。

返回值

  • 函数返回构建好的 Transformer 编码器。

函数build部分:

def build(args):
    state_dim = 14 # TODO hardcode

    # From state
    # backbone = None # from state for now, no need for conv nets
    # From image
    backbones = []
    for _ in args.camera_names:
        backbone = build_backbone(args)
        backbones.append(backbone)

    transformer = build_transformer(args)

    if args.no_encoder:
        encoder = None
    else:
        encoder = build_transformer(args)

    model = DETRVAE(
        backbones,
        transformer,
        encoder,
        state_dim=state_dim,
        num_queries=args.num_queries,
        camera_names=args.camera_names,
        vq=args.vq,
        vq_class=args.vq_class,
        vq_dim=args.vq_dim,
        action_dim=args.action_dim,
    )

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of parameters: %.2fM" % (n_parameters/1e6,))

    return model

????????这段代码定义了一个名为 build 的函数,用于构建一个名为 DETRVAE 的复合模型,结合了卷积神经网络(用于图像处理)、Transformer 架构(用于序列数据处理)和变分自编码器(VAE)。以下是对这个函数的详细解释:

函数参数

  • args:一个包含多个配置参数的对象。

构建过程

  1. 设置状态维度

    • state_dim 被设置为 14,这是机器人或环境状态的维度。
  2. 构建图像处理的卷积网络(CNN)背景模型

    • 对于 args.camera_names 中的每个摄像头,使用 build_backbone 函数构建一个卷积网络,并将其添加到 backbones 列表中。
  3. 构建 Transformer 模型

    • 调用 build_transformer 函数构建 Transformer 模型。
  4. 条件性地构建编码器

    • 如果 args.no_encoder 为真,则不构建编码器,否则使用 build_transformer 函数构建编码器。
  5. 构建 DETRVAE 模型

    • 使用上述构建的组件以及从 args 中提取的其他参数来初始化 DETRVAE 类的实例。
  6. 计算模型的参数数量

    • 计算并打印模型的可训练参数总数。

返回值

  • 返回构建好的 DETRVAE 模型实例。

函数build_cnnmlp部分:

def build_cnnmlp(args):
    state_dim = 14 # TODO hardcode

    # From state
    # backbone = None # from state for now, no need for conv nets
    # From image
    backbones = []
    for _ in args.camera_names:
        backbone = build_backbone(args)
        backbones.append(backbone)

    model = CNNMLP(
        backbones,
        state_dim=state_dim,
        camera_names=args.camera_names,
    )

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of parameters: %.2fM" % (n_parameters/1e6,))

    return model

????????这段代码定义了一个名为 build_cnnmlp 的函数,用于构建一个名为 CNNMLP 的深度学习模型。该模型结合了卷积神经网络(CNN)和多层感知机(MLP),主要用于处理图像和状态信息。以下是对这个函数的详细解释:

函数参数

  • args:一个包含多个配置参数的对象。

构建过程

  1. 设置状态维度

    • state_dim 被设置为 14,这可能是机器人或环境状态的维度。
  2. 构建图像处理的卷积网络(CNN)背景模型

    • 对于 args.camera_names 中的每个摄像头,使用 build_backbone 函数构建一个卷积网络,并将其添加到 backbones 列表中。
  3. 构建 CNNMLP 模型

    • 使用上述构建的 backbones,状态维度 state_dim 和摄像头名称 args.camera_names 来初始化 CNNMLP 类的实例。
  4. 计算模型的参数数量

    • 计算并打印模型的可训练参数总数。

返回值

  • 返回构建好的 CNNMLP 模型实例。
文章来源:https://blog.csdn.net/qq_54900679/article/details/135702186
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。