Python torch 模块,unsqueeze() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用torch.unsqueeze()

项目:RetinaNet    作者:c0nn3r    | 项目源码 | 文件源码
def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        target = torch.unsqueeze(torch.Tensor(target[0]['bbox']), -1)

        path = coco.loadImgs(img_id)[0]['file_name']

        img = Image.open(os.path.join(self.root, path)).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
项目:DistanceGAN    作者:sagiebenaim    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
项目:DeblurGAN    作者:KupynOrest    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
项目:lr-gan.pytorch    作者:jwyang    | 项目源码 | 文件源码
def clampT(self, Tin):
        x_s = Tin.select(1, 0)
        x_r = Tin.select(1, 1)
        x_t = Tin.select(1, 2)

        y_r = Tin.select(1, 3)
        y_s = Tin.select(1, 4)
        y_t = Tin.select(1, 5)

        x_s_clamp = torch.unsqueeze(x_s.clamp(opt.maxobjscale, 2 * opt.maxobjscale), 1)
        x_r_clmap = torch.unsqueeze(x_r.clamp(-rot, rot), 1)
        x_t_clmap = torch.unsqueeze(x_t.clamp(-1.0, 1.0), 1)

        y_r_clamp = torch.unsqueeze(y_r.clamp(-rot, rot), 1)
        y_s_clamp = torch.unsqueeze(y_s.clamp(opt.maxobjscale, 2 * opt.maxobjscale), 1)
        y_t_clamp = torch.unsqueeze(y_t.clamp(-1.0, 1.0), 1)

        Tout = torch.cat([x_s_clamp, x_r_clmap, x_t_clmap, y_r_clamp, y_s_clamp, y_t_clamp], 1)
        return Tout
项目:Deep-learning-with-cats    作者:AlexiaJM    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

# Initialize fake image pools
项目:CycleGANwithPerceptionLoss    作者:EliasVansteenkiste    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
项目:pytorch_cycle_gan    作者:jinfagang    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
项目:pytorch-CycleGAN-and-pix2pix    作者:junyanz    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return Variable(images)
        return_images = []
        for image in images:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
项目:wasserstein-cyclegan    作者:abhiskk    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
项目:rarepepes    作者:kendricktan    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size - 1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
项目:nmp_qc    作者:priba    | 项目源码 | 文件源码
def m_ggnn(self, h_v, h_w, e_vw, opt={}):

        m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data))

        for w in range(h_w.size(1)):
            if torch.nonzero(e_vw[:, w, :].data).size():
                for i, el in enumerate(self.args['e_label']):
                    ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i])

                    parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0),
                                                                            self.learn_args[0][i].size(1))

                    m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2),
                                                                        torch.transpose(torch.unsqueeze(h_w[:, w, :], 1),
                                                                                        1, 2)), 1, 2)
                    m_w = torch.squeeze(m_w)
                    m[:,w,:] = ind.expand_as(m_w)*m_w
        return m
项目:paysage    作者:drckf    | 项目源码 | 文件源码
def pdist(x: T.FloatTensor, y: T.FloatTensor) -> T.FloatTensor:
    """
    Compute the pairwise distance matrix between the rows of x and y.

    Args:
        x (tensor (num_samples_1, num_units))
        y (tensor (num_samples_2, num_units))

    Returns:
        tensor (num_samples_1, num_samples_2)

    """
    inner = dot(x, transpose(y))
    x_mag = norm(x, axis=1) ** 2
    y_mag = norm(y, axis=1) ** 2
    squared = add(unsqueeze(y_mag, axis=0), add(unsqueeze(x_mag, axis=1), -2*inner))
    return torch.sqrt(clip(squared, a_min=0))
项目:pytorch-classification    作者:bearpaw    | 项目源码 | 文件源码
def colorize(x):
    ''' Converts a one-channel grayscale image to a color heatmap image '''
    if x.dim() == 2:
        torch.unsqueeze(x, 0, out=x)
    if x.dim() == 3:
        cl = torch.zeros([3, x.size(1), x.size(2)])
        cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
        cl[1] = gauss(x,1,.5,.3)
        cl[2] = gauss(x,1,.2,.3)
        cl[cl.gt(1)] = 1
    elif x.dim() == 4:
        cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
        cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
        cl[:,1,:,:] = gauss(x,1,.5,.3)
        cl[:,2,:,:] = gauss(x,1,.2,.3)
    return cl
项目:VIGAN    作者:chaoshangcs    | 项目源码 | 文件源码
def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
项目:pytorch-semseg    作者:meetshah1995    | 项目源码 | 文件源码
def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True):

    batch_size = input.size()[0]

    def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True):
        n, c, h, w = input.size()
        log_p = F.log_softmax(input, dim=1)
        log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
        log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0]
        log_p = log_p.view(-1, c)

        mask = target >= 0
        target = target[mask]
        loss = F.nll_loss(log_p, target, weight=weight, reduce=False, size_average=False)
        topk_loss, _ = loss.topk(K)
        reduced_topk_loss = topk_loss.sum() / K

        return reduced_topk_loss

    loss = 0.0
    # Bootstrap from each image not entire batch
    for i in range(batch_size):
        loss += _bootstrap_xentropy_single(input=torch.unsqueeze(input[i], 0),
                                           target=torch.unsqueeze(target[i], 0),
                                           K=K,
                                           weight=weight,
                                           size_average=size_average)
    return loss / float(batch_size)
项目:python-utils    作者:zhijian-liu    | 项目源码 | 文件源码
def query(self, elements):
        if self.capacity == 0:
            return elements
        choices = []
        for element in elements.data:
            element = torch.unsqueeze(element, 0)
            if self.size < self.capacity:
                self.size += 1
                self.elements.append(element)
                choices.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    index = random.randint(0, self.capacity - 1)
                    candidate = self.elements[index].clone()
                    self.elements[index] = element
                    choices.append(candidate)
                else:
                    choices.append(element)
        choices = Variable(torch.cat(choices, 0))
        return choices
项目:dong_iccv_2017    作者:woozzu    | 项目源码 | 文件源码
def pairwise_ranking_loss(margin, x, v):
    zero = torch.zeros(1)
    diag_margin = margin * torch.eye(x.size(0))
    if not args.no_cuda:
        zero, diag_margin = zero.cuda(), diag_margin.cuda()
    zero, diag_margin = Variable(zero), Variable(diag_margin)

    x = x / torch.norm(x, 2, 1, keepdim=True)
    v = v / torch.norm(v, 2, 1, keepdim=True)
    prod = torch.matmul(x, v.transpose(0, 1))
    diag = torch.diag(prod)
    for_x = torch.max(zero, margin - torch.unsqueeze(diag, 1) + prod) - diag_margin
    for_v = torch.max(zero, margin - torch.unsqueeze(diag, 0) + prod) - diag_margin
    return (torch.sum(for_x) + torch.sum(for_v)) / x.size(0)
项目:StackGAN_pytorch    作者:qizhex    | 项目源码 | 文件源码
def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        h_relu = torch.unsqueeze(torch.unsqueeze(h_relu, 2), 3) # -> N x H x 1 x 1
        h_expand = h_relu.expand(64, H, h, w).contiguous().view(64, -1) # -> N x H x h x w
        y_pred = self.linear2(h_expand) # -> N x D_out
        return y_pred

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
项目:squad_rasor_nn    作者:hsgodhia    | 项目源码 | 文件源码
def _span_sums(self, p_lens, stt, end, max_p_len, batch_size, dim, max_ans_len):
        # stt       (max_p_len, batch_size, dim)
        # end       (max_p_len, batch_size, dim)
        # p_lens    (batch_size,)

        max_ans_len_range = torch.from_numpy(np.arange(max_ans_len))
        max_ans_len_range = max_ans_len_range.unsqueeze(0)  # (1, max_ans_len) is a vector like [0,1,2,3,4....,max_ans_len-1]
        offsets = torch.from_numpy(np.arange(max_p_len))
        offsets = offsets.unsqueeze(0)  # (1, max_p_len) is a vector like (0,1,2,3,4....max_p_len-1)
        offsets = offsets.transpose(0, 1)  # (max_p_len, 1) is row vector now like [0/1/2/3...max_p_len-1]

        end_idxs = max_ans_len_range.expand(offsets.size(0), max_ans_len_range.size(1)) + offsets.expand(offsets.size(0), max_ans_len_range.size(1))
        #pdb.set_trace()
        end_idxs_flat = end_idxs.view(-1, 1).squeeze(1)  # (max_p_len*max_ans_len, )
        # note: this is not modeled as tensor of size (SZ, 1) but vector of SZ size
        zero_t = torch.zeros(max_ans_len - 1, batch_size, dim)
        if torch.cuda.is_available():
            zero_t = zero_t.cuda(0)
            end_idxs_flat = end_idxs_flat.cuda(0)

        end_padded = torch.cat((end, Variable(zero_t)), 0)
        end_structed = end_padded[end_idxs_flat]  # (max_p_len*max_ans_len, batch_size, dim)
        end_structed = end_structed.view(max_p_len, max_ans_len, batch_size, dim)
        stt_shuffled = stt.unsqueeze(1)  # stt (max_p_len, 1, batch_size, dim)

        # since the FFNN(h_a) * W we expand h_a as [p_start, p_end]*[w_1 w_2] so this reduces to p_start*w_1 + p_end*w_2
        # now we can reuse the operations, we compute only once
        span_sums = stt_shuffled.expand(max_p_len, max_ans_len, batch_size, dim) + end_structed # (max_p_len, max_ans_len, batch_size, dim)

        span_sums_reshapped = span_sums.permute(2, 0, 1, 3).contiguous().view(batch_size, max_ans_len * max_p_len, dim)

        p_lens_shuffled = p_lens.unsqueeze(1)
        end_idxs_flat_shuffled = end_idxs_flat.unsqueeze(0)

        span_masks_reshaped = Variable(end_idxs_flat_shuffled.expand(p_lens_shuffled.size(0), end_idxs_flat_shuffled.size(1))) < p_lens_shuffled.expand(p_lens_shuffled.size(0), end_idxs_flat_shuffled.size(1))
        span_masks_reshaped = span_masks_reshaped.float()

        return span_sums_reshapped, span_masks_reshaped

    #q_align_weights = self.softmax(q_align_mask_scores)  # (batch_size, max_p_len, max_q_len)
项目:squad_rasor_nn    作者:hsgodhia    | 项目源码 | 文件源码
def _span_sums(self, p_lens, stt, end, max_p_len, batch_size, dim, max_ans_len):
        # stt       (max_p_len, batch_size, dim)
        # end       (max_p_len, batch_size, dim)
        # p_lens    (batch_size,)

        max_ans_len_range = torch.from_numpy(np.arange(max_ans_len))
        max_ans_len_range = max_ans_len_range.unsqueeze(0)  # (1, max_ans_len) is a vector like [0,1,2,3,4....,max_ans_len-1]
        offsets = torch.from_numpy(np.arange(max_p_len))
        offsets = offsets.unsqueeze(0)  # (1, max_p_len) is a vector like (0,1,2,3,4....max_p_len-1)
        offsets = offsets.transpose(0, 1)  # (max_p_len, 1) is row vector now like [0/1/2/3...max_p_len-1]

        end_idxs = max_ans_len_range.expand(offsets.size(0), max_ans_len_range.size(1)) + offsets.expand(offsets.size(0), max_ans_len_range.size(1))
        #pdb.set_trace()
        end_idxs_flat = end_idxs.view(-1, 1).squeeze(1)  # (max_p_len*max_ans_len, )
        # note: this is not modeled as tensor of size (SZ, 1) but vector of SZ size
        zero_t = torch.zeros(max_ans_len - 1, batch_size, dim)
        if torch.cuda.is_available():
            zero_t = zero_t.cuda(0)
            end_idxs_flat = end_idxs_flat.cuda(0)

        end_padded = torch.cat((end, Variable(zero_t)), 0)
        end_structed = end_padded[end_idxs_flat]  # (max_p_len*max_ans_len, batch_size, dim)
        end_structed = end_structed.view(max_p_len, max_ans_len, batch_size, dim)
        stt_shuffled = stt.unsqueeze(1)  # stt (max_p_len, 1, batch_size, dim)

        # since the FFNN(h_a) * W we expand h_a as [p_start, p_end]*[w_1 w_2] so this reduces to p_start*w_1 + p_end*w_2
        # now we can reuse the operations, we compute only once
        span_sums = stt_shuffled.expand(max_p_len, max_ans_len, batch_size, dim) + end_structed # (max_p_len, max_ans_len, batch_size, dim)

        span_sums_reshapped = span_sums.permute(2, 0, 1, 3).contiguous().view(batch_size, max_ans_len * max_p_len, dim)

        p_lens_shuffled = p_lens.unsqueeze(1)
        end_idxs_flat_shuffled = end_idxs_flat.unsqueeze(0)

        span_masks_reshaped = Variable(end_idxs_flat_shuffled.expand(p_lens_shuffled.size(0), end_idxs_flat_shuffled.size(1))) < p_lens_shuffled.expand(p_lens_shuffled.size(0), end_idxs_flat_shuffled.size(1))
        span_masks_reshaped = span_masks_reshaped.float()

        return span_sums_reshapped, span_masks_reshaped

    #q_align_weights = self.softmax(q_align_mask_scores)  # (batch_size, max_p_len, max_q_len)
项目:EarlyWarning    作者:wjlei1990    | 项目源码 | 文件源码
def predict_on_test(net, test_x):
    print("Predict...")
    pred_y = []
    for idx in range(len(test_x)):
        x = test_x[idx, :]
        x = Variable(torch.unsqueeze(torch.Tensor(x), dim=0)).cuda()
        y_p = net(x)
        _y = float(y_p.cpu().data.numpy()[0])
        # print("pred %d: %f | true y: %f" % (idx, _y, test_y[idx]))
        pred_y.append(_y)

    return pred_y
项目:EarlyWarning    作者:wjlei1990    | 项目源码 | 文件源码
def main():
    waveforms, magnitudes = load_data()
    loader = make_dataloader(waveforms, magnitudes)

    rnn = RNN(input_size, hidden_size, num_layers)
    print(rnn)

    optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
    loss_func = nn.MSELoss()

    for epoch in range(3):
        loss_epoch = []
        for step, (batch_x, batch_y) in enumerate(loader):
            x = torch.unsqueeze(batch_x[0, :, :].t(), dim=1)
            print('Epoch: ', epoch, '| Step: ', step, '| x: ',
                  x.size(), '| y: ', batch_y.numpy())
            x = Variable(x)
            y = Variable(torch.Tensor([batch_y.numpy(), ]))
            prediction = rnn(x)
            loss = loss_func(prediction, y)
            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()
            loss_epoch.append(loss.data[0])
            print("Current loss: %e --- loss mean: %f"
                  % (loss.data[0], np.mean(loss_epoch)))
项目:EarlyWarning    作者:wjlei1990    | 项目源码 | 文件源码
def predict_on_test(rnn, test_x):
    print("Predict...")
    pred_y = []
    for idx in range(len(test_x)):
        x = test_x[idx, :, :]
        x = Variable(torch.unsqueeze(torch.Tensor(x).t(), dim=1)).cuda()
        y_p = rnn(x)
        _y = float(y_p.cpu().data.numpy()[0])
        # print("pred %d: %f | true y: %f" % (idx, _y, test_y[idx]))
        pred_y.append(_y)

    return pred_y
项目:DBQA    作者:nanfeng1101    | 项目源码 | 文件源码
def forward(self, q_input, a_input):
        qw = torch.mm(q_input, self.W.view(self.input_size, -1)).view(-1, self.dim, self.input_size)
        qwa = torch.bmm(qw, torch.unsqueeze(a_input, 2))
        qa_vec = qwa.view(-1, self.dim)
        return qa_vec
项目:DBQA    作者:nanfeng1101    | 项目源码 | 文件源码
def forward(self, q_input, a_input, drop_rate):
        """
        input -> embedding_layer -> multi_cnn_layer -> interact_layer -> batchnorm_layer -> mlp_layer
        :param q_input: question sentence vec
        :param a_input: answer sentence vec
        :param: drop_rate: dropout rate
        :return:
        """
        q_input_emb = torch.unsqueeze(self.embedding(q_input), dim=1)
        a_input_emb = torch.unsqueeze(self.embedding(a_input), dim=1)
        q_vec, a_vec = self.inception_module_layers(q_input_emb, a_input_emb)
        qa_vec = self.interact_layer(q_vec, a_vec)
        bn_vec = self.bn_layer(qa_vec)
        prop, cate = self.mlp(bn_vec, drop_rate)
        return prop, cate
项目:URNN-PyTorch    作者:jingli9111    | 项目源码 | 文件源码
def _modReLU(self, h, bias):
        """
        sign(z)*relu(z)
        """
        batch_size = h.size(0)
        sign = torch.sign(h)
        bias_batch = (bias.unsqueeze(0)
                      .expand(batch_size, *bias.size()))
        return sign * functional.relu(torch.abs(h) + bias_batch)
项目:URNN-PyTorch    作者:jingli9111    | 项目源码 | 文件源码
def _forward_rnn(cell, input_, length, hx):
        max_time = input_.size(0)
        output = []
        for time in range(max_time):
            h_next = cell(input_=input_[time], hx=hx)
            # mask = (time < length).float().unsqueeze(1).expand_as(h_next)
            # h_next = h_next*mask + hx*(1 - mask)
            output.append(h_next)
        output = torch.stack(output, 0)
        return output, h_next
项目:URNN-PyTorch    作者:jingli9111    | 项目源码 | 文件源码
def _modReLU(self, h, bias):
        """
        sign(z)*relu(z)
        """
        batch_size = h.size(0)
        sign = torch.sign(h)
        bias_batch = (bias.unsqueeze(0)
                      .expand(batch_size, *bias.size()))
        return sign * functional.relu(torch.abs(h) + bias_batch)
项目:URNN-PyTorch    作者:jingli9111    | 项目源码 | 文件源码
def _forward_rnn(cell, input_, length, hx):
        max_time = input_.size(0)
        output = []
        for time in range(max_time):
            h_next = cell(input_=input_[time], hx=hx)
            # mask = (time < length).float().unsqueeze(1).expand_as(h_next)
            # h_next = h_next*mask + hx*(1 - mask)
            output.append(h_next)
        output = torch.stack(output, 0)
        return output, h_next
项目:nmp_qc    作者:priba    | 项目源码 | 文件源码
def u_ggnn(self, h_v, m_v, opt={}):
        h_v.contiguous()
        m_v.contiguous()
        h_new = self.learn_modules[0](torch.transpose(m_v, 0, 1), torch.unsqueeze(h_v, 0))[0]  # 0 or 1???
        return torch.transpose(h_new, 0, 1)
项目:nmp_qc    作者:priba    | 项目源码 | 文件源码
def forward(self, g, h_in, e):

        h = []

        # Padding to some larger dimension d
        h_t = torch.cat([h_in, Variable(
            torch.zeros(h_in.size(0), h_in.size(1), self.args['out'] - h_in.size(2)).type_as(h_in.data))], 2)

        h.append(h_t.clone())

        # Layer
        for t in range(0, self.n_layers):
            e_aux = e.view(-1, e.size(3))

            h_aux = h[t].view(-1, h[t].size(2))

            m = self.m[0].forward(h[t], h_aux, e_aux)
            m = m.view(h[0].size(0), h[0].size(1), -1, m.size(1))

            # Nodes without edge set message to 0
            m = torch.unsqueeze(g, 3).expand_as(m) * m

            m = torch.squeeze(torch.sum(m, 1))

            h_t = self.u[0].forward(h[t], m)

            # Delete virtual nodes
            h_t = (torch.sum(h_in, 2).expand_as(h_t) > 0).type_as(h_t) * h_t
            h.append(h_t)

        # Readout
        res = self.r.forward(h)

        if self.type == 'classification':
            res = nn.LogSoftmax()(res)
        return res
项目:nmp_qc    作者:priba    | 项目源码 | 文件源码
def m_mpnn(self, h_v, h_w, e_vw, opt={}):
        # Matrices for each edge
        edge_output = self.learn_modules[0](e_vw)
        edge_output = edge_output.view(-1, self.args['out'], self.args['in'])

        h_w_rows = h_w[..., None].expand(h_w.size(0), h_v.size(1), h_w.size(1)).contiguous()

        h_w_rows = h_w_rows.view(-1, self.args['in'])

        h_multiply = torch.bmm(edge_output, torch.unsqueeze(h_w_rows,2))

        m_new = torch.squeeze(h_multiply)

        return m_new
项目:ktorch    作者:farizrahman4u    | 项目源码 | 文件源码
def batch_dot(x, y, axes=None):
    if type(axes) is int:
        axes = (axes, axes)
    def _dot(X):
        x, y = X
        x_shape = x.size()
        y_shape = y.size()
        x_ndim = len(x_shape)
        y_ndim = len(y_shape)
        if x_ndim <= 3 and y_ndim <= 3:
            if x_ndim < 3:
                x_diff = 3 - x_ndim
                for i in range(diff):
                    x = torch.unsqueeze(x, x_ndim + i)
            else:
                x_diff = 0
            if y_ndim < 3:
                y_diff = 3 - y_ndim
                for i in range(diff):
                    y = torch.unsqueeze(y, y_ndim + i)
            else:
                y_diff = 0
            if axes[0] == 1:
                x = torch.transpose(x, 1, 2)
            elif axes[0] == 2:
                pass
            else:
                raise Exception('Invalid axis : ' + str(axes[0]))
            if axes[1] == 2:
                x = torch.transpose(x, 1, 2)
            # -------TODO--------------#
项目:ktorch    作者:farizrahman4u    | 项目源码 | 文件源码
def expand_dims(x, axis=-1):
    def _expand_dims(x, axis=axis):
        return torch.unsqueeze(x, axis)

    def _compute_output_shape(x, axis=axis):
        shape = list(_get_shape(x))
        shape.insert(axis, 1)
        return shape

    return get_op(_expand_dims, output_shape=_compute_output_shape, arguments=[axis])(x)
项目:paysage    作者:drckf    | 项目源码 | 文件源码
def scatter_(mat: T.Tensor, inds: T.LongTensor, val: T.Scalar) -> T.Tensor:
    """
    Assign a value a specific points in a matrix.
    Iterates along the rows of mat,
    successively assigning val to column indices given by inds.

    Note:
        Modifies mat in place.

    Args:
        mat: A tensor.
        inds: The indices
        val: The value to insert
    """
    return mat.scatter_(1, inds.unsqueeze(1), val)
项目:paysage    作者:drckf    | 项目源码 | 文件源码
def unsqueeze(tensor: T.Tensor, axis: int) -> T.Tensor:
    """
    Return tensor with a new axis inserted.

    Args:
        tensor: A tensor.
        axis: The desired axis.

    Returns:
        tensor: A tensor with the new axis inserted.

    """
    return torch.unsqueeze(tensor, axis)
项目:paysage    作者:drckf    | 项目源码 | 文件源码
def broadcast(vec: T.FloatTensor, matrix: T.FloatTensor) -> T.FloatTensor:
    """
    Broadcasts vec into the shape of matrix following numpy rules:

    vec ~ (N, 1) broadcasts to matrix ~ (N, M)
    vec ~ (1, N) and (N,) broadcast to matrix ~ (M, N)

    Args:
        vec: A vector (either flat, row, or column).
        matrix: A matrix (i.e., a 2D tensor).

    Returns:
        tensor: A tensor of the same size as matrix containing the elements
                of the vector.

    Raises:
        BroadcastError

    """
    try:
        if ndim(vec) == 1:
            if ndim(matrix) == 1:
                return vec
            return vec.unsqueeze(0).expand(matrix.size(0), matrix.size(1))
        else:
            return vec.expand(matrix.size(0), matrix.size(1))
    except ValueError:
        raise BroadcastError('cannot broadcast vector of dimension {} \
              onto matrix of dimension {}'.format(shape(vec), shape(matrix)))
项目:paysage    作者:drckf    | 项目源码 | 文件源码
def repeat(tensor: T.FloatTensor, n: int) -> T.FloatTensor:
    """
    Repeat tensor n times along specified axis.

    Args:
        tensor: A vector (i.e., 1D tensor).
        n: The number of repeats.

    Returns:
        tensor: A vector created from many repeats of the input tensor.

    """
    # current implementation only works for vectors
    assert ndim(tensor) == 1
    return flatten(tensor.unsqueeze(1).repeat(1, n))
项目:GAN_Liveness_Detection    作者:yunfan0621    | 项目源码 | 文件源码
def query(self, images):
        # images: torch.Variable of size [batch_size, channel * 2, w, h]

        if self.pool_size == 0:
            return images

        return_images = []
        for image in images.data: # traverse data in batch dimension
            image = torch.unsqueeze(image, 0)

            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                # randomly substitute
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)

        return_images = Variable(torch.cat(return_images, 0))

        return return_images
项目:LSH_Memory    作者:RUSH-LAB    | 项目源码 | 文件源码
def index(batch_size, x):
    idx = torch.arange(0, batch_size).long() 
    idx = torch.unsqueeze(idx, -1)
    return torch.cat((idx, x), dim=1)
项目:LSH_Memory    作者:RUSH-LAB    | 项目源码 | 文件源码
def update(self, query, y, y_hat, y_hat_indices):
        batch_size, dims = query.size()

        # 1) Untouched: Increment memory by 1
        self.age += 1

        # Divide batch by correctness
        result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float()
        incorrect_examples = torch.squeeze(torch.nonzero(1-result))
        correct_examples = torch.squeeze(torch.nonzero(result))

        incorrect = len(incorrect_examples.size()) > 0
        correct = len(correct_examples.size()) > 0

        # 2) Correct: if V[n1] = v
        # Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0
        if correct:
            correct_indices = y_hat_indices[correct_examples]
            correct_keys = self.keys[correct_indices]
            correct_query = query.data[correct_examples]

            new_correct_keys = F.normalize(correct_keys + correct_query, dim=1)
            self.keys[correct_indices] = new_correct_keys
            self.age[correct_indices] = 0

        # 3) Incorrect: if V[n1] != v
        # Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i 
        # K[n'] <- q, V[n'] <- v, A[n'] <- 0
        if incorrect:
            incorrect_size = incorrect_examples.size()[0]
            incorrect_query = query.data[incorrect_examples]
            incorrect_values = y.data[incorrect_examples]

            age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True)
            topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0)
            oldest_indices = torch.squeeze(topk_indices)

            self.keys[oldest_indices] = incorrect_query
            self.values[oldest_indices] = incorrect_values
            self.age[oldest_indices] = 0
项目:OpenNMT-py    作者:OpenNMT    | 项目源码 | 文件源码
def shape_transform(x):
    """ Tranform the size of the tensors to fit for conv input. """
    return torch.unsqueeze(torch.transpose(x, 1, 2), 3)
项目:relational-networks    作者:kimhc6028    | 项目源码 | 文件源码
def forward(self, img, qst):
        x = self.conv(img) ## x = (64 x 24 x 5 x 5)

        """g"""
        mb = x.size()[0]
        n_channels = x.size()[1]
        d = x.size()[2]
        # x_flat = (64 x 25 x 24)
        x_flat = x.view(mb,n_channels,d*d).permute(0,2,1)

        # add coordinates
        x_flat = torch.cat([x_flat, self.coord_tensor],2)

        # add question everywhere
        qst = torch.unsqueeze(qst, 1)
        qst = qst.repeat(1,25,1)
        qst = torch.unsqueeze(qst, 2)

        # cast all pairs against each other
        x_i = torch.unsqueeze(x_flat,1) # (64x1x25x26+11)
        x_i = x_i.repeat(1,25,1,1) # (64x25x25x26+11)
        x_j = torch.unsqueeze(x_flat,2) # (64x25x1x26+11)
        x_j = torch.cat([x_j,qst],3)
        x_j = x_j.repeat(1,1,25,1) # (64x25x25x26+11)

        # concatenate all together
        x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+11)

        # reshape for passing through network
        x_ = x_full.view(mb*d*d*d*d,63)
        x_ = self.g_fc1(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc2(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc3(x_)
        x_ = F.relu(x_)
        x_ = self.g_fc4(x_)
        x_ = F.relu(x_)

        # reshape again and sum
        x_g = x_.view(mb,d*d*d*d,256)
        x_g = x_g.sum(1).squeeze()

        """f"""
        x_f = self.f_fc1(x_g)
        x_f = F.relu(x_f)

        return self.fcout(x_f)
项目:MatchingNetworks    作者:gitabcworld    | 项目源码 | 文件源码
def run_validation_epoch(self):
        """
        Runs one validation epoch
        :param total_val_batches: Number of batches to train on
        :return: mean_validation_categorical_crossentropy_loss and mean_validation_accuracy
        """
        total_val_c_loss = 0.
        total_val_accuracy = 0.
        total_val_batches = len(self.val_loader)
        pbar = tqdm(enumerate(self.val_loader))
        for batch_idx, (x_support_set, y_support_set, x_target, target_y) in pbar:

                x_support_set = Variable(x_support_set).float()
                y_support_set = Variable(y_support_set,requires_grad=False).long()
                x_target = Variable(x_target.squeeze()).float()
                y_target = Variable(target_y.squeeze(),requires_grad=False).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "val_loss: {}, val_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_val_c_loss += c_loss_value.data[0]
                total_val_accuracy += acc.data[0]

        total_val_c_loss = total_val_c_loss / total_val_batches
        total_val_accuracy = total_val_accuracy / total_val_batches

        return total_val_c_loss, total_val_accuracy
项目:MatchingNetworks    作者:gitabcworld    | 项目源码 | 文件源码
def run_testing_epoch(self):
        """
        Runs one testing epoch
        :param total_test_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_testing_categorical_crossentropy_loss and mean_testing_accuracy
        """
        total_test_c_loss = 0.
        total_test_accuracy = 0.
        total_test_batches = len(self.test_loader)
        pbar = tqdm(enumerate(self.test_loader))
        for batch_idx, (x_support_set, y_support_set, x_target, target_y) in pbar:

                x_support_set = Variable(x_support_set).float()
                y_support_set = Variable(y_support_set,requires_grad=False).long()
                x_target = Variable(x_target.squeeze()).float()
                y_target = Variable(target_y.squeeze(),requires_grad=False).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "test_loss: {}, test_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_test_c_loss += c_loss_value.data[0]
                total_test_accuracy += acc.data[0]

        total_test_c_loss = total_test_c_loss / total_test_batches
        total_test_accuracy = total_test_accuracy / total_test_batches
        return total_test_c_loss, total_test_accuracy
项目:MatchingNetworks    作者:gitabcworld    | 项目源码 | 文件源码
def run_validation_epoch(self, total_val_batches):
        """
        Runs one validation epoch
        :param total_val_batches: Number of batches to train on
        :return: mean_validation_categorical_crossentropy_loss and mean_validation_accuracy
        """
        total_val_c_loss = 0.
        total_val_accuracy = 0.

        with tqdm.tqdm(total=total_val_batches) as pbar:
            for i in range(total_val_batches):  # validation epoch
                x_support_set, y_support_set, x_target, y_target = \
                    self.data.get_batch(str_type='val', rotate_flag=False)

                x_support_set = Variable(torch.from_numpy(x_support_set), volatile=True).float()
                y_support_set = Variable(torch.from_numpy(y_support_set), volatile=True).long()
                x_target = Variable(torch.from_numpy(x_target), volatile=True).float()
                y_target = Variable(torch.from_numpy(y_target), volatile=True).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                # Reshape channels
                size = x_support_set.size()
                x_support_set = x_support_set.view(size[0], size[1], size[4], size[2], size[3])
                size = x_target.size()
                x_target = x_target.view(size[0],size[1],size[4],size[2],size[3])
                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "val_loss: {}, val_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_val_c_loss += c_loss_value.data[0]
                total_val_accuracy += acc.data[0]

        total_val_c_loss = total_val_c_loss / total_val_batches
        total_val_accuracy = total_val_accuracy / total_val_batches

        return total_val_c_loss, total_val_accuracy
项目:MatchingNetworks    作者:gitabcworld    | 项目源码 | 文件源码
def run_testing_epoch(self, total_test_batches):
        """
        Runs one testing epoch
        :param total_test_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_testing_categorical_crossentropy_loss and mean_testing_accuracy
        """
        total_test_c_loss = 0.
        total_test_accuracy = 0.
        with tqdm.tqdm(total=total_test_batches) as pbar:
            for i in range(total_test_batches):
                x_support_set, y_support_set, x_target, y_target = \
                    self.data.get_batch(str_type='test', rotate_flag=False)

                x_support_set = Variable(torch.from_numpy(x_support_set), volatile=True).float()
                y_support_set = Variable(torch.from_numpy(y_support_set), volatile=True).long()
                x_target = Variable(torch.from_numpy(x_target), volatile=True).float()
                y_target = Variable(torch.from_numpy(y_target), volatile=True).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                # Reshape channels
                size = x_support_set.size()
                x_support_set = x_support_set.view(size[0], size[1], size[4], size[2], size[3])
                size = x_target.size()
                x_target = x_target.view(size[0],size[1],size[4],size[2],size[3])
                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "test_loss: {}, test_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_test_c_loss += c_loss_value.data[0]
                total_test_accuracy += acc.data[0]
            total_test_c_loss = total_test_c_loss / total_test_batches
            total_test_accuracy = total_test_accuracy / total_test_batches
        return total_test_c_loss, total_test_accuracy
项目:EarlyWarning    作者:wjlei1990    | 项目源码 | 文件源码
def main():
    outputdir = "output.disp.abs"
    if not os.path.exists(outputdir):
        os.makedirs(outputdir)

    waveforms, magnitudes = load_data()
    data_split = split_data(waveforms, magnitudes, train_percentage=0.9)
    print("dimension of train x and y: ", data_split["train_x"].shape,
          data_split["train_y"].shape)
    print("dimension of test x and y: ", data_split["test_x"].shape,
          data_split["test_y"].shape)
    train_loader = make_dataloader(data_split["train_x"],
                                   data_split["train_y"])

    rnn = RNN(input_size, hidden_size, num_layers)
    rnn.cuda()
    print(rnn)

    optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
    loss_func = nn.MSELoss()

    # train
    ntest = data_split["train_x"].shape[0]
    all_loss = {}
    for epoch in range(3):
        loss_epoch = []
        for step, (batch_x, batch_y) in enumerate(train_loader):
            x = torch.unsqueeze(batch_x[0, :, :].t(), dim=1)
            if step % int((ntest/100) + 1) == 1:
                print('Epoch: ', epoch, '| Step: %d/%d' % (step, ntest),
                      "| Loss: %f" % np.mean(loss_epoch))
            if CUDA_FLAG:
                x = Variable(x).cuda()
                y = Variable(torch.Tensor([batch_y.numpy(), ])).cuda()
            else:
                x = Variable(x)
                y = Variable(torch.Tensor([batch_y.numpy(), ]))
            prediction = rnn(x)
            loss = loss_func(prediction, y)
            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()
            loss_epoch.append(loss.data[0])
        all_loss["epoch_%d" % epoch] = loss_epoch

        outputfn = os.path.join(outputdir, "loss.epoch_%d.json" % epoch)
        print("epoch loss file: %s" % outputfn)
        dump_json(loss_epoch, outputfn)

    # test
    pred_y = predict_on_test(rnn, data_split["test_x"])
    test_y = data_split["test_y"]
    _mse = mean_squared_error(test_y, pred_y)
    _std = np.std(test_y - pred_y)
    print("MSE and error std: %f, %f" % (_mse, _std))

    outputfn = os.path.join(outputdir, "prediction.json")
    print("output file: %s" % outputfn)
    data = {"test_y": list(test_y), "test_y_pred": list(pred_y),
            "epoch_loss": all_loss, "mse": _mse, "err_std": _std}
    dump_json(data, outputfn)
项目:URNN-PyTorch    作者:jingli9111    | 项目源码 | 文件源码
def _EUNN(self, hx, thetaA, thetaB):

        L = self.capacity
        N = self.hidden_size

        sinA = torch.sin(self.thetaA)
        cosA = torch.cos(self.thetaA)
        sinB = torch.sin(self.thetaB)
        cosB = torch.cos(self.thetaB)

        I = Variable(torch.ones((L/2, 1)))
        O = Variable(torch.zeros((L/2, 1)))

        diagA = torch.stack((cosA, cosA), 2)
        offA = torch.stack((-sinA, sinA), 2)
        diagB = torch.stack((cosB, cosB), 2)
        offB = torch.stack((-sinB, sinB), 2)

        diagA = diagA.view(L/2, N)
        offA = offA.view(L/2, N)
        diagB = diagB.view(L/2, N-2)
        offB = offB.view(L/2, N-2)

        diagB = torch.cat((I, diagB, I), 1)
        offB = torch.cat((O, offB, O), 1)

        batch_size = hx.size()[0]
        x = hx
        for i in range(L/2):
#           # A
            y = x.view(batch_size, N/2, 2)
            y = torch.stack((y[:,:,1], y[:,:,0]), 2)
            y = y.view(batch_size, N)

            x = torch.mul(x, diagA[i].expand_as(x))
            y = torch.mul(y, offA[i].expand_as(x))

            x = x + y

            # B
            x_top = x[:,0]
            x_mid = x[:,1:-1].contiguous()
            x_bot = x[:,-1]
            y = x_mid.view(batch_size, N/2-1, 2)
            y = torch.stack((y[:, :, 1], y[:, :, 0]), 1)
            y = y.view(batch_size, N-2)
            x_top = torch.unsqueeze(x_top, 1)
            x_bot = torch.unsqueeze(x_bot, 1)
            # print x_top.size(), y.size(), x_bot.size()
            y = torch.cat((x_top, y, x_bot), 1)

            x = x * diagB[i].expand(batch_size, N)
            y = y * offB[i].expand(batch_size, N)

            x = x + y
        return x
项目:URNN-PyTorch    作者:jingli9111    | 项目源码 | 文件源码
def _EUNN(self, hx, thetaA, thetaB):

        L = self.capacity
        N = self.hidden_size

        sinA = torch.sin(self.thetaA)
        cosA = torch.cos(self.thetaA)
        sinB = torch.sin(self.thetaB)
        cosB = torch.cos(self.thetaB)

        I = Variable(torch.ones((L//2, 1)))
        O = Variable(torch.zeros((L//2, 1)))

        diagA = torch.stack((cosA, cosA), 2)
        offA = torch.stack((-sinA, sinA), 2)
        diagB = torch.stack((cosB, cosB), 2)
        offB = torch.stack((-sinB, sinB), 2)

        diagA = diagA.view(L//2, N)
        offA = offA.view(L//2, N)
        diagB = diagB.view(L//2, N-2)
        offB = offB.view(L//2, N-2)

        diagB = torch.cat((I, diagB, I), 1)
        offB = torch.cat((O, offB, O), 1)

        batch_size = hx.size()[0]
        x = hx
        for i in range(L//2):
#           # A
            y = x.view(batch_size, N//2, 2)
            y = torch.stack((y[:,:,1], y[:,:,0]), 2)
            y = y.view(batch_size, N)

            x = torch.mul(x, diagA[i].expand_as(x))
            y = torch.mul(y, offA[i].expand_as(x))

            x = x + y

            # B
            x_top = x[:,0]
            x_mid = x[:,1:-1].contiguous()
            x_bot = x[:,-1]
            y = x_mid.view(batch_size, N//2-1, 2)
            y = torch.stack((y[:, :, 1], y[:, :, 0]), 1)
            y = y.view(batch_size, N-2)
            x_top = torch.unsqueeze(x_top, 1)
            x_bot = torch.unsqueeze(x_bot, 1)
            # print x_top.size(), y.size(), x_bot.size()
            y = torch.cat((x_top, y, x_bot), 1)

            x = x * diagB[i].expand(batch_size, N)
            y = y * offB[i].expand(batch_size, N)

            x = x + y
        return x