Python torch 模块,dsmm() 实例源码

我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用torch.dsmm()

项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_dsmm(self):
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            y = self.randn(dj, dk)

            res = torch.dsmm(x, y)
            expected = torch.mm(x.to_dense(), y)
            self.assertEqual(res, expected)

        test_shape(7, 5, 3)
        test_shape(1000, 100, 100)
        test_shape(3000, 64, 300)
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_dsmm(self):
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            y = self.randn(dj, dk)

            res = torch.dsmm(x, y)
            expected = torch.mm(x.to_dense(), y)
            self.assertEqual(res, expected)

        test_shape(7, 5, 3)
        test_shape(1000, 100, 100)
        test_shape(3000, 64, 300)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_dsmm(self):
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            y = self.randn(dj, dk)

            res = torch.dsmm(x, y)
            expected = torch.mm(x.to_dense(), y)
            self.assertEqual(res, expected)

        test_shape(7, 5, 3)
        test_shape(1000, 100, 100)
        test_shape(3000, 64, 300)
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def test_interpolation():
    x = torch.linspace(0.01, 1, 100)
    grid = torch.linspace(-0.05, 1.05, 50)
    J, C = Interpolation().interpolate(grid, x)
    W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid))
    test_func_grid = grid.pow(2)
    test_func_x = x.pow(2)

    interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze()

    assert all(torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def _derivative_quadratic_form_factory(self, *args):
        def closure(left_vectors, right_vectors):
            if left_vectors.ndimension() == 1:
                left_factor = left_vectors.unsqueeze(0)
                right_factor = right_vectors.unsqueeze(0)
            else:
                left_factor = left_vectors
                right_factor = right_vectors
            if len(args) == 1:
                columns, = args
                return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor),
            elif len(args) == 3:
                columns, W_left, W_right = args
                left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()

                res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
                return res, None, None
            elif len(args) == 4:
                columns, W_left, W_right, added_diag, = args
                diag_grad = columns.new(len(added_diag)).zero_()
                diag_grad[0] = (left_factor * right_factor).sum()

                left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()

                res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
                return res, None, None, diag_grad

        return closure
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def forward(self, dense):
        if self.sparse.ndimension() == 3:
            return bdsmm(self.sparse, dense)
        else:
            return torch.dsmm(self.sparse, dense)
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def backward(self, grad_output):
        if self.sparse.ndimension() == 3:
            return bdsmm(self.sparse.transpose(1, 2), grad_output)
        else:
            return torch.dsmm(self.sparse.t(), grad_output)
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_dsmm(self):
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            y = self.randn(dj, dk)

            res = torch.dsmm(x, y)
            expected = torch.mm(self.safeToDense(x), y)
            self.assertEqual(res, expected)

        test_shape(7, 5, 3)
        test_shape(1000, 100, 100)
        test_shape(3000, 64, 300)
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def kp_interpolated_toeplitz_matmul(toeplitz_columns, tensor, interp_left=None, interp_right=None, noise_diag=None):
    """
    Given an interpolated matrix interp_left * T_1 \otimes ... \otimes T_d * interp_right, plus possibly an additional
    diagonal component s*I, compute a product with some tensor or matrix tensor, where T_i is
    symmetric Toeplitz matrices.

    Args:
        - toeplitz_columns (d x m matrix) - columns of d toeplitz matrix T_i with
          length n_i
        - interp_left (sparse matrix nxm) - Left interpolation matrix
        - interp_right (sparse matrix pxm) - Right interpolation matrix
        - tensor (matrix p x k) - Vector (k=1) or matrix (k>1) to multiply WKW with
        - noise_diag (tensor p) - If not none, add (s*I)tensor to WKW at the end.

    Returns:
        - tensor
    """
    output_dims = tensor.ndimension()
    noise_term = None

    if output_dims == 1:
        tensor = tensor.unsqueeze(1)

    if noise_diag is not None:
        noise_term = noise_diag.unsqueeze(1).expand_as(tensor) * tensor

    if interp_left is not None:
        # Get interp_{r}^{T} tensor
        interp_right_tensor = torch.dsmm(interp_right.t(), tensor)
        # Get (T interp_{r}^{T}) tensor
        rhs = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, interp_right_tensor)

        # Get (interp_{l} T interp_{r}^{T})tensor
        output = torch.dsmm(interp_left, rhs)
    else:
        output = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, tensor)

    if noise_term is not None:
        # Get (interp_{l} T interp_{r}^{T} + \sigma^{2}I)tensor
        output = output + noise_term

    if output_dims == 1:
        output = output.squeeze(1)
    return output