Python torch 模块,eq() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用torch.eq()

项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def forward(self, input, target):
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(self.margin).add_(-1, input)
        buffer.cmax_(0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if self.size_average:
            output = output / input.nelement()

        self.save_for_backward(input, target)
        return input.new((output,))
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def forward(self, input, target):
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(self.margin).add_(-1, input)
        buffer.clamp_(min=0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if self.size_average:
            output = output / input.nelement()

        self.save_for_backward(input, target)
        return input.new((output,))
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def updateOutput(self, input, y):
        if self.buffer is None:
            self.buffer = input.new()
        self.buffer.resize_as_(input).copy_(input)
        self.buffer[torch.eq(y, -1.)] = 0
        self.output = self.buffer.sum()

        self.buffer.fill_(self.margin).add_(-1, input)
        self.buffer.clamp_(min=0)
        self.buffer[torch.eq(y, 1.)] = 0
        self.output = self.output + self.buffer.sum()

        if self.sizeAverage:
            self.output = self.output / input.nelement()

        return self.output
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def forward(self, input, target):
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(self.margin).add_(-1, input)
        buffer.clamp_(min=0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if self.size_average:
            output = output / input.nelement()

        self.save_for_backward(input, target)
        return input.new((output,))
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def updateOutput(self, input, y):
        if self.buffer is None:
            self.buffer = input.new()
        self.buffer.resize_as_(input).copy_(input)
        self.buffer[torch.eq(y, -1.)] = 0
        self.output = self.buffer.sum()

        self.buffer.fill_(self.margin).add_(-1, input)
        self.buffer.clamp_(min=0)
        self.buffer[torch.eq(y, 1.)] = 0
        self.output = self.output + self.buffer.sum()

        if self.sizeAverage:
            self.output = self.output / input.nelement()

        return self.output
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def forward(ctx, input, target, margin, size_average):
        ctx.margin = margin
        ctx.size_average = size_average
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(ctx.margin).add_(-1, input)
        buffer.clamp_(min=0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if ctx.size_average:
            output = output / input.nelement()

        ctx.save_for_backward(input, target)
        return input.new((output,))
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def updateOutput(self, input, y):
        if self.buffer is None:
            self.buffer = input.new()
        self.buffer.resize_as_(input).copy_(input)
        self.buffer[torch.eq(y, -1.)] = 0
        self.output = self.buffer.sum()

        self.buffer.fill_(self.margin).add_(-1, input)
        self.buffer.clamp_(min=0)
        self.buffer[torch.eq(y, 1.)] = 0
        self.output = self.output + self.buffer.sum()

        if self.sizeAverage:
            self.output = self.output / input.nelement()

        return self.output
项目:NeuroNLP2    作者:XuezheMax    | 项目源码 | 文件源码
def decode(self, input_word, input_char, target=None, mask=None, length=None, hx=None, leading_symbolic=0):
        # output from rnn [batch, length, tag_space]
        output, _, mask, length = self._get_rnn_output(input_word, input_char, mask=mask, length=length, hx=hx)

        if target is None:
            return self.crf.decode(output, mask=mask, leading_symbolic=leading_symbolic), None

        if length is not None:
            max_len = length.max()
            target = target[:, :max_len]

        preds = self.crf.decode(output, mask=mask, leading_symbolic=leading_symbolic)
        if mask is None:
            return preds, torch.eq(preds, target.data).float().sum()
        else:
            return preds, (torch.eq(preds, target.data).float() * mask.data).sum()
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def forward(ctx, input, target, margin, size_average):
        ctx.margin = margin
        ctx.size_average = size_average
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(ctx.margin).add_(-1, input)
        buffer.clamp_(min=0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if ctx.size_average:
            output = output / input.nelement()

        ctx.save_for_backward(input, target)
        return input.new((output,))
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def updateOutput(self, input, y):
        if self.buffer is None:
            self.buffer = input.new()
        self.buffer.resize_as_(input).copy_(input)
        self.buffer[torch.eq(y, -1.)] = 0
        self.output = self.buffer.sum()

        self.buffer.fill_(self.margin).add_(-1, input)
        self.buffer.clamp_(min=0)
        self.buffer[torch.eq(y, 1.)] = 0
        self.output = self.output + self.buffer.sum()

        if self.sizeAverage:
            self.output = self.output / input.nelement()

        return self.output
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def backward(self, grad_output):
        v1, v2, y = self.saved_tensors

        buffer = v1.new()
        _idx = self._new_idx(v1)

        gw1 = grad_output.new()
        gw2 = grad_output.new()
        gw1.resize_as_(v1).copy_(v2)
        gw2.resize_as_(v1).copy_(v1)

        torch.mul(buffer, self.w1, self.w22)
        gw1.addcmul_(-1, buffer.expand_as(v1), v1)
        gw1.mul_(self.w.expand_as(v1))

        torch.mul(buffer, self.w1, self.w32)
        gw2.addcmul_(-1, buffer.expand_as(v1), v2)
        gw2.mul_(self.w.expand_as(v1))

        torch.le(_idx, self._outputs, 0)
        _idx = _idx.view(-1, 1).expand(gw1.size())
        gw1[_idx] = 0
        gw2[_idx] = 0

        torch.eq(_idx, y, 1)
        _idx = _idx.view(-1, 1).expand(gw2.size())
        gw1[_idx] = gw1[_idx].mul_(-1)
        gw2[_idx] = gw2[_idx].mul_(-1)

        if self.size_average:
            gw1.div_(y.size(0))
            gw2.div_(y.size(0))

        if grad_output[0] != 1:
            gw1.mul_(grad_output)
            gw2.mul_(grad_output)

        return gw1, gw2, None
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def backward(self, grad_output):
        input, target = self.saved_tensors
        grad_input = input.new().resize_as_(input).copy_(target)
        grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, self.margin))] = 0

        if self.size_average:
            grad_input.mul_(1. / input.nelement())

        if grad_output[0] != 1:
            grad_input.mul_(grad_output[0])

        return grad_input, None
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def updateOutput(self, input, y):
        self.buffer = self.buffer or input.new()
        self.buffer.resize_as_(input).copy_(input)
        self.buffer[torch.eq(y, -1.)] = 0
        self.output = self.buffer.sum()

        self.buffer.fill_(self.margin).add_(-1, input)
        self.buffer.cmax_(0)
        self.buffer[torch.eq(y, 1.)] = 0
        self.output = self.output + self.buffer.sum()

        if self.sizeAverage:
            self.output = self.output / input.nelement()

        return self.output
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def updateGradInput(self, input, y):
        self.gradInput.resize_as_(input).copy_(y)
        self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0

        if self.sizeAverage:
            self.gradInput.mul_(1. / input.nelement())

        return self.gradInput
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def test_logical(self):
        x = torch.rand(100, 100) * 2 - 1
        xx = x.clone()

        xgt = torch.gt(x, 1)
        xlt = torch.lt(x, 1)

        xeq = torch.eq(x, 1)
        xne = torch.ne(x, 1)

        neqs = xgt + xlt
        all = neqs + xeq
        self.assertEqual(neqs.sum(), xne.sum(), 0)
        self.assertEqual(x.nelement(), all.sum())
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def forward(self, input1, input2, y):
        self.w1 = input1.new()
        self.w22 = input1.new()
        self.w = input1.new()
        self.w32 = input1.new()
        self._outputs = input1.new()

        _idx = input1.new().byte()

        buffer = torch.mul(input1, input2)
        torch.sum(buffer, 1, out=self.w1)

        epsilon = 1e-12
        torch.mul(input1, input1, out=buffer)
        torch.sum(buffer, 1, out=self.w22).add_(epsilon)

        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self._outputs, self.w22, out=self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(input2, input2, out=buffer)
        torch.sum(buffer, 1, out=self.w32).add_(epsilon)
        torch.div(self._outputs, self.w32, out=self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self.w1, self.w, out=self._outputs)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(y, -1, out=_idx)
        self._outputs[_idx] = self._outputs[_idx].add_(-self.margin).clamp_(min=0)
        torch.eq(y, 1, out=_idx)
        self._outputs[_idx] = self._outputs[_idx].mul_(-1).add_(1)

        output = self._outputs.sum()

        if self.size_average:
            output = output / y.size(0)

        self.save_for_backward(input1, input2, y)
        return input1.new((output,))
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def backward(self, grad_output):
        v1, v2, y = self.saved_tensors

        buffer = v1.new()
        _idx = v1.new().byte()

        gw1 = grad_output.new()
        gw2 = grad_output.new()
        gw1.resize_as_(v1).copy_(v2)
        gw2.resize_as_(v1).copy_(v1)

        torch.mul(self.w1, self.w22, out=buffer)
        gw1.addcmul_(-1, buffer.expand_as(v1), v1)
        gw1.mul_(self.w.expand_as(v1))

        torch.mul(self.w1, self.w32, out=buffer)
        gw2.addcmul_(-1, buffer.expand_as(v1), v2)
        gw2.mul_(self.w.expand_as(v1))

        torch.le(self._outputs, 0, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw1.size())
        gw1[_idx] = 0
        gw2[_idx] = 0

        torch.eq(y, 1, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw2.size())
        gw1[_idx] = gw1[_idx].mul_(-1)
        gw2[_idx] = gw2[_idx].mul_(-1)

        if self.size_average:
            gw1.div_(y.size(0))
            gw2.div_(y.size(0))

        grad_output_val = grad_output[0]
        if grad_output_val != 1:
            gw1.mul_(grad_output_val)
            gw2.mul_(grad_output_val)

        return gw1, gw2, None
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def backward(self, grad_output):
        input, target = self.saved_tensors
        grad_input = input.new().resize_as_(input).copy_(target)
        grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, self.margin))] = 0

        if self.size_average:
            grad_input.mul_(1. / input.nelement())

        if grad_output[0] != 1:
            grad_input.mul_(grad_output[0])

        return grad_input, None
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def updateGradInput(self, input, y):
        self.gradInput.resize_as_(input).copy_(y)
        self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0

        if self.sizeAverage:
            self.gradInput.mul_(1. / input.nelement())

        return self.gradInput
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_logical(self):
        x = torch.rand(100, 100) * 2 - 1
        xx = x.clone()

        xgt = torch.gt(x, 1)
        xlt = torch.lt(x, 1)

        xeq = torch.eq(x, 1)
        xne = torch.ne(x, 1)

        neqs = xgt + xlt
        all = neqs + xeq
        self.assertEqual(neqs.sum(), xne.sum(), 0)
        self.assertEqual(x.nelement(), all.sum())
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_comparison_ops(self):
        x = torch.randn(5, 5)
        y = torch.randn(5, 5)

        eq = x == y
        for idx in iter_indices(x):
            self.assertIs(x[idx] == y[idx], eq[idx] == 1)

        ne = x != y
        for idx in iter_indices(x):
            self.assertIs(x[idx] != y[idx], ne[idx] == 1)

        lt = x < y
        for idx in iter_indices(x):
            self.assertIs(x[idx] < y[idx], lt[idx] == 1)

        le = x <= y
        for idx in iter_indices(x):
            self.assertIs(x[idx] <= y[idx], le[idx] == 1)

        gt = x > y
        for idx in iter_indices(x):
            self.assertIs(x[idx] > y[idx], gt[idx] == 1)

        ge = x >= y
        for idx in iter_indices(x):
            self.assertIs(x[idx] >= y[idx], ge[idx] == 1)
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def forward(self, input1, input2, y):
        self.w1 = input1.new()
        self.w22 = input1.new()
        self.w = input1.new()
        self.w32 = input1.new()
        self._outputs = input1.new()

        _idx = input1.new().byte()

        buffer = torch.mul(input1, input2)
        torch.sum(buffer, 1, out=self.w1, keepdim=True)

        epsilon = 1e-12
        torch.mul(input1, input1, out=buffer)
        torch.sum(buffer, 1, out=self.w22, keepdim=True).add_(epsilon)

        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self._outputs, self.w22, out=self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(input2, input2, out=buffer)
        torch.sum(buffer, 1, out=self.w32, keepdim=True).add_(epsilon)
        torch.div(self._outputs, self.w32, out=self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self.w1, self.w, out=self._outputs)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(y, -1, out=_idx)
        self._outputs[_idx] = self._outputs[_idx].add_(-self.margin).clamp_(min=0)
        torch.eq(y, 1, out=_idx)
        self._outputs[_idx] = self._outputs[_idx].mul_(-1).add_(1)

        output = self._outputs.sum()

        if self.size_average:
            output = output / y.size(0)

        self.save_for_backward(input1, input2, y)
        return input1.new((output,))
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def backward(self, grad_output):
        v1, v2, y = self.saved_tensors

        buffer = v1.new()
        _idx = v1.new().byte()

        gw1 = grad_output.new()
        gw2 = grad_output.new()
        gw1.resize_as_(v1).copy_(v2)
        gw2.resize_as_(v1).copy_(v1)

        torch.mul(self.w1, self.w22, out=buffer)
        gw1.addcmul_(-1, buffer.expand_as(v1), v1)
        gw1.mul_(self.w.expand_as(v1))

        torch.mul(self.w1, self.w32, out=buffer)
        gw2.addcmul_(-1, buffer.expand_as(v1), v2)
        gw2.mul_(self.w.expand_as(v1))

        torch.le(self._outputs, 0, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw1.size())
        gw1[_idx] = 0
        gw2[_idx] = 0

        torch.eq(y, 1, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw2.size())
        gw1[_idx] = gw1[_idx].mul_(-1)
        gw2[_idx] = gw2[_idx].mul_(-1)

        if self.size_average:
            gw1.div_(y.size(0))
            gw2.div_(y.size(0))

        grad_output_val = grad_output[0]
        if grad_output_val != 1:
            gw1.mul_(grad_output_val)
            gw2.mul_(grad_output_val)

        return gw1, gw2, None
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def backward(self, grad_output):
        input, target = self.saved_tensors
        grad_input = input.new().resize_as_(input).copy_(target)
        grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, self.margin))] = 0

        if self.size_average:
            grad_input.mul_(1. / input.nelement())

        if grad_output[0] != 1:
            grad_input.mul_(grad_output[0])

        return grad_input, None
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def updateGradInput(self, input, y):
        self.gradInput.resize_as_(input).copy_(y)
        self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0

        if self.sizeAverage:
            self.gradInput.mul_(1. / input.nelement())

        return self.gradInput
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_logical(self):
        x = torch.rand(100, 100) * 2 - 1
        xx = x.clone()

        xgt = torch.gt(x, 1)
        xlt = torch.lt(x, 1)

        xeq = torch.eq(x, 1)
        xne = torch.ne(x, 1)

        neqs = xgt + xlt
        all = neqs + xeq
        self.assertEqual(neqs.sum(), xne.sum(), 0)
        self.assertEqual(x.nelement(), all.sum())
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_comparison_ops(self):
        x = torch.randn(5, 5)
        y = torch.randn(5, 5)

        eq = x == y
        for idx in iter_indices(x):
            self.assertIs(x[idx] == y[idx], eq[idx] == 1)

        ne = x != y
        for idx in iter_indices(x):
            self.assertIs(x[idx] != y[idx], ne[idx] == 1)

        lt = x < y
        for idx in iter_indices(x):
            self.assertIs(x[idx] < y[idx], lt[idx] == 1)

        le = x <= y
        for idx in iter_indices(x):
            self.assertIs(x[idx] <= y[idx], le[idx] == 1)

        gt = x > y
        for idx in iter_indices(x):
            self.assertIs(x[idx] > y[idx], gt[idx] == 1)

        ge = x >= y
        for idx in iter_indices(x):
            self.assertIs(x[idx] >= y[idx], ge[idx] == 1)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def backward(ctx, grad_output):
        v1, v2, y = ctx.saved_tensors

        buffer = v1.new()
        _idx = v1.new().byte()

        gw1 = grad_output.new()
        gw2 = grad_output.new()
        gw1.resize_as_(v1).copy_(v2)
        gw2.resize_as_(v1).copy_(v1)

        torch.mul(ctx.w1, ctx.w22, out=buffer)
        gw1.addcmul_(-1, buffer.expand_as(v1), v1)
        gw1.mul_(ctx.w.expand_as(v1))

        torch.mul(ctx.w1, ctx.w32, out=buffer)
        gw2.addcmul_(-1, buffer.expand_as(v1), v2)
        gw2.mul_(ctx.w.expand_as(v1))

        torch.le(ctx._outputs, 0, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw1.size())
        gw1[_idx] = 0
        gw2[_idx] = 0

        torch.eq(y, 1, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw2.size())
        gw1[_idx] = gw1[_idx].mul_(-1)
        gw2[_idx] = gw2[_idx].mul_(-1)

        if ctx.size_average:
            gw1.div_(y.size(0))
            gw2.div_(y.size(0))

        grad_output_val = grad_output[0]
        if grad_output_val != 1:
            gw1.mul_(grad_output_val)
            gw2.mul_(grad_output_val)

        return gw1, gw2, None, None, None
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def forward(ctx, input, target, grad_output, margin, size_average):
        ctx.margin = margin
        ctx.size_average = size_average
        ctx.save_for_backward(input, target, grad_output)
        grad_input = input.new().resize_as_(input).copy_(target)
        grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, ctx.margin))] = 0

        if ctx.size_average:
            grad_input.mul_(1. / input.nelement())

        if grad_output[0] != 1:
            grad_input.mul_(grad_output[0])

        return grad_input
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def updateGradInput(self, input, y):
        self.gradInput.resize_as_(input).copy_(y)
        self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0

        if self.sizeAverage:
            self.gradInput.mul_(1. / input.nelement())

        return self.gradInput
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_logical(self):
        x = torch.rand(100, 100) * 2 - 1
        xx = x.clone()

        xgt = torch.gt(x, 1)
        xlt = torch.lt(x, 1)

        xeq = torch.eq(x, 1)
        xne = torch.ne(x, 1)

        neqs = xgt + xlt
        all = neqs + xeq
        self.assertEqual(neqs.sum(), xne.sum(), 0)
        self.assertEqual(x.nelement(), all.sum())
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_comparison_ops(self):
        x = torch.randn(5, 5)
        y = torch.randn(5, 5)

        eq = x == y
        for idx in iter_indices(x):
            self.assertIs(x[idx] == y[idx], eq[idx] == 1)

        ne = x != y
        for idx in iter_indices(x):
            self.assertIs(x[idx] != y[idx], ne[idx] == 1)

        lt = x < y
        for idx in iter_indices(x):
            self.assertIs(x[idx] < y[idx], lt[idx] == 1)

        le = x <= y
        for idx in iter_indices(x):
            self.assertIs(x[idx] <= y[idx], le[idx] == 1)

        gt = x > y
        for idx in iter_indices(x):
            self.assertIs(x[idx] > y[idx], gt[idx] == 1)

        ge = x >= y
        for idx in iter_indices(x):
            self.assertIs(x[idx] >= y[idx], ge[idx] == 1)
项目:sourceseparation_misc    作者:ycemsubakan    | 项目源码 | 文件源码
def form_mixtures(digit1, digit2, loader, arguments): 
    dataset1, dataset2 = [], []
    for i, (ft, tar) in enumerate(loader):   
        # digit 1
        mask = torch.eq(tar, digit1)
        inds = torch.nonzero(mask).squeeze()
        ft1 = torch.index_select(ft, dim=0, index=inds)
        dataset1.append(ft1)

        # digit 2
        mask = torch.eq(tar, digit2)
        inds = torch.nonzero(mask).squeeze()
        ft2 = torch.index_select(ft, dim=0, index=inds)
        dataset2.append(ft2)
        print(i)

    dataset1 = torch.cat(dataset1, dim=0)
    dataset2 = torch.cat(dataset2, dim=0)

    if arguments.input_type == 'noise':
        inp1 = torch.randn(dataset1.size(0), arguments.L1) 
        inp2 = torch.randn(dataset2.size(0), arguments.L1) 
    elif arguments.input_type == 'autoenc':
        inp1 = dataset1
        inp2 = dataset2
    else:
        raise ValueError('Whaaaaaat input_type?')

    N1, N2 = dataset1.size(0), dataset2.size(0)
    Nmix = min([N1, N2])

    dataset_mix = dataset1[:Nmix] + dataset2[:Nmix]

    dataset1 = TensorDataset(data_tensor=inp1,
                                        target_tensor=dataset1,
                                        lens=[1]*Nmix)
    dataset2 = data_utils.TensorDataset(data_tensor=inp2,
                                        target_tensor=dataset2)
    dataset_mix = data_utils.TensorDataset(data_tensor=dataset_mix,
                                        target_tensor=torch.ones(Nmix))

    kwargs = {'num_workers': 1, 'pin_memory': True} if arguments.cuda else {}
    loader1 = data_utils.DataLoader(dataset1, batch_size=arguments.batch_size, shuffle=False, **kwargs)
    loader2 = data_utils.DataLoader(dataset2, batch_size=arguments.batch_size, shuffle=False, **kwargs)
    loader_mix = data_utils.DataLoader(dataset_mix, batch_size=arguments.batch_size, shuffle=False, **kwargs)

    return loader1, loader2, loader_mix
项目:NeuroNLP2    作者:XuezheMax    | 项目源码 | 文件源码
def loss(self, input_word, input_char, target, mask=None, length=None, hx=None, leading_symbolic=0):
        # [batch, length, num_labels]
        output, mask, length = self.forward(input_word, input_char, mask=mask, length=length, hx=hx)
        # [batch, length, num_labels]
        output = self.dense_softmax(output)
        # preds = [batch, length]
        _, preds = torch.max(output[:, :, leading_symbolic:], dim=2)
        preds += leading_symbolic

        output_size = output.size()
        # [batch * length, num_labels]
        output_size = (output_size[0] * output_size[1], output_size[2])
        output = output.view(output_size)

        if length is not None and target.size(1) != mask.size(1):
            max_len = length.max()
            target = target[:, :max_len].contiguous()

        if mask is not None:
            # TODO for Pytorch 2.0.4, first take nllloss then mask (no need of broadcast for mask)
            return self.nll_loss(self.logsoftmax(output) * mask.contiguous().view(output_size[0], 1),
                                 target.view(-1)) / mask.sum(), \
                   (torch.eq(preds, target).type_as(mask) * mask).sum(), preds
        else:
            num = output_size[0] * output_size[1]
            return self.nll_loss(self.logsoftmax(output), target.view(-1)) / num, \
                   (torch.eq(preds, target).type_as(output)).sum(), preds
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def backward(ctx, grad_output):
        v1, v2, y = ctx.saved_tensors

        buffer = v1.new()
        _idx = v1.new().byte()

        gw1 = grad_output.new()
        gw2 = grad_output.new()
        gw1.resize_as_(v1).copy_(v2)
        gw2.resize_as_(v1).copy_(v1)

        torch.mul(ctx.w1, ctx.w22, out=buffer)
        gw1.addcmul_(-1, buffer.expand_as(v1), v1)
        gw1.mul_(ctx.w.expand_as(v1))

        torch.mul(ctx.w1, ctx.w32, out=buffer)
        gw2.addcmul_(-1, buffer.expand_as(v1), v2)
        gw2.mul_(ctx.w.expand_as(v1))

        torch.le(ctx._outputs, 0, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw1.size())
        gw1[_idx] = 0
        gw2[_idx] = 0

        torch.eq(y, 1, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw2.size())
        gw1[_idx] = gw1[_idx].mul_(-1)
        gw2[_idx] = gw2[_idx].mul_(-1)

        if ctx.size_average:
            gw1.div_(y.size(0))
            gw2.div_(y.size(0))

        grad_output_val = grad_output[0]
        if grad_output_val != 1:
            gw1.mul_(grad_output_val)
            gw2.mul_(grad_output_val)

        return gw1, gw2, None, None, None
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def forward(ctx, input, target, grad_output, margin, size_average):
        ctx.margin = margin
        ctx.size_average = size_average
        ctx.save_for_backward(input, target, grad_output)
        grad_input = input.new().resize_as_(input).copy_(target)
        grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, ctx.margin))] = 0

        if ctx.size_average:
            grad_input.mul_(1. / input.nelement())

        if grad_output[0] != 1:
            grad_input.mul_(grad_output[0])

        return grad_input
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def updateGradInput(self, input, y):
        self.gradInput.resize_as_(input).copy_(y)
        self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0

        if self.sizeAverage:
            self.gradInput.mul_(1. / input.nelement())

        return self.gradInput
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_logical(self):
        x = torch.rand(100, 100) * 2 - 1
        xx = x.clone()

        xgt = torch.gt(x, 1)
        xlt = torch.lt(x, 1)

        xeq = torch.eq(x, 1)
        xne = torch.ne(x, 1)

        neqs = xgt + xlt
        all = neqs + xeq
        self.assertEqual(neqs.sum(), xne.sum(), 0)
        self.assertEqual(x.nelement(), all.sum())
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_comparison_ops(self):
        x = torch.randn(5, 5)
        y = torch.randn(5, 5)

        eq = x == y
        for idx in iter_indices(x):
            self.assertIs(x[idx] == y[idx], eq[idx] == 1)

        ne = x != y
        for idx in iter_indices(x):
            self.assertIs(x[idx] != y[idx], ne[idx] == 1)

        lt = x < y
        for idx in iter_indices(x):
            self.assertIs(x[idx] < y[idx], lt[idx] == 1)

        le = x <= y
        for idx in iter_indices(x):
            self.assertIs(x[idx] <= y[idx], le[idx] == 1)

        gt = x > y
        for idx in iter_indices(x):
            self.assertIs(x[idx] > y[idx], gt[idx] == 1)

        ge = x >= y
        for idx in iter_indices(x):
            self.assertIs(x[idx] >= y[idx], ge[idx] == 1)
项目:paysage    作者:drckf    | 项目源码 | 文件源码
def equal(x: T.FloatTensor, y: T.FloatTensor) -> T.ByteTensor:
    """
    Elementwise test for if two tensors are equal.

    Args:
        x: A tensor.
        y: A tensor.

    Returns:
        tensor (of bools): Elementwise test of equality between x and y.

    """
    return torch.eq(x, y)
项目:LSH_Memory    作者:RUSH-LAB    | 项目源码 | 文件源码
def update(self, query, y, y_hat, y_hat_indices):
        batch_size, dims = query.size()

        # 1) Untouched: Increment memory by 1
        self.age += 1

        # Divide batch by correctness
        result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float()
        incorrect_examples = torch.squeeze(torch.nonzero(1-result))
        correct_examples = torch.squeeze(torch.nonzero(result))

        incorrect = len(incorrect_examples.size()) > 0
        correct = len(correct_examples.size()) > 0

        # 2) Correct: if V[n1] = v
        # Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0
        if correct:
            correct_indices = y_hat_indices[correct_examples]
            correct_keys = self.keys[correct_indices]
            correct_query = query.data[correct_examples]

            new_correct_keys = F.normalize(correct_keys + correct_query, dim=1)
            self.keys[correct_indices] = new_correct_keys
            self.age[correct_indices] = 0

        # 3) Incorrect: if V[n1] != v
        # Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i 
        # K[n'] <- q, V[n'] <- v, A[n'] <- 0
        if incorrect:
            incorrect_size = incorrect_examples.size()[0]
            incorrect_query = query.data[incorrect_examples]
            incorrect_values = y.data[incorrect_examples]

            age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True)
            topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0)
            oldest_indices = torch.squeeze(topk_indices)

            self.keys[oldest_indices] = incorrect_query
            self.values[oldest_indices] = incorrect_values
            self.age[oldest_indices] = 0
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def forward(self, input1, input2, y):
        self.w1  = input1.new()
        self.w22 = input1.new()
        self.w  = input1.new()
        self.w32 = input1.new()
        self._outputs = input1.new()

        buffer = input1.new()
        _idx = self._new_idx(input1)

        torch.mul(buffer, input1, input2)
        torch.sum(self.w1, buffer, 1)

        epsilon = 1e-12
        torch.mul(buffer, input1, input1)
        torch.sum(self.w22, buffer, 1).add_(epsilon)

        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self.w22, self._outputs, self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(buffer, input2, input2)
        torch.sum(self.w32, buffer, 1).add_(epsilon)
        torch.div(self.w32, self._outputs, self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self._outputs, self.w1, self.w)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(_idx, y, -1)
        self._outputs[_idx] = self._outputs[_idx].add_(-self.margin).cmax_(0)
        torch.eq(_idx, y, 1)
        self._outputs[_idx] = self._outputs[_idx].mul_(-1).add_(1)

        output = self._outputs.sum()

        if self.size_average:
            output = output / y.size(0)

        self.save_for_backward(input1, input2, y)
        return input1.new((output,))
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def updateOutput(self, input, y):
        input1, input2 = input[0], input[1]

        # keep backward compatibility
        if not self.buffer:
            self.buffer = input1.new()
            self.w1  = input1.new()
            self.w22 = input1.new()
            self.w  = input1.new()
            self.w32 = input1.new()
            self._outputs = input1.new()

            # comparison operators behave differently from cuda/c implementations
            # TODO: verify name
            if input1.type() == 'torch.cuda.FloatTensor':
                self._idx = torch.cuda.ByteTensor()
            else:
                self._idx = torch.ByteTensor()

        torch.mul(self.buffer, input1, input2)
        torch.sum(self.w1, self.buffer, 1)

        epsilon = 1e-12
        torch.mul(self.buffer, input1, input1)
        torch.sum(self.w22, self.buffer, 1).add_(epsilon)
        # self._outputs is also used as a temporary buffer
        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self.w22, self._outputs, self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(self.buffer, input2, input2)
        torch.sum(self.w32, self.buffer, 1).add_(epsilon)
        torch.div(self.w32, self._outputs, self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self._outputs, self.w1, self.w)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(self._idx, y, -1)
        self._outputs[self._idx] = self._outputs[self._idx].add_(-self.margin).cmax_(0)
        torch.eq(self._idx, y, 1)
        self._outputs[self._idx] = self._outputs[self._idx].mul_(-1).add_(1)

        self.output = self._outputs.sum()

        if self.sizeAverage:
           self.output = self.output / y.size(0)

        return self.output
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def test_topk(self):
        def topKViaSort(t, k, dim, dir):
            sorted, indices = t.sort(dim, dir)
            return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)

        def compareTensors(t, res1, ind1, res2, ind2, dim):
            # Values should be exactly equivalent
            self.assertEqual(res1, res2, 0)

            # Indices might differ based on the implementation, since there is
            # no guarantee of the relative order of selection
            if not ind1.eq(ind2).all():
                # To verify that the indices represent equivalent elements,
                # gather from the input using the topk indices and compare against
                # the sort indices
                vals = t.gather(dim, ind2)
                self.assertEqual(res1, vals, 0)

        def compare(t, k, dim, dir):
            topKVal, topKInd = t.topk(k, dim, dir, True)
            sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
            compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)

        t = torch.rand(random.randint(1, SIZE),
                        random.randint(1, SIZE),
                        random.randint(1, SIZE))

        for kTries in range(3):
            for dimTries in range(3):
                for transpose in (True, False):
                    for dir in (True, False):
                        testTensor = t
                        if transpose:
                            dim1 = random.randrange(t.ndimension())
                            dim2 = dim1
                            while dim1 == dim2:
                                dim2 = random.randrange(t.ndimension())

                            testTensor = t.transpose(dim1, dim2)

                        dim = random.randrange(testTensor.ndimension())
                        k = random.randint(1, testTensor.size(dim))
                        compare(testTensor, k, dim, dir)
项目:sru    作者:taolei87    | 项目源码 | 文件源码
def __iter__(self):
        for batch in self.data:
            batch_size = len(batch)
            batch = list(zip(*batch))
            if self.eval:
                assert len(batch) == 7
            else:
                assert len(batch) == 9

            context_len = max(len(x) for x in batch[0])
            context_id = torch.LongTensor(batch_size, context_len).fill_(0)
            for i, doc in enumerate(batch[0]):
                context_id[i, :len(doc)] = torch.LongTensor(doc)

            feature_len = len(batch[1][0][0])
            context_feature = torch.Tensor(batch_size, context_len, feature_len).fill_(0)
            for i, doc in enumerate(batch[1]):
                for j, feature in enumerate(doc):
                    context_feature[i, j, :] = torch.Tensor(feature)

            context_tag = torch.LongTensor(batch_size, context_len).fill_(0)
            for i, doc in enumerate(batch[2]):
                context_tag[i, :len(doc)] = torch.LongTensor(doc)

            context_ent = torch.LongTensor(batch_size, context_len).fill_(0)
            for i, doc in enumerate(batch[3]):
                context_ent[i, :len(doc)] = torch.LongTensor(doc)
            question_len = max(len(x) for x in batch[4])
            question_id = torch.LongTensor(batch_size, question_len).fill_(0)
            for i, doc in enumerate(batch[4]):
                question_id[i, :len(doc)] = torch.LongTensor(doc)

            context_mask = torch.eq(context_id, 0)
            question_mask = torch.eq(question_id, 0)
            if not self.eval:
                y_s = torch.LongTensor(batch[5])
                y_e = torch.LongTensor(batch[6])
            text = list(batch[-2])
            span = list(batch[-1])
            if self.gpu:
                context_id = context_id.pin_memory()
                context_feature = context_feature.pin_memory()
                context_tag = context_tag.pin_memory()
                context_ent = context_ent.pin_memory()
                context_mask = context_mask.pin_memory()
                question_id = question_id.pin_memory()
                question_mask = question_mask.pin_memory()
            if self.eval:
                yield (context_id, context_feature, context_tag, context_ent, context_mask,
                       question_id, question_mask, text, span)
            else:
                yield (context_id, context_feature, context_tag, context_ent, context_mask,
                       question_id, question_mask, y_s, y_e, text, span)
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def updateOutput(self, input, y):
        input1, input2 = input[0], input[1]

        # keep backward compatibility
        if self.buffer is None:
            self.buffer = input1.new()
            self.w1 = input1.new()
            self.w22 = input1.new()
            self.w = input1.new()
            self.w32 = input1.new()
            self._outputs = input1.new()

            # comparison operators behave differently from cuda/c implementations
            # TODO: verify name
            if input1.type() == 'torch.cuda.FloatTensor':
                self._idx = torch.cuda.ByteTensor()
            else:
                self._idx = torch.ByteTensor()

        torch.mul(input1, input2, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w1)

        epsilon = 1e-12
        torch.mul(input1, input1, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w22).add_(epsilon)
        # self._outputs is also used as a temporary buffer
        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self._outputs, self.w22, out=self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(input2, input2, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w32).add_(epsilon)
        torch.div(self._outputs, self.w32, out=self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self.w1, self.w, out=self._outputs)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(y, -1, out=self._idx)
        self._outputs[self._idx] = self._outputs[self._idx].add_(-self.margin).clamp_(min=0)
        torch.eq(y, 1, out=self._idx)
        self._outputs[self._idx] = self._outputs[self._idx].mul_(-1).add_(1)

        self.output = self._outputs.sum()

        if self.sizeAverage:
            self.output = self.output / y.size(0)

        return self.output
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_topk(self):
        def topKViaSort(t, k, dim, dir):
            sorted, indices = t.sort(dim, dir)
            return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)

        def compareTensors(t, res1, ind1, res2, ind2, dim):
            # Values should be exactly equivalent
            self.assertEqual(res1, res2, 0)

            # Indices might differ based on the implementation, since there is
            # no guarantee of the relative order of selection
            if not ind1.eq(ind2).all():
                # To verify that the indices represent equivalent elements,
                # gather from the input using the topk indices and compare against
                # the sort indices
                vals = t.gather(dim, ind2)
                self.assertEqual(res1, vals, 0)

        def compare(t, k, dim, dir):
            topKVal, topKInd = t.topk(k, dim, dir, True)
            sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
            compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)

        t = torch.rand(random.randint(1, SIZE),
                       random.randint(1, SIZE),
                       random.randint(1, SIZE))

        for _kTries in range(3):
            for _dimTries in range(3):
                for transpose in (True, False):
                    for dir in (True, False):
                        testTensor = t
                        if transpose:
                            dim1 = random.randrange(t.ndimension())
                            dim2 = dim1
                            while dim1 == dim2:
                                dim2 = random.randrange(t.ndimension())

                            testTensor = t.transpose(dim1, dim2)

                        dim = random.randrange(testTensor.ndimension())
                        k = random.randint(1, testTensor.size(dim))
                        compare(testTensor, k, dim, dir)
项目:unsupervised-treelstm    作者:jihunchoi    | 项目源码 | 文件源码
def evaluate(args):
    with open(args.data, 'rb') as f:
        test_dataset: SNLIDataset = pickle.load(f)
    word_vocab = test_dataset.word_vocab
    label_vocab = test_dataset.label_vocab
    model = SNLIModel(num_classes=len(label_vocab), num_words=len(word_vocab),
                      word_dim=args.word_dim, hidden_dim=args.hidden_dim,
                      clf_hidden_dim=args.clf_hidden_dim,
                      clf_num_layers=args.clf_num_layers,
                      use_leaf_rnn=args.leaf_rnn,
                      intra_attention=args.intra_attention,
                      use_batchnorm=args.batchnorm,
                      dropout_prob=args.dropout,
                      bidirectional=args.bidirectional)
    num_params = sum(np.prod(p.size()) for p in model.parameters())
    num_embedding_params = np.prod(model.word_embedding.weight.size())
    print(f'# of parameters: {num_params}')
    print(f'# of word embedding parameters: {num_embedding_params}')
    print(f'# of parameters (excluding word embeddings): '
          f'{num_params - num_embedding_params}')
    model.load_state_dict(torch.load(args.model))
    model.eval()
    if args.gpu > -1:
        model.cuda(args.gpu)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate)
    num_correct = 0
    num_data = len(test_dataset)
    for batch in test_data_loader:
        pre = wrap_with_variable(batch['pre'], volatile=True, gpu=args.gpu)
        hyp = wrap_with_variable(batch['hyp'], volatile=True, gpu=args.gpu)
        pre_length = wrap_with_variable(batch['pre_length'], volatile=True,
                                        gpu=args.gpu)
        hyp_length = wrap_with_variable(batch['hyp_length'], volatile=True,
                                        gpu=args.gpu)
        label = wrap_with_variable(batch['label'], volatile=True, gpu=args.gpu)
        logits = model(pre=pre, pre_length=pre_length,
                       hyp=hyp, hyp_length=hyp_length)
        label_pred = logits.max(1)[1]
        num_correct_batch = torch.eq(label, label_pred).long().sum()
        num_correct_batch = unwrap_scalar_variable(num_correct_batch)
        num_correct += num_correct_batch
    print(f'# data: {num_data}')
    print(f'# correct: {num_correct}')
    print(f'Accuracy: {num_correct / num_data:.4f}')
项目:unsupervised-treelstm    作者:jihunchoi    | 项目源码 | 文件源码
def evaluate(args):
    text_field = data.Field(lower=args.lower, include_lengths=True,
                            batch_first=True)
    label_field = data.Field(sequential=False)

    filter_pred = None
    if not args.fine_grained:
        filter_pred = lambda ex: ex.label != 'neutral'
    dataset_splits = datasets.SST.splits(
        root='./data/sst', text_field=text_field, label_field=label_field,
        fine_grained=args.fine_grained, train_subtrees=True,
        filter_pred=filter_pred)
    test_dataset = dataset_splits[2]

    text_field.build_vocab(*dataset_splits)
    label_field.build_vocab(*dataset_splits)

    print(f'Number of classes: {len(label_field.vocab)}')

    _, _, test_loader = data.BucketIterator.splits(
        datasets=dataset_splits, batch_size=args.batch_size, device=args.gpu)

    num_classes = len(label_field.vocab)
    model = SSTModel(num_classes=num_classes, num_words=len(text_field.vocab),
                     word_dim=args.word_dim, hidden_dim=args.hidden_dim,
                     clf_hidden_dim=args.clf_hidden_dim,
                     clf_num_layers=args.clf_num_layers,
                     use_leaf_rnn=args.leaf_rnn,
                     bidirectional=args.bidirectional,
                     intra_attention=args.intra_attention,
                     use_batchnorm=args.batchnorm,
                     dropout_prob=args.dropout)
    num_params = sum(np.prod(p.size()) for p in model.parameters())
    num_embedding_params = np.prod(model.word_embedding.weight.size())
    print(f'# of parameters: {num_params}')
    print(f'# of word embedding parameters: {num_embedding_params}')
    print(f'# of parameters (excluding word embeddings): '
          f'{num_params - num_embedding_params}')
    model.load_state_dict(torch.load(args.model))
    model.eval()
    if args.gpu > -1:
        model.cuda(args.gpu)
    num_correct = 0
    num_data = len(test_dataset)
    for batch in test_loader:
        words, length = batch.text
        label = batch.label
        length = wrap_with_variable(length, volatile=True, gpu=args.gpu)
        logits = model(words=words, length=length)
        label_pred = logits.max(1)[1]
        num_correct_batch = torch.eq(label, label_pred).long().sum()
        num_correct_batch = unwrap_scalar_variable(num_correct_batch)
        num_correct += num_correct_batch
    print(f'# data: {num_data}')
    print(f'# correct: {num_correct}')
    print(f'Accuracy: {num_correct / num_data:.4f}')
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def updateOutput(self, input, y):
        input1, input2 = input[0], input[1]

        # keep backward compatibility
        if self.buffer is None:
            self.buffer = input1.new()
            self.w1 = input1.new()
            self.w22 = input1.new()
            self.w = input1.new()
            self.w32 = input1.new()
            self._outputs = input1.new()

            # comparison operators behave differently from cuda/c implementations
            # TODO: verify name
            if input1.type() == 'torch.cuda.FloatTensor':
                self._idx = torch.cuda.ByteTensor()
            else:
                self._idx = torch.ByteTensor()

        torch.mul(input1, input2, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w1, keepdim=True)

        epsilon = 1e-12
        torch.mul(input1, input1, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w22, keepdim=True).add_(epsilon)
        # self._outputs is also used as a temporary buffer
        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self._outputs, self.w22, out=self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(input2, input2, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w32, keepdim=True).add_(epsilon)
        torch.div(self._outputs, self.w32, out=self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self.w1, self.w, out=self._outputs)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(y, -1, out=self._idx)
        self._outputs[self._idx] = self._outputs[self._idx].add_(-self.margin).clamp_(min=0)
        torch.eq(y, 1, out=self._idx)
        self._outputs[self._idx] = self._outputs[self._idx].mul_(-1).add_(1)

        self.output = self._outputs.sum()

        if self.sizeAverage:
            self.output = self.output / y.size(0)

        return self.output