Python torch 模块,round() 实例源码

我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用torch.round()

项目:pytorch-dnc    作者:jingweiz    | 项目源码 | 文件源码
def visual(self, input_ts, target_ts, mask_ts, output_ts=None):
        """
        input_ts:  [(num_wordsx2+2) x batch_size x (len_word+2)]
        target_ts: [(num_wordsx2+2) x batch_size x (len_word)]
        mask_ts:   [(num_wordsx2+2) x batch_size x (len_word)]
        output_ts: [(num_wordsx2+2) x batch_size x (len_word)]
        """
        output_ts = torch.round(output_ts * mask_ts) if output_ts is not None else None
        input_strings  = [self._readable(input_ts[:, 0, i])  for i in range(input_ts.size(2))]
        target_strings = [self._readable(target_ts[:, 0, i]) for i in range(target_ts.size(2))]
        mask_strings   = [self._readable(mask_ts[:, 0, 0])]
        output_strings = [self._readable(output_ts[:, 0, i]) for i in range(output_ts.size(2))] if output_ts is not None else None
        input_strings  = 'Input:\n'  + '\n'.join(input_strings)
        target_strings = 'Target:\n' + '\n'.join(target_strings)
        mask_strings   = 'Mask:\n'   + '\n'.join(mask_strings)
        output_strings = 'Output:\n' + '\n'.join(output_strings) if output_ts is not None else None
        # strings = [input_strings, target_strings, mask_strings, output_strings]
        # self.logger.warning(input_strings)
        # self.logger.warning(target_strings)
        # self.logger.warning(mask_strings)
        # self.logger.warning(output_strings)
        print(input_strings)
        print(target_strings)
        print(mask_strings)
        print(output_strings) if output_ts is not None else None
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def test_round(self):
        self._testMath(torch.round, round)
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_round(self):
        self._testMath(torch.round, round)
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_round(self):
        self._testMath(torch.round, round)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_round(self):
        self._testMath(torch.round, round)
项目:pytorch-dnc    作者:jingweiz    | 项目源码 | 文件源码
def visual(self, input_ts, target_ts, mask_ts, output_ts=None):
        """
        input_ts:  [(num_wordsx(repeats+1)+3) x batch_size x (len_word+2)]
        target_ts: [(num_wordsx(repeats+1)+3) x batch_size x (len_word+1)]
        mask_ts:   [(num_wordsx(repeats+1)+3) x batch_size x (len_word+1)]
        output_ts: [(num_wordsx(repeats+1)+3) x batch_size x (len_word+1)]
        """
        input_ts  = self._unnormalize_repeats(input_ts)
        output_ts = torch.round(output_ts * mask_ts) if output_ts is not None else None
        input_strings  = [self._readable(input_ts[:, 0, i])  for i in range(input_ts.size(2))]
        target_strings = [self._readable(target_ts[:, 0, i]) for i in range(target_ts.size(2))]
        mask_strings   = [self._readable(mask_ts[:, 0, 0])]
        output_strings = [self._readable(output_ts[:, 0, i]) for i in range(output_ts.size(2))] if output_ts is not None else None
        input_strings  = 'Input:\n'  + '\n'.join(input_strings)
        target_strings = 'Target:\n' + '\n'.join(target_strings)
        mask_strings   = 'Mask:\n'   + '\n'.join(mask_strings)
        output_strings = 'Output:\n' + '\n'.join(output_strings) if output_ts is not None else None
        # strings = [input_strings, target_strings, mask_strings, output_strings]
        # self.logger.warning(input_strings)
        # self.logger.warning(target_strings)
        # self.logger.warning(mask_strings)
        # self.logger.warning(output_strings)
        print(input_strings)
        print(target_strings)
        print(mask_strings)
        print(output_strings) if output_ts is not None else None
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_round(self):
        self._testMath(torch.round, round)
项目:ktorch    作者:farizrahman4u    | 项目源码 | 文件源码
def round(x):
    y = get_op(lambda x: torch.round(x))(x)
    return y