Python torch 模块,size() 实例源码

我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用torch.size()

项目:categorical-dqn    作者:floringogianu    | 项目源码 | 文件源码
def _batch2torch(self, batch, batch_size):
        """ List of transitions -> Batch of transitions -> pytorch tensors.

            Returns:
                states: torch.size([batch_size, hist_len, w, h])
                a/r/d: torch.size([batch_size, 1])
        """
        # check-out pytorch dqn tutorial.
        # (t1, t2, ... tn) -> t((s1, s2, ..., sn), (a1, a2, ... an) ...)
        batch = BatchTransition(*zip(*batch))

        # lists to tensors
        state_batch = torch.cat(batch.state, 0).type(self.dtype.FT) / 255
        action_batch = self.dtype.LT(batch.action).unsqueeze(1)
        reward_batch = self.dtype.FT(batch.reward).unsqueeze(1)
        next_state_batch = torch.cat(batch.state_, 0).type(self.dtype.FT) / 255
        # [False, False, True, False] -> [1, 1, 0, 1]::ByteTensor
        mask = 1 - self.dtype.BT(batch.done).unsqueeze(1)

        return [batch_size, state_batch, action_batch, reward_batch,
                next_state_batch, mask]
项目:categorical-dqn    作者:floringogianu    | 项目源码 | 文件源码
def __init__(self, env, env_type, hist_len, state_dims, cuda=None):
        super(PreprocessFrames, self).__init__(env)

        self.env_type = env_type
        self.state_dims = state_dims
        self.hist_len = hist_len
        self.env_wh = self.env.observation_space.shape[0:2]
        self.env_ch = self.env.observation_space.shape[2]
        self.wxh = self.env_wh[0] * self.env_wh[1]

        # need to find a better way
        if self.env_type == "atari":
            self._preprocess = self._atari_preprocess
        elif self.env_type == "catch":
            self._preprocess = self._catch_preprocess
        print("[Preprocess Wrapper] for %s with state history of %d frames."
              % (self.env_type, hist_len))

        self.cuda = False if cuda is None else cuda
        self.dtype = dtype = TorchTypes(self.cuda)
        self.rgb = dtype.FT([.2126, .7152, .0722])

        # torch.size([1, 4, 24, 24])
        """
        self.hist_state = torch.FloatTensor(1, hist_len, *state_dims)
        self.hist_state.fill_(0)
        """

        self.d = OrderedDict({i: torch.FloatTensor(1, 1, *state_dims).fill_(0)
                              for i in range(hist_len)})