Python torch 模块,diag() 实例源码


项目:MMD-GAN    作者:OctoberChang    | 项目源码 | 文件源码
def _mix_rbf_kernel(X, Y, sigma_list):
    assert(X.size(0) == Y.size(0))
    m = X.size(0)

    Z =, Y), 0)
    ZZT =, Z.t())
    diag_ZZT = torch.diag(ZZT).unsqueeze(1)
    Z_norm_sqr = diag_ZZT.expand_as(ZZT)
    exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()

    K = 0.0
    for sigma in sigma_list:
        gamma = 1.0 / (2 * sigma**2)
        K += torch.exp(-gamma * exponent)

    return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)
项目:pytorch-a2c-ppo-acktr    作者:ikostrikov    | 项目源码 | 文件源码
def orthogonal(tensor, gain=1):
    if tensor.ndimension() < 2:
        raise ValueError("Only tensors with 2 or more dimensions are supported")

    rows = tensor.size(0)
    cols = tensor[0].numel()
    flattened = torch.Tensor(rows, cols).normal_(0, 1)

    if rows < cols:

    # Compute the qr factorization
    q, r = torch.qr(flattened)
    # Make Q uniform according to
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph.expand_as(q)

    if rows < cols:

    return tensor
项目:e2e-model-learning    作者:locuslab    | 项目源码 | 文件源码
def __init__(self, params, eps=1e-2):
        super(SolveNewsvendor, self).__init__()
        k = len(params['d'])
        self.Q = Variable(torch.diag(torch.Tensor(
            [params['c_quad']] + [params['b_quad']]*k + [params['h_quad']]*k)) \
        self.p = Variable(torch.Tensor(
            [params['c_lin']] + [params['b_lin']]*k + [params['h_lin']]*k) \
        self.G = Variable([
  [-torch.ones(k,1), -torch.eye(k), torch.zeros(k,k)], 1),
  [torch.ones(k,1), torch.zeros(k,k), -torch.eye(k)], 1),
            -torch.eye(1 + 2*k)], 0).cuda())
        self.h = Variable(torch.Tensor(
            np.concatenate([-params['d'], params['d'], np.zeros(1+ 2*k)])).cuda()) = Variable(torch.Tensor([1])).cuda()
        self.eps_eye = eps * Variable(torch.eye(1 + 2*k).cuda()).unsqueeze(0)
项目:e2e-model-learning    作者:locuslab    | 项目源码 | 文件源码
def forward(self, y):
        nBatch, k = y.size()

        Q_scale =[torch.diag(
            [, y[i], y[i]])).unsqueeze(0) for i in range(nBatch)], 0)
        Q = self.Q.unsqueeze(0).expand_as(Q_scale).mul(Q_scale)
        p_scale =[Variable(torch.ones(nBatch,1).cuda()), y, y], 1)
        p = self.p.unsqueeze(0).expand_as(p_scale).mul(p_scale)
        G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))
        h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))
        e = Variable(torch.Tensor().cuda()).double()

        out = QPFunction(verbose=False)\
            (Q.double(), p.double(), G.double(), h.double(), e, e).float()

        return out[:,:1]
项目:block    作者:bamos    | 项目源码 | 文件源码
def test_np():

    nx, nineq, neq = 4, 6, 7
    Q = npr.randn(nx, nx)
    G = npr.randn(nineq, nx)
    A = npr.randn(neq, nx)
    D = np.diag(npr.rand(nineq))

    K_ = np.bmat((
        (Q, np.zeros((nx, nineq)), G.T, A.T),
        (np.zeros((nineq, nx)), D, np.eye(nineq), np.zeros((nineq, neq))),
        (G, np.eye(nineq), np.zeros((nineq, nineq + neq))),
        (A, np.zeros((neq, nineq + nineq + neq)))

    K = block((
        (Q,   0, G.T, A.T),
        (0,   D, 'I',   0),
        (G, 'I',   0,   0),
        (A,   0,   0,   0)

    assert np.allclose(K_, K)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_constant(self):
        x = Variable(torch.randn(2, 2), requires_grad=True)

        trace = torch._C._tracer_enter((x,), 0)

        y = Variable(torch.diag(torch.Tensor([2, 2])))
        z = x.matmul(y)

        function = torch._C._jit_createAutogradClosure(trace)

        z2 = function()(x)
        self.assertEqual(z, z2)  # make sure the data has been cloned

        x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True)
        z3 = function()(x2)
        self.assertEqual(, torch.ones(2, 2) * 4)
项目:NeuroNLP2    作者:XuezheMax    | 项目源码 | 文件源码
def decode(self, input_word, input_char, input_pos, mask=None, length=None, hx=None, leading_symbolic=0):
        # out_arc shape [batch, length, length]
        out_arc, out_type, mask, length = self.forward(input_word, input_char, input_pos,
                                                       mask=mask, length=length, hx=hx)
        out_arc =
        batch, max_len, _ = out_arc.size()
        # set diagonal elements to -inf
        out_arc = out_arc + torch.diag(
        # set invalid positions to -inf
        if mask is not None:
            # minus_mask = (1 -, max_len, 1)
            minus_mask = (1 -
            out_arc.masked_fill_(minus_mask, -np.inf)

        # compute naive predictions.
        # predition shape = [batch, length]
        _, heads = out_arc.max(dim=1)

        types = self._decode_types(out_type, heads, leading_symbolic)

        return heads.cpu().numpy(),
项目:NeuroNLP2    作者:XuezheMax    | 项目源码 | 文件源码
def forward(self, input_h, input_c, mask=None):

            input_h: Tensor
                the head input tensor with shape = [batch, length, input_size]
            input_c: Tensor
                the child input tensor with shape = [batch, length, input_size]
            mask: Tensor or None
                the mask tensor with shape = [batch, length]
            lengths: Tensor or None
                the length tensor with shape = [batch]

        Returns: Tensor
            the energy tensor with shape = [batch, num_label, length, length]

        batch, length, _ = input_h.size()
        # [batch, num_labels, length, length]
        output = self.attention(input_h, input_c, mask_d=mask, mask_e=mask)
        # set diagonal elements to -inf
        output = output + Variable(torch.diag(
        return output
项目:torchsample    作者:ncullen93    | 项目源码 | 文件源码
def th_corrcoef(x):
    mimics np.corrcoef
    # calculate covariance matrix of rows
    mean_x = th.mean(x, 1)
    xm = x.sub(mean_x.expand_as(x))
    c =
    c = c / (x.size(1) - 1)

    # normalize covariance matrix
    d = th.diag(c)
    stddev = th.pow(d, 0.5)
    c = c.div(stddev.expand_as(c))
    c = c.div(stddev.expand_as(c).t())

    # clamp between -1 and 1
    c = th.clamp(c, -1.0, 1.0)

    return c
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_constant(self):
        x = Variable(torch.randn(2, 2), requires_grad=True)

        trace = torch._C._tracer_enter((x,), 0)

        y = Variable(torch.diag(torch.Tensor([2, 2])))
        z = x.matmul(y)

        function = torch._C._jit_createInterpreterFactory(trace)

        z2 = function()(x)
        self.assertEqual(z, z2)  # make sure the data has been cloned

        x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True)
        z3 = function()(x2)
        self.assertEqual(, torch.ones(2, 2) * 4)
项目:OpenNMT-py    作者:OpenNMT    | 项目源码 | 文件源码
def forward(self, input):
        laplacian = input.exp() + self.eps
        output = input.clone()
        for b in range(input.size(0)):
            lap = laplacian[b].masked_fill(
                Variable(torch.eye(input.size(1)).cuda().ne(0)), 0)
            lap = -lap + torch.diag(lap.sum(0))
            # store roots on diagonal
            lap[0] = input[b].diag().exp()
            inv_laplacian = lap.inverse()

            factor = inv_laplacian.diag().unsqueeze(1)\
                                         .expand_as(input[b]).transpose(0, 1)
            term1 = input[b].exp().mul(factor).clone()
            term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone()
            term1[:, 0] = 0
            term2[0] = 0
            output[b] = term1 - term2
            roots_output = input[b].diag().exp().mul(
                inv_laplacian.transpose(0, 1)[0])
            output[b] = output[b] + torch.diag(roots_output)
        return output
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def test_diag(self):
        x = torch.rand(100, 100)
        res1 = torch.diag(x)
        res2 = torch.Tensor()
        torch.diag(res2, x)
        self.assertEqual(res1, res2)
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def test_eig(self):
        a = torch.Tensor(((1.96,  0.00,  0.00,  0.00,  0.00),
                        (-6.49,  3.80,  0.00,  0.00,  0.00),
                        (-0.47, -6.39,  4.17,  0.00,  0.00),
                        (-7.20,  1.50, -1.51,  5.70,  0.00),
                        (-0.65, -6.34,  2.67,  1.80, -7.10))).t().contiguous()
        e = torch.eig(a)[0]
        ee, vv = torch.eig(a, True)
        te = torch.Tensor()
        tv = torch.Tensor()
        eee, vvv = torch.eig(te, tv, a, True)
        self.assertEqual(e, ee, 1e-12)
        self.assertEqual(ee, eee, 1e-12)
        self.assertEqual(ee, te, 1e-12)
        self.assertEqual(vv, vvv, 1e-12)
        self.assertEqual(vv, tv, 1e-12)

        # test reuse
        X = torch.randn(4,4)
        X =, X)
        e, v = torch.zeros(4,2), torch.zeros(4,4)
        torch.eig(e, v, X, True)
        Xhat =, torch.diag(, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        torch.eig(e, v, X, True)
        Xhat =,, 0).diag(), v.t()))
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        # test non-contiguous
        X = torch.randn(4, 4)
        X =, X)
        e = torch.zeros(4, 2, 2)[:,1]
        v = torch.zeros(4, 2, 4)[:,1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.eig(e, v, X, True)
        Xhat =, torch.diag(, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def test_symeig(self):
        xval = torch.rand(100,3)
        cov =, xval)
        rese = torch.zeros(3)
        resv = torch.zeros(3,3)

        # First call to symeig
        self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
        torch.symeig(rese, resv, cov.clone(), True)
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # Second call to symeig
        self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
        torch.symeig(rese, resv, cov.clone(), True)
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # test non-contiguous
        X = torch.rand(5, 5)
        X = X.t() * X
        e = torch.zeros(4, 2).select(1, 1)
        v = torch.zeros(4, 2, 4)[:,1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.symeig(e, v, X, True)
        Xhat =, torch.diag(e)), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def test_svd(self):
        a=torch.Tensor(((8.79,  6.11, -9.15,  9.57, -3.49,  9.84),
                        (9.93,  6.91, -7.93,  1.64,  4.02,  0.15),
                        (9.83,  5.04,  4.86,  8.83,  9.80, -8.99),
                        (5.45, -0.27,  4.85,  0.74, 10.00, -6.02),
                        (3.16,  7.98,  3.01,  5.80,  4.27, -5.31))).t().clone()
        u, s, v = torch.svd(a)
        uu = torch.Tensor()
        ss = torch.Tensor()
        vv = torch.Tensor()
        uuu, sss, vvv = torch.svd(uu, ss, vv, a)
        self.assertEqual(u, uu, 0, 'torch.svd')
        self.assertEqual(u, uuu, 0, 'torch.svd')
        self.assertEqual(s, ss, 0, 'torch.svd')
        self.assertEqual(s, sss, 0, 'torch.svd')
        self.assertEqual(v, vv, 0, 'torch.svd')
        self.assertEqual(v, vvv, 0, 'torch.svd')

        # test reuse
        X = torch.randn(4, 4)
        U, S, V = torch.svd(X)
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        torch.svd(U, S, V, X)
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        # test non-contiguous
        X = torch.randn(5, 5)
        U = torch.zeros(5, 2, 5)[:,1]
        S = torch.zeros(5, 2)[:,1]
        V = torch.zeros(5, 2, 5)[:,1]

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        self.assertFalse(S.is_contiguous(), 'S is contiguous')
        self.assertFalse(V.is_contiguous(), 'V is contiguous')
        torch.svd(U, S, V, X)
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
项目:pyro    作者:uber    | 项目源码 | 文件源码
def _test_jacobian(self, input_dim, hidden_dim):
        jacobian = torch.zeros(input_dim, input_dim)
        iaf = InverseAutoregressiveFlow(input_dim, hidden_dim, sigmoid_bias=0.5)

        def nonzero(x):
            return torch.sign(torch.abs(x))

        x = Variable(torch.randn(1, input_dim))
        iaf_x = iaf(x)
        for j in range(input_dim):
            for k in range(input_dim):
                epsilon_vector = torch.zeros(1, input_dim)
                epsilon_vector[0, j] = self.epsilon
                iaf_x_eps = iaf(x + Variable(epsilon_vector))
                delta = (iaf_x_eps - iaf_x) / self.epsilon
                jacobian[j, k] = float(delta[0, k].data.cpu().numpy()[0])

        permutation = iaf.get_arn().get_permutation()
        permuted_jacobian = jacobian.clone()
        for j in range(input_dim):
            for k in range(input_dim):
                permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]]

        analytic_ldt = iaf.log_det_jacobian(iaf_x).data.cpu().numpy()[0]
        numeric_ldt = torch.sum(torch.log(torch.diag(permuted_jacobian)))
        ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt)

        diag_sum = torch.sum(torch.diag(nonzero(permuted_jacobian)))
        lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=-1))

        self.assertTrue(ldt_discrepancy < self.epsilon)
        self.assertTrue(diag_sum == float(input_dim))
        self.assertTrue(lower_sum == float(0.0))
项目:MMD-GAN    作者:OctoberChang    | 项目源码 | 文件源码
def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
    m = K_XX.size(0)    # assume X, Y are same shape

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if const_diagonal is not False:
        diag_X = diag_Y = const_diagonal
        sum_diag_X = sum_diag_Y = m * const_diagonal
        diag_X = torch.diag(K_XX)                       # (m,)
        diag_Y = torch.diag(K_YY)                       # (m,)
        sum_diag_X = torch.sum(diag_X)
        sum_diag_Y = torch.sum(diag_Y)

    Kt_XX_sums = K_XX.sum(dim=1) - diag_X             # \tilde{K}_XX * e = K_XX * e - diag_X
    Kt_YY_sums = K_YY.sum(dim=1) - diag_Y             # \tilde{K}_YY * e = K_YY * e - diag_Y
    K_XY_sums_0 = K_XY.sum(dim=0)                     # K_{XY}^T * e

    Kt_XX_sum = Kt_XX_sums.sum()                       # e^T * \tilde{K}_XX * e
    Kt_YY_sum = Kt_YY_sums.sum()                       # e^T * \tilde{K}_YY * e
    K_XY_sum = K_XY_sums_0.sum()                       # e^T * K_{XY} * e

    if biased:
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
            + (Kt_YY_sum + sum_diag_Y) / (m * m)
            - 2.0 * K_XY_sum / (m * m))
        mmd2 = (Kt_XX_sum / (m * (m - 1))
            + Kt_YY_sum / (m * (m - 1))
            - 2.0 * K_XY_sum / (m * m))

    return mmd2
项目:action-detection    作者:yjxiong    | 项目源码 | 文件源码
def forward(self, pred, labels, targets):
        indexer = - 1
        prep = pred[:, indexer, :]
        class_pred =[:, :,  0]).view(-1, 1),
                                torch.diag(prep[:, :, 1]).view(-1, 1)),
        loss = self.smooth_l1_loss(class_pred.view(-1), targets.view(-1)) * 2
        return loss
项目:e2e-model-learning    作者:locuslab    | 项目源码 | 文件源码
def forward(self, z0, mu, dg, d2g):
        nBatch, n = z0.size()

        Q =[torch.diag(d2g[i] + 1).unsqueeze(0) 
            for i in range(nBatch)], 0).double()
        p = (dg - d2g*z0 - mu).double()
        G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))
        h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))

        out = QPFunction(verbose=False)(Q, p, G, h, self.e, self.e)
        return out
项目:e2c-pytorch    作者:ethanluoyc    | 项目源码 | 文件源码
def cov(self):
        """This should only be called when NormalDistribution represents one sample"""
        if self.v is not None and self.r is not None:
            assert self.v.dim() == 1
            dim = self.v.dim()
            v = self.v.unsqueeze(1)  # D * 1 vector
            rt = self.r.unsqueeze(0)  # 1 * D vector
            A = torch.eye(dim) +
            return torch.diag(self.sigma.pow(2))
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def orthogonal(tensor, gain=1):
    """Fills the input Tensor or Variable with a (semi) orthogonal matrix, as described in "Exact solutions to the
    nonlinear dynamics of learning in deep linear neural networks" - Saxe, A. et al. (2013). The input tensor must have
    at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened.

        tensor: an n-dimensional torch.Tensor or autograd.Variable, where n >= 2
        gain: optional scaling factor

        >>> w = torch.Tensor(3, 5)
        >>> nn.init.orthogonal(w)
    if isinstance(tensor, Variable):
        orthogonal(, gain=gain)
        return tensor

    if tensor.ndimension() < 2:
        raise ValueError("Only tensors with 2 or more dimensions are supported")

    rows = tensor.size(0)
    cols = tensor[0].numel()
    flattened = torch.Tensor(rows, cols).normal_(0, 1)
    # Compute the qr factorization
    q, r = torch.qr(flattened)
    # Make Q uniform according to
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph.expand_as(q)
    # Pad zeros to Q (if rows smaller than cols)
    if rows < cols:
        padding = torch.zeros(rows, cols - rows)
        if q.is_cuda:
            q =[q, padding.cuda()], 1)
            q =[q, padding], 1)

    return tensor
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_diag(self):
        x = torch.rand(100, 100)
        res1 = torch.diag(x)
        res2 = torch.Tensor()
        torch.diag(x, out=res2)
        self.assertEqual(res1, res2)
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_eig(self):
        a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
                          (-6.49, 3.80, 0.00, 0.00, 0.00),
                          (-0.47, -6.39, 4.17, 0.00, 0.00),
                          (-7.20, 1.50, -1.51, 5.70, 0.00),
                          (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous()
        e = torch.eig(a)[0]
        ee, vv = torch.eig(a, True)
        te = torch.Tensor()
        tv = torch.Tensor()
        eee, vvv = torch.eig(a, True, out=(te, tv))
        self.assertEqual(e, ee, 1e-12)
        self.assertEqual(ee, eee, 1e-12)
        self.assertEqual(ee, te, 1e-12)
        self.assertEqual(vv, vvv, 1e-12)
        self.assertEqual(vv, tv, 1e-12)

        # test reuse
        X = torch.randn(4, 4)
        X =, X)
        e, v = torch.zeros(4, 2), torch.zeros(4, 4)
        torch.eig(X, True, out=(e, v))
        Xhat =, torch.diag(, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        torch.eig(X, True, out=(e, v))
        Xhat =,, 0).diag(), v.t()))
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        # test non-contiguous
        X = torch.randn(4, 4)
        X =, X)
        e = torch.zeros(4, 2, 2)[:, 1]
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.eig(X, True, out=(e, v))
        Xhat =, torch.diag(, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_symeig(self):
        xval = torch.rand(100, 3)
        cov =, xval)
        rese = torch.zeros(3)
        resv = torch.zeros(3, 3)

        # First call to symeig
        self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # Second call to symeig
        self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # test non-contiguous
        X = torch.rand(5, 5)
        X = X.t() * X
        e = torch.zeros(4, 2).select(1, 1)
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.symeig(X, True, out=(e, v))
        Xhat =, torch.diag(e)), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_svd(self):
        a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
                          (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
                          (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
                          (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
                          (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
        u, s, v = torch.svd(a)
        uu = torch.Tensor()
        ss = torch.Tensor()
        vv = torch.Tensor()
        uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
        self.assertEqual(u, uu, 0, 'torch.svd')
        self.assertEqual(u, uuu, 0, 'torch.svd')
        self.assertEqual(s, ss, 0, 'torch.svd')
        self.assertEqual(s, sss, 0, 'torch.svd')
        self.assertEqual(v, vv, 0, 'torch.svd')
        self.assertEqual(v, vvv, 0, 'torch.svd')

        # test reuse
        X = torch.randn(4, 4)
        U, S, V = torch.svd(X)
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        # test non-contiguous
        X = torch.randn(5, 5)
        U = torch.zeros(5, 2, 5)[:, 1]
        S = torch.zeros(5, 2)[:, 1]
        V = torch.zeros(5, 2, 5)[:, 1]

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        self.assertFalse(S.is_contiguous(), 'S is contiguous')
        self.assertFalse(V.is_contiguous(), 'V is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
项目:dong_iccv_2017    作者:woozzu    | 项目源码 | 文件源码
def pairwise_ranking_loss(margin, x, v):
    zero = torch.zeros(1)
    diag_margin = margin * torch.eye(x.size(0))
    if not args.no_cuda:
        zero, diag_margin = zero.cuda(), diag_margin.cuda()
    zero, diag_margin = Variable(zero), Variable(diag_margin)

    x = x / torch.norm(x, 2, 1, keepdim=True)
    v = v / torch.norm(v, 2, 1, keepdim=True)
    prod = torch.matmul(x, v.transpose(0, 1))
    diag = torch.diag(prod)
    for_x = torch.max(zero, margin - torch.unsqueeze(diag, 1) + prod) - diag_margin
    for_v = torch.max(zero, margin - torch.unsqueeze(diag, 0) + prod) - diag_margin
    return (torch.sum(for_x) + torch.sum(for_v)) / x.size(0)
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def orthogonal(tensor, gain=1):
    """Fills the input Tensor or Variable with a (semi) orthogonal matrix, as described in "Exact solutions to the
    nonlinear dynamics of learning in deep linear neural networks" - Saxe, A. et al. (2013). The input tensor must have
    at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened.

        tensor: an n-dimensional torch.Tensor or autograd.Variable, where n >= 2
        gain: optional scaling factor

        >>> w = torch.Tensor(3, 5)
        >>> nn.init.orthogonal(w)
    if isinstance(tensor, Variable):
        orthogonal(, gain=gain)
        return tensor

    if tensor.ndimension() < 2:
        raise ValueError("Only tensors with 2 or more dimensions are supported")

    rows = tensor.size(0)
    cols = tensor[0].numel()
    flattened = torch.Tensor(rows, cols).normal_(0, 1)
    # Compute the qr factorization
    q, r = torch.qr(flattened)
    # Make Q uniform according to
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph.expand_as(q)
    # Pad zeros to Q (if rows smaller than cols)
    if rows < cols:
        padding = torch.zeros(rows, cols - rows)
        if q.is_cuda:
            q =[q, padding.cuda()], 1)
            q =[q, padding], 1)

    return tensor
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_diag(self):
        x = torch.rand(100, 100)
        res1 = torch.diag(x)
        res2 = torch.Tensor()
        torch.diag(x, out=res2)
        self.assertEqual(res1, res2)
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_eig(self):
        a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
                          (-6.49, 3.80, 0.00, 0.00, 0.00),
                          (-0.47, -6.39, 4.17, 0.00, 0.00),
                          (-7.20, 1.50, -1.51, 5.70, 0.00),
                          (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous()
        e = torch.eig(a)[0]
        ee, vv = torch.eig(a, True)
        te = torch.Tensor()
        tv = torch.Tensor()
        eee, vvv = torch.eig(a, True, out=(te, tv))
        self.assertEqual(e, ee, 1e-12)
        self.assertEqual(ee, eee, 1e-12)
        self.assertEqual(ee, te, 1e-12)
        self.assertEqual(vv, vvv, 1e-12)
        self.assertEqual(vv, tv, 1e-12)

        # test reuse
        X = torch.randn(4, 4)
        X =, X)
        e, v = torch.zeros(4, 2), torch.zeros(4, 4)
        torch.eig(X, True, out=(e, v))
        Xhat =, torch.diag(, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        torch.eig(X, True, out=(e, v))
        Xhat =,, 0).diag(), v.t()))
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        # test non-contiguous
        X = torch.randn(4, 4)
        X =, X)
        e = torch.zeros(4, 2, 2)[:, 1]
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.eig(X, True, out=(e, v))
        Xhat =, torch.diag(, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_symeig(self):
        xval = torch.rand(100, 3)
        cov =, xval)
        rese = torch.zeros(3)
        resv = torch.zeros(3, 3)

        # First call to symeig
        self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # Second call to symeig
        self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # test non-contiguous
        X = torch.rand(5, 5)
        X = X.t() * X
        e = torch.zeros(4, 2).select(1, 1)
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.symeig(X, True, out=(e, v))
        Xhat =, torch.diag(e)), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_svd(self):
        a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
                          (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
                          (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
                          (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
                          (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
        u, s, v = torch.svd(a)
        uu = torch.Tensor()
        ss = torch.Tensor()
        vv = torch.Tensor()
        uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
        self.assertEqual(u, uu, 0, 'torch.svd')
        self.assertEqual(u, uuu, 0, 'torch.svd')
        self.assertEqual(s, ss, 0, 'torch.svd')
        self.assertEqual(s, sss, 0, 'torch.svd')
        self.assertEqual(v, vv, 0, 'torch.svd')
        self.assertEqual(v, vvv, 0, 'torch.svd')

        # test reuse
        X = torch.randn(4, 4)
        U, S, V = torch.svd(X)
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        # test non-contiguous
        X = torch.randn(5, 5)
        U = torch.zeros(5, 2, 5)[:, 1]
        S = torch.zeros(5, 2)[:, 1]
        V = torch.zeros(5, 2, 5)[:, 1]

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        self.assertFalse(S.is_contiguous(), 'S is contiguous')
        self.assertFalse(V.is_contiguous(), 'V is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
项目:block    作者:bamos    | 项目源码 | 文件源码
def test_torch():
    import torch
    from torch.autograd import Variable


    nx, nineq, neq = 4, 6, 7
    Q = torch.randn(nx, nx)
    G = torch.randn(nineq, nx)
    A = torch.randn(neq, nx)
    D = torch.diag(torch.rand(nineq))

    K_ =, torch.zeros(nx, nineq).type_as(Q), G.t(), A.t()), 1),, nx).type_as(Q), D,
                   torch.zeros(nineq, neq).type_as(Q)), 1),, torch.eye(nineq).type_as(Q), torch.zeros(
            nineq, nineq + neq).type_as(Q)), 1),, torch.zeros((neq, nineq + nineq + neq))), 1)

    K = block((
        (Q,   0, G.t(), A.t()),
        (0,   D,   'I',     0),
        (G, 'I',     0,     0),
        (A,   0,     0,     0)

    assert (K - K_).norm() == 0.0
    K = block((
        (Variable(Q),   0, G.t(), Variable(A.t())),
        (0,   Variable(D),   'I',     0),
        (Variable(G), 'I',     0,     0),
        (A,   0,     0,     0)

    assert ( - K_).norm() == 0.0
项目:block    作者:bamos    | 项目源码 | 文件源码
def test_linear_operator():

    nx, nineq, neq = 4, 6, 7
    Q = npr.randn(nx, nx)
    G = npr.randn(nineq, nx)
    A = npr.randn(neq, nx)
    D = np.diag(npr.rand(nineq))

    K_ = np.bmat((
        (Q, np.zeros((nx, nineq)), G.T, A.T),
        (np.zeros((nineq, nx)), D, np.eye(nineq), np.zeros((nineq, neq))),
        (G, np.eye(nineq), np.zeros((nineq, nineq + neq))),
        (A, np.zeros((neq, nineq + nineq + neq)))

    Q_lo = sla.aslinearoperator(Q)
    G_lo = sla.aslinearoperator(G)
    A_lo = sla.aslinearoperator(A)
    D_lo = sla.aslinearoperator(D)

    K = block((
        (Q_lo,    0,    G.T,    A.T),
        (0,    D_lo,    'I',      0),
        (G_lo,  'I',      0,      0),
        (A_lo,    0,      0,      0)
    ), arrtype=sla.LinearOperator)

    w1 = np.random.randn(K_.shape[1])
    assert np.allclose(,
    w2 = np.random.randn(K_.shape[0])
    assert np.allclose(,
    W = np.random.randn(*K_.shape)
    assert np.allclose(,
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def forward(ctx, input, diagonal_idx=0):
        ctx.diagonal_idx = diagonal_idx
        return input.diag(ctx.diagonal_idx)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def backward(ctx, grad_output):
        return grad_output.diag(ctx.diagonal_idx), None
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def phi(A):
        Return lower triangle of A and halve the diagonal.
        B = A.tril()

        B = B - 0.5 * torch.diag(torch.diag(B))

        return B
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_eig(self):
        a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
                          (-6.49, 3.80, 0.00, 0.00, 0.00),
                          (-0.47, -6.39, 4.17, 0.00, 0.00),
                          (-7.20, 1.50, -1.51, 5.70, 0.00),
                          (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous()
        e = torch.eig(a)[0]
        ee, vv = torch.eig(a, True)
        te = torch.Tensor()
        tv = torch.Tensor()
        eee, vvv = torch.eig(a, True, out=(te, tv))
        self.assertEqual(e, ee, 1e-12)
        self.assertEqual(ee, eee, 1e-12)
        self.assertEqual(ee, te, 1e-12)
        self.assertEqual(vv, vvv, 1e-12)
        self.assertEqual(vv, tv, 1e-12)

        # test reuse
        X = torch.randn(4, 4)
        X =, X)
        e, v = torch.zeros(4, 2), torch.zeros(4, 4)
        torch.eig(X, True, out=(e, v))
        Xhat =, torch.diag(, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        torch.eig(X, True, out=(e, v))
        Xhat =,, 0).diag(), v.t()))
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
        self.assertFalse(v.is_contiguous(), 'V is contiguous')

        # test non-contiguous
        X = torch.randn(4, 4)
        X =, X)
        e = torch.zeros(4, 2, 2)[:, 1]
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.eig(X, True, out=(e, v))
        Xhat =, torch.diag(, 0))), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_symeig(self):
        xval = torch.rand(100, 3)
        cov =, xval)
        rese = torch.zeros(3)
        resv = torch.zeros(3, 3)

        # First call to symeig
        self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # Second call to symeig
        self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # test non-contiguous
        X = torch.rand(5, 5)
        X = X.t() * X
        e = torch.zeros(4, 2).select(1, 1)
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.symeig(X, True, out=(e, v))
        Xhat =, torch.diag(e)), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_svd(self):
        a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
                          (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
                          (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
                          (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
                          (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
        u, s, v = torch.svd(a)
        uu = torch.Tensor()
        ss = torch.Tensor()
        vv = torch.Tensor()
        uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
        self.assertEqual(u, uu, 0, 'torch.svd')
        self.assertEqual(u, uuu, 0, 'torch.svd')
        self.assertEqual(s, ss, 0, 'torch.svd')
        self.assertEqual(s, sss, 0, 'torch.svd')
        self.assertEqual(v, vv, 0, 'torch.svd')
        self.assertEqual(v, vvv, 0, 'torch.svd')

        # test reuse
        X = torch.randn(4, 4)
        U, S, V = torch.svd(X)
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        # test non-contiguous
        X = torch.randn(5, 5)
        U = torch.zeros(5, 2, 5)[:, 1]
        S = torch.zeros(5, 2)[:, 1]
        V = torch.zeros(5, 2, 5)[:, 1]

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        self.assertFalse(S.is_contiguous(), 'S is contiguous')
        self.assertFalse(V.is_contiguous(), 'V is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def pending_test_diag():
    diag_actual = torch.diag(WKW)
    diag_res = lazy_kronecker_product_var.diag()
    assert utils.approx_equal(, diag_actual)
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def add_diag(self, diag):
        if self.added_diag is None:
            return MulLazyVariable(*self.lazy_vars,
            return MulLazyVariable(*self.lazy_vars,
                                   added_diag=self.added_diag + diag)
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def diag(self):
        res = Variable(torch.ones(self.size()[0]))
        for lazy_var in self.lazy_vars:
            res = res * lazy_var.diag()

        if self.added_diag is not None:
            res = res + self.added_diag
        return res
项目:gpytorch    作者:jrg365    | 项目源码 | 文件源码
def evaluate(self):
        res = None
        for lazy_var in self.lazy_vars:
            if res is None:
                res = lazy_var.evaluate()
                res = res * lazy_var.evaluate()

        if self.added_diag is not None:
            res = res + self.added_diag.diag()
        return res
项目:qpth    作者:locuslab    | 项目源码 | 文件源码
def factor_kkt(U_S, R, d):
    """ Factor the U22 block that we can only do after we know D. """
    nineq = R.size(0)
    U_S[-nineq:, -nineq:] = torch.potrf(R + torch.diag(1 / d.cpu()).type_as(d))
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def random_square_matrix_of_rank(l, rank):
    assert rank <= l
    A = torch.randn(l, l)
    u, s, v = A.svd()
    for i in range(l):
        if i >= rank:
            s[i] = 0
        elif s[i] == 0:
            s[i] = 1
    return, 1))
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def random_fullrank_matrix_distinct_singular_value(l):
    A = torch.randn(l, l)
    u, _, v = A.svd()
    s = torch.arange(1, l + 1).mul_(1.0 / (l + 1))
    return, 1))
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_diag(self):
        x = torch.rand(100, 100)
        res1 = torch.diag(x)
        res2 = torch.Tensor()
        torch.diag(x, out=res2)
        self.assertEqual(res1, res2)
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_symeig(self):
        xval = torch.rand(100, 3)
        cov =, xval)
        rese = torch.zeros(3)
        resv = torch.zeros(3, 3)

        # First call to symeig
        self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # Second call to symeig
        self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
        torch.symeig(cov.clone(), True, out=(rese, resv))
        ahat =, torch.diag(rese)), resv.t())
        self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')

        # test non-contiguous
        X = torch.rand(5, 5)
        X = X.t() * X
        e = torch.zeros(4, 2).select(1, 1)
        v = torch.zeros(4, 2, 4)[:, 1]
        self.assertFalse(v.is_contiguous(), 'V is contiguous')
        self.assertFalse(e.is_contiguous(), 'E is contiguous')
        torch.symeig(X, True, out=(e, v))
        Xhat =, torch.diag(e)), v.t())
        self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_svd(self):
        a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
                          (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
                          (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
                          (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
                          (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
        u, s, v = torch.svd(a)
        uu = torch.Tensor()
        ss = torch.Tensor()
        vv = torch.Tensor()
        uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
        self.assertEqual(u, uu, 0, 'torch.svd')
        self.assertEqual(u, uuu, 0, 'torch.svd')
        self.assertEqual(s, ss, 0, 'torch.svd')
        self.assertEqual(s, sss, 0, 'torch.svd')
        self.assertEqual(v, vv, 0, 'torch.svd')
        self.assertEqual(v, vvv, 0, 'torch.svd')

        # test reuse
        X = torch.randn(4, 4)
        U, S, V = torch.svd(X)
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')

        # test non-contiguous
        X = torch.randn(5, 5)
        U = torch.zeros(5, 2, 5)[:, 1]
        S = torch.zeros(5, 2)[:, 1]
        V = torch.zeros(5, 2, 5)[:, 1]

        self.assertFalse(U.is_contiguous(), 'U is contiguous')
        self.assertFalse(S.is_contiguous(), 'S is contiguous')
        self.assertFalse(V.is_contiguous(), 'V is contiguous')
        torch.svd(X, out=(U, S, V))
        Xhat =,, V.t()))
        self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
项目:pyinn    作者:szagoruyko    | 项目源码 | 文件源码
def cublas_dgmm(A, x, out=None):
    if out is not None:
        assert out.is_contiguous() and out.size() == A.size()
        out =
    assert x.dim() == 1
    assert x.numel() == A.size(-1) or x.numel() == A.size(0)
    assert A.type() == x.type() == out.type()
    assert A.is_contiguous()

    if not isinstance(A, (torch.cuda.FloatTensor, torch.cuda.DoubleTensor)):
        if x.numel() == A.size(-1):
            return, out=out.view_as(A))
            return torch.diag(x).mm(A, out=out.view_as(A))
        if x.numel() == A.size(-1):
            m, n =  A.size(-1), A.numel() // A.size(-1)
            mode = 'l'
            #, out=out)
            # return out
        elif x.numel() == A.size(0):
            n, m = A.size(0), A.numel() // A.size(0)
            mode = 'r'
            # if A.stride(0) == 1:
            #     mode = 'l'
            #     n, m = m, n
            # x.diag().mm(A, out=out)
            # return out
        lda, ldc = m, m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        from skcuda import cublas
        cublas.cublasSetStream(handle, stream)
        args = [handle, mode, m, n, A.data_ptr(), lda, x.data_ptr(), incx, out.data_ptr(), ldc]
        if isinstance(A, torch.cuda.FloatTensor):
        elif isinstance(A, torch.cuda.DoubleTensor):
        return out