Python torch 模块,bmm() 实例源码


def forward(self, inp, hidden):
        outp = self.bilstm.forward(inp, hidden)[0]
        size = outp.size()  # [bsz, len, nhid]
        compressed_embeddings = outp.view(-1, size[2])  # [bsz*len, nhid*2]
        transformed_inp = torch.transpose(inp, 0, 1).contiguous()  # [bsz, len]
        transformed_inp = transformed_inp.view(size[0], 1, size[1])  # [bsz, 1, len]
        concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
        concatenated_inp =, 1)  # [bsz, hop, len]

        hbar = self.tanh(self.ws1(self.drop(compressed_embeddings)))  # [bsz*len, attention-unit]
        alphas = self.ws2(hbar).view(size[0], size[1], -1)  # [bsz, len, hop]
        alphas = torch.transpose(alphas, 1, 2).contiguous()  # [bsz, hop, len]
        penalized_alphas = alphas + (
            -10000 * (concatenated_inp == self.dictionary.word2idx['<pad>']).float())
            # [bsz, hop, len] + [bsz, hop, len]
        alphas = self.softmax(penalized_alphas.view(-1, size[1]))  # [bsz*hop, len]
        alphas = alphas.view(size[0], self.attention_hops, size[1])  # [bsz, hop, len]
        return torch.bmm(alphas, outp), alphas
def backward(self, grad_output):
        batch1, batch2 = self.saved_tensors
        grad_add_matrix = grad_batch1 = grad_batch2 = None

        if self.needs_input_grad[0]:
            grad_add_matrix = grad_output
            if self.alpha != 1:
                grad_add_matrix = grad_add_matrix.mul(self.alpha)

        if any(self.needs_input_grad[1:]):
            batch_grad_output = (grad_output
                    .expand(batch1.size(0), batch1.size(1), batch2.size(2)))

        if self.needs_input_grad[1]:
            grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
            if self.beta != 1:
                grad_batch1 *= self.beta

        if self.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
            if self.beta != 1:
                grad_batch2 *= self.beta

        return grad_add_matrix, grad_batch1, grad_batch2
def backward(self, grad_output):
        batch1, batch2 = self.saved_tensors
        grad_add_batch = grad_batch1 = grad_batch2 = None

        if self.needs_input_grad[0]:
            grad_add_batch = grad_output
            if self.alpha != 1:
                grad_add_batch = grad_add_batch.mul(self.alpha)

        if self.needs_input_grad[1]:
            grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
            if self.beta != 1:
                grad_batch1 *= self.beta

        if self.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
            if self.beta != 1:
                grad_batch2 *= self.beta

        return grad_add_batch, grad_batch1, grad_batch2
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
  , a, b)
            if self.transA:
                a = a.transpose(2, 3)
            if self.transB:
                b = b.transpose(2, 3)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(self.output, a, b)

        return self.output
def updateOutput(self, input):
        M, v = input
        assert M.ndimension() == 2 or M.ndimension() == 3

        if M.ndimension() == 2:
            assert v.ndimension() == 1
            if self.trans:
                M = M.transpose(0, 1)
  , M, v)
            assert v.ndimension() == 2
            if self.trans:
                M = M.transpose(1, 2)
            self.output.resize_(M.size(0), M.size(1), 1)
            torch.bmm(self.output, M, v.view(v.size(0), v.size(1), 1)).resize_(M.size(0), M.size(1))

        return self.output
def forward(self, q, k, v, attn_mask):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        residual = q

        bsz, len_q, d_model = q.size()
        len_k, len_v = k.size(1), v.size(1)

        def reshape(x):
            """[bsz, len, d_*] -> [n_head x (bsz*len) x d_*]"""
            return x.repeat(n_head, 1, 1).view(n_head, -1, d_model)

        q_s, k_s, v_s = map(reshape, [q, k, v])

        q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k)
        k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k)
        v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v)

        outputs = self.attention(q_s, k_s, v_s, attn_mask.repeat(n_head, 1, 1))
        outputs =, bsz, dim=0), dim=-1).view(-1, n_head*d_v)
        outputs = F.dropout(self.w_o(outputs), p=self.dropout).view(bsz, len_q, -1)
        return self.lm(outputs + residual)
def forward(self, x, target_embedding, encoder_out):
        residual = x

        # attention
        x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
        x = self.bmm(x, encoder_out[0])

        # softmax over last dim
        sz = x.size()
        x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
        x = x.view(sz)
        attn_scores = x

        x = self.bmm(x, encoder_out[1])

        # scale attention output
        s = encoder_out[1].size(1)
        x = x * (s * math.sqrt(1.0 / s))

        # project back
        x = (self.out_projection(x) + residual) * math.sqrt(0.5)
        return x, attn_scores
def forward(self, q, k, v, attn_mask=None):

        attn = torch.bmm(q, k.transpose(1, 2)) / self.temper

        if attn_mask is not None:

            assert attn_mask.size() == attn.size(), \
                    'Attention mask shape {} mismatch ' \
                    'with Attention logit tensor shape ' \
                    '{}.'.format(attn_mask.size(), attn.size())

  , -float('inf'))

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn
def forward(self, inputs, context):
        inputs: batch x dim
        context: batch x sourceL x dim
        targetT = self.linear_in(inputs).unsqueeze(2)  # batch x dim x 1

        # Get attention
        attn = torch.bmm(context, targetT).squeeze(2)  # batch x sourceL
        if self.mask is not None:
  , -_INF)
        attn =
        attn3 = attn.view(attn.size(0), 1, attn.size(1))  # batch x 1 x sourceL

        weightedContext = torch.bmm(attn3, context).squeeze(1)  # batch x dim
        contextCombined =, inputs), 1)

        contextOutput = self.tanh(self.linear_out(contextCombined))

        return contextOutput, attn
def forward(self, x, target_embedding, encoder_out):
        residual = x

        # attention
        x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
        x = self.bmm(x, encoder_out[0])

        # softmax over last dim
        sz = x.size()
        x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
        x = x.view(sz)
        attn_scores = x

        x = self.bmm(x, encoder_out[1])

        # scale attention output
        s = encoder_out[1].size(1)
        x = x * (s * math.sqrt(1.0 / s))

        # project back
        x = (self.out_projection(x) + residual) * math.sqrt(0.5)
        return x, attn_scores
def forward(self, input, context):
        input: batch x dim
        context: batch x sourceL x dim
        targetT = self.linear_in(input).unsqueeze(2)  # batch x dim x 1

        # Get attention
        attn = torch.bmm(context, targetT).squeeze(2)  # batch x sourceL
        if self.mask is not None:
  , -float('inf'))
        attn =
        attn3 = attn.view(attn.size(0), 1, attn.size(1))  # batch x 1 x sourceL

        weightedContext = torch.bmm(attn3, context).squeeze(1)  # batch x dim
        contextCombined =, input), 1)

        contextOutput = self.tanh(self.linear_out(contextCombined))

        return contextOutput, attn
def forward(self, input, context):
        input: batch x dim
        context: batch x sourceL x dim
        targetT = self.linear_in(input).unsqueeze(2)  # batch x dim x 1

        # Get attention
        attn = torch.bmm(context, targetT).squeeze(2)  # batch x sourceL
        if self.mask is not None:
  , -float('inf'))
        attn =
        attn3 = attn.view(attn.size(0), 1, attn.size(1))  # batch x 1 x sourceL

        weightedContext = torch.bmm(attn3, context).squeeze(1)  # batch x dim
        contextCombined =, input), 1)

        contextOutput = self.tanh(self.linear_out(contextCombined))

        return contextOutput, attn
def forward(self, x):
        batchsize = x.size()[0]
        trans = self.stn(x) # regressing the transforming parameters using STN 
        x = x.transpose(2,1) # bz x 2048 x 3 
        x = torch.bmm(x, trans) # (bz x 2048 x 3) x (bz x 3 x 3) 
        x = x.transpose(2,1) # bz x 3 x 2048
        x = F.relu(self.bn1(self.conv1(x)))
        pointfeat = x # bz x 64 x 2048
        x = F.relu(self.bn2(self.conv2(x))) # bz x 128 x 2048
        x = self.bn3(self.conv3(x)) # bz x 1024 x 2048
        x = self.mp1(x)
        x = x.view(-1, 1024) # bz x 1024
        if self.global_feat: # using global feats for classification
            return x, trans
            x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
            return[x, pointfeat], 1), trans
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_matrix = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_matrix = grad_output
            if ctx.alpha != 1:
                grad_add_matrix = grad_add_matrix.mul(ctx.alpha)

        if any(ctx.needs_input_grad[1:]):
            batch_grad_output = (grad_output
                                 .expand(batch1.size(0), batch1.size(1), batch2.size(2)))

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_batch = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_batch = grad_output
            if ctx.alpha != 1:
                grad_add_batch = grad_add_batch.mul(ctx.alpha)

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_batch, grad_batch1, grad_batch2, None, None, None
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
  , b, out=self.output)
            if self.transA:
                a = a.transpose(2, 3)
            if self.transB:
                b = b.transpose(2, 3)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(a, b, out=self.output)

        return self.output
def updateOutput(self, input):
        M, v = input
        assert M.ndimension() == 2 or M.ndimension() == 3

        if M.ndimension() == 2:
            assert v.ndimension() == 1
            if self.trans:
                M = M.transpose(0, 1)
  , v, out=self.output)
            assert v.ndimension() == 2
            if self.trans:
                M = M.transpose(1, 2)
            self.output.resize_(M.size(0), M.size(1), 1)
            torch.bmm(M, v.view(v.size(0), v.size(1), 1), out=self.output).resize_(M.size(0), M.size(1))

        return self.output
def _test_btrisolve(self, cast):
        a = torch.FloatTensor((((1.3722, -0.9020),
                                (1.8849, 1.9169)),
                               ((0.7187, -1.1695),
                                (-0.0139, 1.3572)),
                               ((-1.6181, 0.7148),
                                (1.3728, 0.1319))))
        b = torch.FloatTensor(((4.02, 6.19),
                               (-1.56, 4.00),
                               (9.81, -4.09)))
        a, b = cast(a), cast(b)
        info = cast(torch.IntTensor())
        LU_data, pivots = a.btrifact(info=info)
        self.assertEqual(info.abs().sum(), 0)
        x = torch.btrisolve(b, LU_data, pivots)
        b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
        self.assertEqual(b_, b)
def forward(self, input1):
        self.input1 = input1
        output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
        self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
        for i in range(input1.size(0)):
            self.batchgrid[i] = self.grid

        if input1.is_cuda:
            self.batchgrid = self.batchgrid.cuda()
            output = output.cuda()

        batchgrid_temp = self.batchgrid.view(-1, self.height*self.width, 3)
        input_temp = torch.transpose(input1, 1, 2)
        output_temp = torch.bmm(batchgrid_temp, input_temp)
        output = output_temp.view(-1, self.height, self.width, 2)
        return output
def forward(self, input, context):
        """Propogate input through the network.

        input: batch x dim
        context: batch x sourceL x dim
        target = self.linear_in(input).unsqueeze(2)  # batch x dim x 1

        # Get attention
        attn = torch.bmm(context, target).squeeze(2)  # batch x sourceL
        attn =
        attn3 = attn.view(attn.size(0), 1, attn.size(1))  # batch x 1 x sourceL

        weighted_context = torch.bmm(attn3, context).squeeze(1)  # batch x dim
        h_tilde =, input), 1)

        h_tilde = self.tanh(self.linear_out(h_tilde))

        return h_tilde, attn
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_matrix = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_matrix = grad_output
            if ctx.alpha != 1:
                grad_add_matrix = grad_add_matrix.mul(ctx.alpha)

        if any(ctx.needs_input_grad[1:]):
            batch_grad_output = (grad_output
                                 .expand(batch1.size(0), batch1.size(1), batch2.size(2)))

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_batch = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_batch = grad_output
            if ctx.alpha != 1:
                grad_add_batch = grad_add_batch.mul(ctx.alpha)

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_batch, grad_batch1, grad_batch2, None, None, None
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
  , b, out=self.output)
            if self.transA:
                a = a.transpose(1, 2)
            if self.transB:
                b = b.transpose(1, 2)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(a, b, out=self.output)

        return self.output
def updateOutput(self, input):
        M, v = input
        assert M.ndimension() == 2 or M.ndimension() == 3

        if M.ndimension() == 2:
            assert v.ndimension() == 1
            if self.trans:
                M = M.transpose(0, 1)
  , v, out=self.output)
            assert v.ndimension() == 2
            if self.trans:
                M = M.transpose(1, 2)
            self.output.resize_(M.size(0), M.size(1), 1)
            torch.bmm(M, v.view(v.size(0), v.size(1), 1), out=self.output).resize_(M.size(0), M.size(1))

        return self.output
def _test_btrisolve(self, cast):
        a = torch.FloatTensor((((1.3722, -0.9020),
                                (1.8849, 1.9169)),
                               ((0.7187, -1.1695),
                                (-0.0139, 1.3572)),
                               ((-1.6181, 0.7148),
                                (1.3728, 0.1319))))
        b = torch.FloatTensor(((4.02, 6.19),
                               (-1.56, 4.00),
                               (9.81, -4.09)))
        a, b = cast(a), cast(b)
        info = cast(torch.IntTensor())
        LU_data, pivots = a.btrifact(info=info)
        self.assertEqual(info.abs().sum(), 0)
        x = torch.btrisolve(b, LU_data, pivots)
        b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
        self.assertEqual(b_, b)
def forward(self, x):
        batchsize = x.size()[0]
        trans = self.stn(x)
        x = x.transpose(2,1)
        x = torch.bmm(x, trans)
        x = x.transpose(2,1)
        x = F.relu(self.bn1(self.conv1(x)))
        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = self.mp1(x)
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans
            x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
            return[x, pointfeat], 1), trans
def calc_score(self, att_query, att_keys):
        att_query is: b x t_q x n
        att_keys is b x t_k x n
        return b x t_q x t_k scores

        b, t_k, n = list(att_keys.size())
        t_q = att_query.size(1)
        if self.mode == 'bahdanau':
            att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
            att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
            sum_qk = att_query + att_keys
            sum_qk = sum_qk.view(b * t_k * t_q, n)
            out = self.linear_att(F.tanh(sum_qk)).view(b, t_q, t_k)
        elif self.mode == 'dot_prod':
            out = torch.bmm(att_query, att_keys.transpose(1, 2))
            if self.normalize:
                out.div_(n ** 0.5)
        return out
def forward(self, v, z):
        :param v: batch_size (B) x latent_size (L)
        :param z: batch_size (B) x latent_size (L)
        :return: z_new = z - 2* v v_T / norm(v,2) * z
        # v * v_T
        vvT = torch.bmm( v.unsqueeze(2), v.unsqueeze(1) )  # v * v_T : batch_dot( B x L x 1 * B x 1 x L ) = B x L x L
        # v * v_T * z
        vvTz = torch.bmm( vvT, z.unsqueeze(2) ).squeeze(2) # A * z : batchdot( B x L x L * B x L x 1 ).squeeze(2) = (B x L x 1).squeeze(2) = B x L
        # calculate norm ||v||^2
        norm_sq = torch.sum( v * v, 1 ) # calculate norm-2 for each row : B x 1
        norm_sq = norm_sq.expand( norm_sq.size(0), v.size(1) ) # expand sizes : B x L
        # calculate new z
        z_new = z - 2 * vvTz / norm_sq # z - 2 * v * v_T  * z / norm2(v)
        return z_new
def forward(self, L, z):
        :param L: batch_size (B) x latent_size^2 (L^2)
        :param z: batch_size (B) x latent_size (L)
        :return: z_new = L*z
        # L->tril(L)
        L_matrix = L.view( -1, self.args.z1_size, self.args.z1_size ) # resize to get B x L x L
        LTmask = torch.tril( torch.ones(self.args.z1_size, self.args.z1_size), k=-1 ) # lower-triangular mask matrix (1s in lower triangular part)
        I = Variable( torch.eye(self.args.z1_size, self.args.z1_size).expand(L_matrix.size(0), self.args.z1_size, self.args.z1_size) )
        if self.args.cuda:
            LTmask = LTmask.cuda()
            I = I.cuda()
        LTmask = Variable(LTmask)
        LTmask = LTmask.unsqueeze(0).expand( L_matrix.size(0), self.args.z1_size, self.args.z1_size ) # 1 x L x L -> B x L x L
        LT = torch.mul( L_matrix, LTmask ) + I # here we get a batch of lower-triangular matrices with ones on diagonal

        # z_new = L * z
        z_new = torch.bmm( LT , z.unsqueeze(2) ).squeeze(2) # B x L x L * B x L x 1 -> B x L

        return z_new
def forward(ctx, theta, size):
        assert type(size) == torch.Size
        N, C, H, W = size
        ctx.size = size
        if theta.is_cuda:
            ctx.is_cuda = True
            grid =, H, W, 2)
            theta = theta.contiguous()
            torch._C._cudnn_affine_grid_generator_forward(theta, grid, N, C, H, W)
            ctx.is_cuda = False
            base_grid =, H, W, 3)
            linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
            base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
            linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
            base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
            base_grid[:, :, :, 2] = 1
            ctx.base_grid = base_grid
            grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
            grid = grid.view(N, H, W, 2)
        return grid
def backward(ctx, grad_grid):
        N, C, H, W = ctx.size
        assert grad_grid.size() == torch.Size([N, H, W, 2])
        assert ctx.is_cuda == grad_grid.is_cuda
        if grad_grid.is_cuda:
            grad_theta =, 2, 3)
            grad_grid = grad_grid.contiguous()
            torch._C._cudnn_affine_grid_generator_backward(grad_theta, grad_grid,
                                                           N, C, H, W)
            base_grid = ctx.base_grid
            grad_theta = torch.bmm(
                base_grid.view(N, H * W, 3).transpose(1, 2),
                grad_grid.view(N, H * W, 2))
            grad_theta = grad_theta.transpose(1, 2)

        return grad_theta, None
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_matrix = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_matrix = maybe_unexpand(grad_output, ctx.add_matrix_size)
            if ctx.alpha != 1:
                grad_add_matrix = grad_add_matrix.mul(ctx.alpha)

        if any(ctx.needs_input_grad[1:]):
            batch_grad_output = (grad_output
                                 .expand(batch1.size(0), batch1.size(1), batch2.size(2)))

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
def backward(ctx, grad_output):
        batch1, batch2 = ctx.saved_variables
        grad_add_batch = grad_batch1 = grad_batch2 = None

        if ctx.needs_input_grad[0]:
            grad_add_batch = maybe_unexpand(grad_output, ctx.add_batch_size)
            if ctx.alpha != 1:
                grad_add_batch = grad_add_batch.mul(ctx.alpha)

        if ctx.needs_input_grad[1]:
            grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
            if ctx.beta != 1:
                grad_batch1 *= ctx.beta

        if ctx.needs_input_grad[2]:
            grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
            if ctx.beta != 1:
                grad_batch2 *= ctx.beta

        return grad_add_batch, grad_batch1, grad_batch2, None, None, None
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
  , b, out=self.output)
            if self.transA:
                a = a.transpose(1, 2)
            if self.transB:
                b = b.transpose(1, 2)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(a, b, out=self.output)

        return self.output
def updateOutput(self, input):
        M, v = input
        assert M.ndimension() == 2 or M.ndimension() == 3

        if M.ndimension() == 2:
            assert v.ndimension() == 1
            if self.trans:
                M = M.transpose(0, 1)
  , v, out=self.output)
            assert v.ndimension() == 2
            if self.trans:
                M = M.transpose(1, 2)
            self.output.resize_(M.size(0), M.size(1), 1)
            torch.bmm(M, v.view(v.size(0), v.size(1), 1), out=self.output).resize_(M.size(0), M.size(1))

        return self.output
def _test_btrisolve(self, cast):
        a = torch.FloatTensor((((1.3722, -0.9020),
                                (1.8849, 1.9169)),
                               ((0.7187, -1.1695),
                                (-0.0139, 1.3572)),
                               ((-1.6181, 0.7148),
                                (1.3728, 0.1319))))
        b = torch.FloatTensor(((4.02, 6.19),
                               (-1.56, 4.00),
                               (9.81, -4.09)))
        a, b = cast(a), cast(b)
        info = cast(torch.IntTensor())
        LU_data, pivots = a.btrifact(info=info)
        self.assertEqual(info.abs().sum(), 0)
        x = torch.btrisolve(b, LU_data, pivots)
        b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
        self.assertEqual(b_, b)
def backward(self, gradE):
        A, X, C = self.saved_variables
        with torch.cuda.device_of(A):
            gradA = Variable(
            gradX = Variable(
            gradC = Variable(
        if isinstance(, torch.cuda.FloatTensor):
            with torch.cuda.device_of(
        elif isinstance(, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(
            raise RuntimeError('Unimplemented data type!'), gradE).data)*A.sum(1).unsqueeze(2)).sum(0).data)
        return gradA, gradX, gradC
def backward(self, gradE):
        A, X, C = self.saved_tensors
        with torch.cuda.device_of(A):
            gradA =
            gradX =
            gradC =
        if isinstance(A, torch.cuda.FloatTensor):
            with torch.cuda.device_of(A):
                    gradE, A, X, C)
        elif isinstance(A, torch.cuda.DoubleTensor):
            with torch.cuda.device_of(A):
                    gradE, A, X, C)
            raise RuntimeError('Unimplemented data type!')
        gradX.copy_(torch.bmm(A, gradE))
        return gradA, gradX, gradC
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    invQ_rx = rx.btrisolve(*Q_LU)
    if neq > 0:
        h =, 2)).squeeze(1) - ry,
                       invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1)
        h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz

    w = -(h.btrisolve(*S_LU))

    g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1)
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)
    g2 = -rs - w[:, neq:]

    dx = g1.btrisolve(*Q_LU)
    ds = g2 / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None

    return dx, ds, dz, dy
def forward(self, x):
        batchsize = x.size()[0]
        if self.trans:
            trans = self.stn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans)
            x = x.transpose(2,1)
        x = F.relu(self.bn1(self.conv1(x)))
        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x,_ = torch.max(x, 2)
        x = x.view(-1, 1024)
        if self.trans:
            if self.global_feat:
                return x, trans
                x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
                return[x, pointfeat], 1), trans
            return x
def forward(self, output, context):
        batch_size = output.size(0)
        hidden_size = output.size(2)
        input_size = context.size(1)
        # (batch, out_len, dim) * (batch, in_len, dim) -> (batch, out_len, in_len)
        attn = torch.bmm(output, context.transpose(1, 2))
        if self.mask is not None:
  , -float('inf'))
        attn = F.softmax(attn.view(-1, input_size)).view(batch_size, -1, input_size)

        # (batch, out_len, in_len) * (batch, in_len, dim) -> (batch, out_len, dim)
        mix = torch.bmm(attn, context)

        # concat -> (batch, out_len, 2*dim)
        combined =, output), dim=2)
        # output -> (batch, out_len, dim)
        output = F.tanh(self.linear_out(combined.view(-1, 2 * hidden_size))).view(batch_size, -1, hidden_size)

        return output, attn
def forward(self, h, att_feats, p_att_feats):
        # The p_att_feats here is already projected
        att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size
        att = p_att_feats.view(-1, att_size, self.att_hid_size)

        att_h = self.h2att(h)                        # batch * att_hid_size
        att_h = att_h.unsqueeze(1).expand_as(att)            # batch * att_size * att_hid_size
        dot = att + att_h                                   # batch * att_size * att_hid_size
        dot = F.tanh(dot)                                # batch * att_size * att_hid_size
        dot = dot.view(-1, self.att_hid_size)               # (batch * att_size) * att_hid_size
        dot = self.alpha_net(dot)                           # (batch * att_size) * 1
        dot = dot.view(-1, att_size)                        # batch * att_size

        weight = F.softmax(dot)                             # batch * att_size
        att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size
        att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size

        return att_res
def forward(self, input, hidden, encoder_output, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn([0], hidden[0]), 1)))
        attn_weights = attn_weights.cuda() if use_cuda else attn_weights
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
        attn_applied = attn_applied.cuda() if use_cuda else attn_applied

        output =[0], attn_applied[0]), 1)
        output = output.cuda() if use_cuda else output
        output = self.attn_combine(output).unsqueeze(0)

        for i in range(self.n_layers):
            output = F.relu(output)
            output = output.cuda() if use_cuda else output
            output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]))
        output = output.cuda() if use_cuda else output
        return output, hidden, attn_weights
def forward(self, query, ref):
            query: is the hidden state of the decoder at the current
                time step. batch x dim
            ref: the set of hidden states from the encoder. 
                sourceL x batch x hidden_dim
        # ref is now [batch_size x hidden_dim x sourceL]
        ref = ref.permute(1, 2, 0)
        q = self.project_query(query).unsqueeze(2)  # batch x dim x 1
        e = self.project_ref(ref)  # batch_size x hidden_dim x sourceL 
        # expand the query by sourceL
        # batch x dim x sourceL
        expanded_q = q.repeat(1, 1, e.size(2)) 
        # batch x 1 x hidden_dim
        v_view = self.v.unsqueeze(0).expand(
                expanded_q.size(0), len(self.v)).unsqueeze(1)
        # [batch_size x 1 x hidden_dim] * [batch_size x hidden_dim x sourceL]
        u = torch.bmm(v_view, self.tanh(expanded_q + e)).squeeze(1)
        if self.use_tanh:
            logits = self.C * self.tanh(u)
            logits = u  
        return e, logits
def forward(self, inputs):
            inputs: [embedding_dim x batch_size x sourceL] of embedded inputs

        (encoder_hx, encoder_cx) = self.encoder.enc_init_state
        encoder_hx = encoder_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0)
        encoder_cx = encoder_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0)       

        # encoder forward pass
        enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx))

        # grab the hidden state and process it via the process block 
        process_block_state = enc_h_t[-1]
        for i in range(self.n_process_block_iters):
            ref, logits = self.process_block(process_block_state, enc_outputs)
            process_block_state = torch.bmm(ref,
        # produce the final scalar output
        out = self.decoder(process_block_state)
        return out
def forward(self, dec_out, enc_outs, enc_att=None, mask=None):

        - dec_out: torch.Tensor(batch_size x hid_dim)
        - enc_outs: torch.Tensor(seq_len x batch_size x hid_dim)
        - enc_att: (optional), torch.Tensor(seq_len x batch_size x att_dim)
        - mask: (optional), torch.ByteTensor(batch_size x seq_len)
        # (batch x seq_len)
        weights = self.scorer(dec_out, enc_outs, enc_att=enc_att)

        if mask is not None:
            # weights = weights * mask.float()
   -, -float('inf'))

        weights = F.softmax(weights, dim=1)

        # (eq 7)
        context = weights.unsqueeze(1).bmm(enc_outs.transpose(0, 1)).squeeze(1)
        # (eq 5) linear out combining context and hidden
        context = F.tanh(self.linear_out([context, dec_out], 1)))

        return context, weights
def _access(self, memory_vb): # write
        variables needed:
            wl_curr_vb: [batch_size x num_heads x mem_hei]
            erase_vb:   [batch_size x num_heads x mem_wid]
                     -> /in (0, 1)
            add_vb:     [batch_size x num_heads x mem_wid]
                     -> w/ no restrictions in range
            memory_vb:  [batch_size x mem_hei x mem_wid]
            memory_vb:  [batch_size x mem_hei x mem_wid]

        # first let's do erasion
        weighted_erase_vb = torch.bmm(self.wl_curr_vb.contiguous().view(-1, self.mem_hei, 1),
                                      self.erase_vb.contiguous().view(-1, 1, self.mem_wid)).view(-1, self.num_heads, self.mem_hei, self.mem_wid)
        keep_vb = - weighted_erase_vb, dim=1)
        memory_vb = memory_vb * keep_vb
        # finally let's write (do addition)
        return memory_vb + torch.bmm(self.wl_curr_vb.transpose(1, 2), self.add_vb)
def _access(self, memory_vb): # write
        variables needed:
            wl_curr_vb: [batch_size x num_heads x mem_hei]
            erase_vb:   [batch_size x num_heads x mem_wid]
                     -> /in (0, 1)
            add_vb:     [batch_size x num_heads x mem_wid]
                     -> w/ no restrictions in range
            memory_vb:  [batch_size x mem_hei x mem_wid]
            memory_vb:  [batch_size x mem_hei x mem_wid]

        # first let's do erasion
        weighted_erase_vb = torch.bmm(self.wl_curr_vb.contiguous().view(-1, self.mem_hei, 1),
                                      self.erase_vb.contiguous().view(-1, 1, self.mem_wid)).view(-1, self.num_heads, self.mem_hei, self.mem_wid)
        keep_vb = - weighted_erase_vb, dim=1)
        memory_vb = memory_vb * keep_vb
        # finally let's write (do addition)
        return memory_vb + torch.bmm(self.wl_curr_vb.transpose(1, 2), self.add_vb)
def forward(ctx, theta, size):
        assert type(size) == torch.Size
        N, C, H, W = size
        ctx.size = size
        if theta.is_cuda:
            assert False
        ctx.is_cuda = False
        base_grid =, H, W, 3)
        linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
        base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
        linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
        base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
        base_grid[:, :, :, 2] = 1
        ctx.base_grid = base_grid
        grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
        grid = grid.view(N, H, W, 2)
        return grid
def updateOutput(self, input):
        assert len(input) == 2
        a, b = input
        assert a.ndimension() == 2 or a.ndimension() == 3
        assert a.dim() == b.dim()

        if a.ndimension() == 2:
            if self.transA:
                a = a.t()
            if self.transB:
                b = b.t()
            self.output.resize_(a.size(0), b.size(1))
  , b, out=self.output)
            if self.transA:
                a = a.transpose(1, 2)
            if self.transB:
                b = b.transpose(1, 2)

            self.output.resize_(a.size(0), a.size(1), b.size(2))
            torch.bmm(a, b, out=self.output)

        return self.output