Python numpy.testing 模块,assert_equal() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用numpy.testing.assert_equal()

项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_safe_binop():
    # Test checked arithmetic routines

    ops = [
        (operator.add, 1),
        (operator.sub, 2),
        (operator.mul, 3)
    ]

    with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
        for xop, a, b in it:
            pyop, op = xop
            c = pyop(a, b)

            if not (INT64_MIN <= c <= INT64_MAX):
                assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
            else:
                d = mt.extint_safe_binop(a, b, op)
                if c != d:
                    # assert_equal is slow
                    assert_equal(d, c)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_divmod_128_64():
    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
        for a, b in it:
            if a >= 0:
                c, cr = divmod(a, b)
            else:
                c, cr = divmod(-a, b)
                c = -c
                cr = -cr

            d, dr = mt.extint_divmod_128_64(a, b)

            if c != d or d != dr or b*d + dr != a:
                assert_equal(d, c)
                assert_equal(dr, cr)
                assert_equal(b*d + dr, a)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def check_may_share_memory_exact(a, b):
    got = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT)

    assert_equal(np.may_share_memory(a, b),
                 np.may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS))

    a.fill(0)
    b.fill(0)
    a.fill(1)
    exact = b.any()

    err_msg = ""
    if got != exact:
        err_msg = "    " + "\n    ".join([
            "base_a - base_b = %r" % (a.__array_interface__['data'][0] - b.__array_interface__['data'][0],),
            "shape_a = %r" % (a.shape,),
            "shape_b = %r" % (b.shape,),
            "strides_a = %r" % (a.strides,),
            "strides_b = %r" % (b.strides,),
            "size_a = %r" % (a.size,),
            "size_b = %r" % (b.size,)
        ])

    assert_equal(got, exact, err_msg=err_msg)
项目:F_UNCLE    作者:fraserphysics    | 项目源码 | 文件源码
def test_update_prior(self):
        """Tests setting and updating the prior
        """
        model = PhysicsModel(prior=[3.5])

        npt.assert_equal(model.prior, [3.5])

        new_prior = [2.5]
        new_model = model.update_prior(new_prior)

        npt.assert_equal(new_model.prior, new_prior,
                         err_msg="Test that new prior set correctly")

        self.assertFalse(model is new_model,
                         msg="Test that update_prior gives a new instance")
        self.assertFalse(new_model.prior is new_prior,
                         msg="Test that prior is not linked to passed value")
项目:F_UNCLE    作者:fraserphysics    | 项目源码 | 文件源码
def test_call(self):
        """Test of the function call
        """
        models = {'simp': SimpleModel([2, 1])}
        data = self.simSimp(models)

        self.assertEqual(len(data), 3)
        self.assertIsInstance(data[0], np.ndarray)
        self.assertEqual(len(data[1]), 1)
        self.assertIsInstance(data[1][0], np.ndarray)
        self.assertIsInstance(data[2], dict)
        self.assertTrue('mean_fn' in data[2])
        self.assertIsInstance(data[2]['mean_fn'], IUSpline)

        xx = np.arange(10)
        npt.assert_equal(data[0], xx)
        npt.assert_equal(data[1][0], (2 * xx)**2 + 1 * xx)
项目:introspective    作者:numeristical    | 项目源码 | 文件源码
def test_transform_data():
    """
    Testing the transformation of the data from raw data to functions
    used for fitting a function.

    """
    # We start with actual data. We test here just that reading the data in
    # different ways ultimately generates the same arrays.
    from matplotlib import mlab
    ortho = mlab.csv2rec(op.join(data_path, 'ortho.csv'))
    x1, y1, n1 = mli.transform_data(ortho)
    x2, y2, n2 = mli.transform_data(op.join(data_path, 'ortho.csv'))
    npt.assert_equal(x1, x2)
    npt.assert_equal(y1, y2)
    # We can also be a bit more critical, by testing with data that we
    # generate, and should produce a particular answer:
    my_data = pd.DataFrame(
        np.array([[0.1, 2], [0.1, 1], [0.2, 2], [0.2, 2], [0.3, 1],
                  [0.3, 1]]),
        columns=['contrast1', 'answer'])
    my_x, my_y, my_n = mli.transform_data(my_data)
    npt.assert_equal(my_x, np.array([0.1, 0.2, 0.3]))
    npt.assert_equal(my_y, np.array([0.5, 0, 1.0]))
    npt.assert_equal(my_n, np.array([2, 2, 2]))
项目:krpcScripts    作者:jwvanderbeck    | 项目源码 | 文件源码
def test_safe_binop():
    # Test checked arithmetic routines

    ops = [
        (operator.add, 1),
        (operator.sub, 2),
        (operator.mul, 3)
    ]

    with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
        for xop, a, b in it:
            pyop, op = xop
            c = pyop(a, b)

            if not (INT64_MIN <= c <= INT64_MAX):
                assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
            else:
                d = mt.extint_safe_binop(a, b, op)
                if c != d:
                    # assert_equal is slow
                    assert_equal(d, c)
项目:krpcScripts    作者:jwvanderbeck    | 项目源码 | 文件源码
def test_divmod_128_64():
    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
        for a, b in it:
            if a >= 0:
                c, cr = divmod(a, b)
            else:
                c, cr = divmod(-a, b)
                c = -c
                cr = -cr

            d, dr = mt.extint_divmod_128_64(a, b)

            if c != d or d != dr or b*d + dr != a:
                assert_equal(d, c)
                assert_equal(dr, cr)
                assert_equal(b*d + dr, a)
项目:krpcScripts    作者:jwvanderbeck    | 项目源码 | 文件源码
def check_may_share_memory_exact(a, b):
    got = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT)

    assert_equal(np.may_share_memory(a, b),
                 np.may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS))

    a.fill(0)
    b.fill(0)
    a.fill(1)
    exact = b.any()

    err_msg = ""
    if got != exact:
        err_msg = "    " + "\n    ".join([
            "base_a - base_b = %r" % (a.__array_interface__['data'][0] - b.__array_interface__['data'][0],),
            "shape_a = %r" % (a.shape,),
            "shape_b = %r" % (b.shape,),
            "strides_a = %r" % (a.strides,),
            "strides_b = %r" % (b.strides,),
            "size_a = %r" % (a.size,),
            "size_b = %r" % (b.size,)
        ])

    assert_equal(got, exact, err_msg=err_msg)
项目:deepcpg    作者:cangermueller    | 项目源码 | 文件源码
def test_read(self):
        data_files = [self.data_files[0], self.data_files[-1]]
        names = ['pos', 'chromo', '/outputs/cpg/BS27_4_SER']
        data = hdf.read(data_files, names, shuffle=False)

        assert np.all(data['chromo'][:5] == b'18')
        npt.assert_equal(data['pos'][:5],
                         [3000023, 3000086, 3000092, 3000103, 3000163])
        npt.assert_equal(data['/outputs/cpg/BS27_4_SER'][:5],
                         [1, 1, 1, -1, 0])

        assert np.all(data['chromo'][-5:] == b'19')
        npt.assert_equal(data['pos'][-5:],
                         [4447803, 4447814, 4447818, 4447821, 4447847])
        npt.assert_equal(data['/outputs/cpg/BS27_4_SER'][-5:],
                         [1, 1, 1, 1, 1])
项目:DeepLearnTute    作者:DouglasOrr    | 项目源码 | 文件源码
def test_parse_uji():
    npt.assert_equal([], list(data.parse_uji([])))

    npt.assert_equal([('a', [[(0, 100), (-200, 300), (456, 777)]]),
                      ('b', [[(1, 2), (3, 4)],
                             [(5, 6), (7, 8), (9, -1)]])],
                     list(data.parse_uji('''
//
// UJI: 100 units per millimetre
//
// ASCII char: a
WORD a some-arbitrary-STRING
  NUMSTROKES 1
    POINTS 3 # 0 100 -200 300 456 777

// more WORD NUMSTROKES POINTS to come...
WORD b some-arbitrary-STRING
  NUMSTROKES 2
    POINTS 2 # 1 2 3 4
    POINTS 3 # 5 6 7 8 9 -1
                     '''.split('\n'))))
项目:indigo    作者:mbdriscoll    | 项目源码 | 文件源码
def test_compat_Zpad(backend, X,Y,Z, P, K):
    pymr = pytest.importorskip('pymr')
    b = backend()

    i_shape = (X, Y, Z, K)
    o_shape = (X+2*P, Y+2*P, Z+2*P, K)

    x = indigo.util.rand64c( *i_shape )

    D0 = pymr.linop.Zpad( o_shape, i_shape, dtype=x.dtype )
    D1 = b.Zpad(o_shape[:3], i_shape[:3], dtype=x.dtype)

    x_indigo = np.asfortranarray(x.reshape((-1,K), order='F'))
    x_pmr = pymr.util.vec(x)
    y_exp = D0 * x_pmr
    y_act = D1 * x_indigo

    y_act = y_act.flatten(order='F')
    npt.assert_equal(y_act, y_exp)
项目:pulse2percept    作者:uwescience    | 项目源码 | 文件源码
def test_TimeSeries():
    max_val = 2.0
    max_idx = 156
    data_orig = np.random.rand(10, 10, 1000)
    data_orig[4, 4, max_idx] = max_val
    ts = utils.TimeSeries(1.0, data_orig)

    # Make sure function returns largest element
    tmax, vmax = ts.max()
    npt.assert_equal(tmax, max_idx)
    npt.assert_equal(vmax, max_val)

    # Make sure function returns largest frame
    tmax, fmax = ts.max_frame()
    npt.assert_equal(tmax, max_idx)
    npt.assert_equal(fmax.data, data_orig[:, :, max_idx])

    # Make sure getitem works
    npt.assert_equal(isinstance(ts[3], utils.TimeSeries), True)
    npt.assert_equal(ts[3].data, ts.data[3])
项目:pulse2percept    作者:uwescience    | 项目源码 | 文件源码
def test_gamma():
    tsample = 0.005 / 1000

    with pytest.raises(ValueError):
        t, g = utils.gamma(0, 0.1, tsample)
    with pytest.raises(ValueError):
        t, g = utils.gamma(2, -0.1, tsample)
    with pytest.raises(ValueError):
        t, g = utils.gamma(2, 0.1, -tsample)

    for tau in [0.001, 0.01, 0.1]:
        for n in [1, 2, 5]:
            t, g = utils.gamma(n, tau, tsample)
            npt.assert_equal(np.arange(0, t[-1] + tsample / 2.0, tsample), t)
            if n > 1:
                npt.assert_equal(g[0], 0.0)

            # Make sure area under the curve is normalized
            npt.assert_almost_equal(np.trapz(np.abs(g), dx=tsample), 1.0,
                                    decimal=2)

            # Make sure peak sits correctly
            npt.assert_almost_equal(g.argmax() * tsample, tau * (n - 1))
项目:pulse2percept    作者:uwescience    | 项目源码 | 文件源码
def test_BaseModel():
    # Cannot instantiate abstract class
    with pytest.raises(TypeError):
        tm = p2p.retina.BaseModel(0.01)

    # Child class must provide `model_cascade()`
    class Incomplete(p2p.retina.BaseModel):
        pass
    with pytest.raises(TypeError):
        tm = Incomplete()

    # A complete class
    class Complete(p2p.retina.BaseModel):

        def model_cascade(self, inval):
            return inval

    tm = Complete(tsample=0.1)
    npt.assert_equal(tm.tsample, 0.1)
    npt.assert_equal(tm.model_cascade(2.4), 2.4)
项目:pulse2percept    作者:uwescience    | 项目源码 | 文件源码
def test_axon_dist_from_soma():
    # A small grid
    xg, yg = np.meshgrid([-1, 0, 1], [-1, 0, 1], indexing='xy')

    # When axon locations are snapped to the grid, a really short axon should
    # have zero distance to the soma:
    for x_soma in [-1.0, -0.2, 0.51]:
        axon = np.array([[i, i] for i in np.linspace(x_soma, x_soma + 0.01)])
        _, dist = p2p.retina.axon_dist_from_soma(axon, xg, yg)
        npt.assert_almost_equal(dist, 0.0)

    # On this simple grid, a diagonal axon should have dist [0, sqrt(2), 2]:
    for sign in [-1.0, 1.0]:
        for num in [10, 20, 50]:
            axon = np.array([[i, i] for i in np.linspace(sign, -sign, num)])
            _, dist = p2p.retina.axon_dist_from_soma(axon, xg, yg)
            npt.assert_almost_equal(dist, np.array([0.0, np.sqrt(2), 2.0]))

    # An axon that does not live near the grid should return infinite distance
    axon = np.array([[i, i] for i in np.linspace(1000.0, 1500.0)])
    _, dist = p2p.retina.axon_dist_from_soma(axon, xg, yg)
    npt.assert_equal(np.isinf(dist), True)
项目:pulse2percept    作者:uwescience    | 项目源码 | 文件源码
def test_load_video_metadata():
    # Load a test example
    reload(files)
    with pytest.raises(OSError):
        metadata = files.load_video_metadata('nothing_there.mp4')

    from skvideo import datasets
    metadata = files.load_video_metadata(datasets.bikes())
    npt.assert_equal(metadata['@codec_name'], 'h264')
    npt.assert_equal(metadata['@duration_ts'], '128000')
    npt.assert_equal(metadata['@r_frame_rate'], '25/1')

    # Trigger an import error
    with mock.patch.dict("sys.modules", {"skvideo": {}, "skvideo.utils": {}}):
        with pytest.raises(ImportError):
            reload(files)
            files.load_video_metadata(datasets.bikes())
项目:lambda-numba    作者:rlhotovy    | 项目源码 | 文件源码
def test_safe_binop():
    # Test checked arithmetic routines

    ops = [
        (operator.add, 1),
        (operator.sub, 2),
        (operator.mul, 3)
    ]

    with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
        for xop, a, b in it:
            pyop, op = xop
            c = pyop(a, b)

            if not (INT64_MIN <= c <= INT64_MAX):
                assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
            else:
                d = mt.extint_safe_binop(a, b, op)
                if c != d:
                    # assert_equal is slow
                    assert_equal(d, c)
项目:lambda-numba    作者:rlhotovy    | 项目源码 | 文件源码
def test_divmod_128_64():
    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
        for a, b in it:
            if a >= 0:
                c, cr = divmod(a, b)
            else:
                c, cr = divmod(-a, b)
                c = -c
                cr = -cr

            d, dr = mt.extint_divmod_128_64(a, b)

            if c != d or d != dr or b*d + dr != a:
                assert_equal(d, c)
                assert_equal(dr, cr)
                assert_equal(b*d + dr, a)
项目:lambda-numba    作者:rlhotovy    | 项目源码 | 文件源码
def check_may_share_memory_exact(a, b):
    got = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT)

    assert_equal(np.may_share_memory(a, b),
                 np.may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS))

    a.fill(0)
    b.fill(0)
    a.fill(1)
    exact = b.any()

    err_msg = ""
    if got != exact:
        err_msg = "    " + "\n    ".join([
            "base_a - base_b = %r" % (a.__array_interface__['data'][0] - b.__array_interface__['data'][0],),
            "shape_a = %r" % (a.shape,),
            "shape_b = %r" % (b.shape,),
            "strides_a = %r" % (a.strides,),
            "strides_b = %r" % (b.strides,),
            "size_a = %r" % (a.size,),
            "size_b = %r" % (b.size,)
        ])

    assert_equal(got, exact, err_msg=err_msg)
项目:cesium_web    作者:cesium-ml    | 项目源码 | 文件源码
def test_robust_literal_eval():
    """Test util.robust_literal_eval"""
    params = {"n_estimators": "1000",
              "max_features": "auto",
              "min_weight_fraction_leaf": "0.34",
              "bootstrap": "True",
              "class_weight": "{'a': 0.2, 'b': 0.8}",
              "max_features2": "[150.3, 20, 'auto']"}
    expected = {"n_estimators": 1000,
                "max_features": "auto",
                "min_weight_fraction_leaf": 0.34,
                "bootstrap": True,
                "class_weight": {'a': 0.2, 'b': 0.8},
                "max_features2": [150.3, 20, "auto"]}
    params = {k: util.robust_literal_eval(v) for k, v in params.items()}
    npt.assert_equal(params, expected)
项目:cesium_web    作者:cesium-ml    | 项目源码 | 文件源码
def test_download_prediction_csv_class(driver, project, dataset, featureset,
                                       model, prediction):
    driver.get('/')
    _click_download(project.id, driver)
    assert os.path.exists('/tmp/cesium_prediction_results.csv')
    try:
        npt.assert_equal(
            np.genfromtxt('/tmp/cesium_prediction_results.csv', dtype='str'),
            ['ts_name,label,prediction',
             '0,Mira,Mira',
             '1,Classical_Cepheid,Classical_Cepheid',
             '2,Mira,Mira',
             '3,Classical_Cepheid,Classical_Cepheid',
             '4,Mira,Mira'])
    finally:
        os.remove('/tmp/cesium_prediction_results.csv')
项目:cesium_web    作者:cesium-ml    | 项目源码 | 文件源码
def test_download_prediction_csv_regr(driver, project, dataset, featureset, model, prediction):
    driver.get('/')
    _click_download(project.id, driver)
    assert os.path.exists('/tmp/cesium_prediction_results.csv')
    try:
        results = np.genfromtxt('/tmp/cesium_prediction_results.csv',
                                dtype='str', delimiter=',')
        npt.assert_equal(results[0],
                         ['ts_name', 'label', 'prediction'])
        npt.assert_array_almost_equal(
            [[float(e) for e in row] for row in results[1:]],
            [[0, 2.2, 2.2],
             [1, 3.4, 3.4],
             [2, 4.4, 4.4],
             [3, 2.2, 2.2],
             [4, 3.1, 3.1]])
    finally:
        os.remove('/tmp/cesium_prediction_results.csv')
项目:kubeface    作者:hammerlab    | 项目源码 | 文件源码
def test_put_and_get_to_bucket(bucket):
    data = "ABCDe" * 1000
    data_handle = BytesIO(data.encode("UTF-8"))
    file_name = "kubeface-test-%s.txt" % (
        str(time.time()).replace(".", ""))
    name = "%s/%s" % (bucket, file_name)
    storage.put(name, data_handle)
    testing.assert_equal(storage.list_contents(name), [file_name])
    testing.assert_(
        file_name in storage.list_contents("%s/kubeface-test-" % bucket))

    result_handle = storage.get(name)
    testing.assert_equal(result_handle.read().decode("UTF-8"), data)
    storage.delete(name)
    testing.assert_(
        file_name not in storage.list_contents("%s/" % bucket))
项目:kubeface    作者:hammerlab    | 项目源码 | 文件源码
def test_move(bucket):
    data = "ABCDe" * 1000
    data_handle = BytesIO(data.encode("UTF-8"))
    file_name = "kubeface-test-%s.txt" % (
        str(time.time()).replace(".", ""))
    name = "%s/%s" % (bucket, file_name)
    name2 = "%s/moved-%s" % (bucket, file_name)
    storage.put(name, data_handle)
    testing.assert_equal(storage.list_contents(name), [file_name])
    storage.move(name, name2)
    testing.assert_equal(storage.list_contents(name), [])
    testing.assert_equal(
        storage.list_contents(name2),
        ["moved-%s" % file_name])
    result_handle = storage.get(name2)
    testing.assert_equal(result_handle.read().decode("UTF-8"), data)
    storage.delete(name2)
    testing.assert_(
        ("moved-%s" % file_name) not in storage.list_contents("%s/" % bucket))
项目:kubeface    作者:hammerlab    | 项目源码 | 文件源码
def test_worker_exception_delayed(bucket):
    c = client_from_commandline_args([
        "--kubeface-poll-seconds", "1.1",
        "--kubeface-backend", "local-process",
        "--kubeface-storage", bucket,
        "--kubeface-wait-to-raise-task-exception",
    ])
    mapper = c.map(lambda x: 2 / (x - 2), range(10))
    testing.assert_equal(next(mapper), -1)
    testing.assert_equal(next(mapper), -2)
    testing.assert_equal(len(c.job_summary(include_done=False)), 1)
    testing.assert_equal(len(c.job_summary(include_done=True)), 1)
    testing.assert_raises(ZeroDivisionError, next, mapper)
    testing.assert_equal(len(c.job_summary(include_done=False)), 0)
    testing.assert_equal(len(c.job_summary(include_done=True)), 1)
    testing.assert_raises(StopIteration, next, mapper)
    testing.assert_equal(len(c.job_summary(include_done=False)), 0)
    testing.assert_equal(len(c.job_summary(include_done=True)), 1)
项目:kubeface    作者:hammerlab    | 项目源码 | 文件源码
def test_job_command(bucket):
    c = client_from_commandline_args([
        "--kubeface-poll-seconds", "1.1",
        "--kubeface-backend", "local-process",
        "--kubeface-storage", bucket,
    ])

    mapper = c.map(math.exp, range(10), cache_key='FOOBARBAZ')
    testing.assert_equal(next(mapper), 1)
    assert 'FOOBARBAZ' in run_job_command(bucket, [])
    assert 'active' in (
        find_line_with(
            "FOOBARBAZ",
            run_job_command(bucket, ["--include-done"]),
            nth=1))
    list(mapper)
    assert 'FOOBARBAZ' not in run_job_command(bucket, [])
项目:deliver    作者:orchestor    | 项目源码 | 文件源码
def test_safe_binop():
    # Test checked arithmetic routines

    ops = [
        (operator.add, 1),
        (operator.sub, 2),
        (operator.mul, 3)
    ]

    with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
        for xop, a, b in it:
            pyop, op = xop
            c = pyop(a, b)

            if not (INT64_MIN <= c <= INT64_MAX):
                assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
            else:
                d = mt.extint_safe_binop(a, b, op)
                if c != d:
                    # assert_equal is slow
                    assert_equal(d, c)
项目:deliver    作者:orchestor    | 项目源码 | 文件源码
def test_divmod_128_64():
    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
        for a, b in it:
            if a >= 0:
                c, cr = divmod(a, b)
            else:
                c, cr = divmod(-a, b)
                c = -c
                cr = -cr

            d, dr = mt.extint_divmod_128_64(a, b)

            if c != d or d != dr or b*d + dr != a:
                assert_equal(d, c)
                assert_equal(dr, cr)
                assert_equal(b*d + dr, a)
项目:deliver    作者:orchestor    | 项目源码 | 文件源码
def check_may_share_memory_exact(a, b):
    got = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT)

    assert_equal(np.may_share_memory(a, b),
                 np.may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS))

    a.fill(0)
    b.fill(0)
    a.fill(1)
    exact = b.any()

    err_msg = ""
    if got != exact:
        err_msg = "    " + "\n    ".join([
            "base_a - base_b = %r" % (a.__array_interface__['data'][0] - b.__array_interface__['data'][0],),
            "shape_a = %r" % (a.shape,),
            "shape_b = %r" % (b.shape,),
            "strides_a = %r" % (a.strides,),
            "strides_b = %r" % (b.strides,),
            "size_a = %r" % (a.size,),
            "size_b = %r" % (b.size,)
        ])

    assert_equal(got, exact, err_msg=err_msg)
项目:SHED    作者:xpdAcq    | 项目源码 | 文件源码
def test_zip(n, n2, kwargs, expected):
    source = Stream()
    source2 = Stream()

    L = es.zip(source, source2, **kwargs).sink_to_list()
    s = list(to_event_model(
        [np.random.random((10, 10)) for _ in range(n)],
        output_info=[('pe1_image', {'dtype': 'array'})]
    ))
    s2 = list(to_event_model(
        [np.random.random((10, 10)) for _ in range(n2)],
        output_info=[('pe1_image', {'dtype': 'array'})]
    ))

    for _ in range(2):
        L.clear()
        for b in s2:
            source2.emit(b)
        for a in s:
            source.emit(a)
        assert_docs = set()
        for name, (l1, l2) in L:
            assert_docs.add(name)
            assert_raises(AssertionError, assert_equal, l1, l2)
        assert set(assert_docs) == {'start', 'descriptor', 'event', 'stop'}
项目:pyAFQ    作者:yeatmanlab    | 项目源码 | 文件源码
def test_fit_dki():
    with nbtmp.InTemporaryDirectory() as tmpdir:
        fbval = op.join(tmpdir, 'dki.bval')
        fbvec = op.join(tmpdir, 'dki.bvec')
        fdata = op.join(tmpdir, 'dki.nii.gz')
        make_dki_data(fbval, fbvec, fdata)
        cmd = ["pyAFQ_dki", "-d", fdata, "-l", fbval, "-c", fbvec,
               "-o", tmpdir]
        out = runner.run_command(cmd)
        npt.assert_equal(out[0], 0)
        # Get expected values
        names = ['FA', 'MD', 'AD', 'RD', 'MK', 'AK', 'RK']
        for n in names:
            fname = op.join(tmpdir, "dki_%s.nii.gz" % n)
            img = nib.load(fdata)
            affine = img.get_affine()
            shape = img.shape[:-1]
            assert_image_shape_affine(fname, shape, affine)
项目:pyAFQ    作者:yeatmanlab    | 项目源码 | 文件源码
def test_predict_dki():
    with nbtmp.InTemporaryDirectory() as tmpdir:
        fbval = op.join(tmpdir, 'dki.bval')
        fbvec = op.join(tmpdir, 'dki.bvec')
        fdata = op.join(tmpdir, 'dki.nii.gz')
        make_dki_data(fbval, fbvec, fdata)
        cmd1 = ["pyAFQ_dki", "-d", fdata, "-l", fbval, "-c", fbvec,
                "-o", tmpdir]
        out = runner.run_command(cmd1)
        npt.assert_equal(out[0], 0)

        # Get expected values
        fparams = op.join(tmpdir, "dki_params.nii.gz")
        cmd2 = ["pyAFQ_dki_predict", "-p", fparams, "-l", fbval, "-c", fbvec,
                "-o", tmpdir, '-b', '0']
        out = runner.run_command(cmd2)
        npt.assert_equal(out[0], 0)
        pred = nib.load(op.join(tmpdir, "dki_prediction.nii.gz")).get_data()
        data = nib.load(op.join(tmpdir, "dki.nii.gz")).get_data()
        npt.assert_array_almost_equal(pred, data)
项目:pyAFQ    作者:yeatmanlab    | 项目源码 | 文件源码
def test_fit_dti():
    with nbtmp.InTemporaryDirectory() as tmpdir:
        fbval = op.join(tmpdir, 'dti.bval')
        fbvec = op.join(tmpdir, 'dti.bvec')
        fdata = op.join(tmpdir, 'dti.nii.gz')
        make_dti_data(fbval, fbvec, fdata)
        cmd = ["pyAFQ_dti", "-d", fdata, "-l", fbval, "-c", fbvec,
               "-o", tmpdir, '-b', '0']
        out = runner.run_command(cmd)
        npt.assert_equal(out[0], 0)
        # Get expected values
        names = ['FA', 'MD', 'AD', 'RD']
        for n in names:
            fname = op.join(tmpdir, "dti_%s.nii.gz" % n)
            img = nib.load(fdata)
            affine = img.get_affine()
            shape = img.shape[:-1]
            assert_image_shape_affine(fname, shape, affine)
项目:pyAFQ    作者:yeatmanlab    | 项目源码 | 文件源码
def test_predict_dti():
    with nbtmp.InTemporaryDirectory() as tmpdir:
        fbval = op.join(tmpdir, 'dti.bval')
        fbvec = op.join(tmpdir, 'dti.bvec')
        fdata = op.join(tmpdir, 'dti.nii.gz')
        make_dti_data(fbval, fbvec, fdata)
        cmd1 = ["pyAFQ_dti", "-d", fdata, "-l", fbval, "-c", fbvec,
                "-o", tmpdir]
        out = runner.run_command(cmd1)
        npt.assert_equal(out[0], 0)
        # Get expected values
        fparams = op.join(tmpdir, "dti_params.nii.gz")
        cmd2 = ["pyAFQ_dti_predict", "-p", fparams, "-l", fbval, "-c", fbvec,
                "-o", tmpdir, '-b', '0']
        out = runner.run_command(cmd2)
        npt.assert_equal(out[0], 0)
        pred = nib.load(op.join(tmpdir, "dti_prediction.nii.gz")).get_data()
        data = nib.load(op.join(tmpdir, "dti.nii.gz")).get_data()
        npt.assert_array_almost_equal(pred, data)
项目:pyAFQ    作者:yeatmanlab    | 项目源码 | 文件源码
def test_fit_csd():
    fdata, fbval, fbvec = dpd.get_data('small_64D')
    with nbtmp.InTemporaryDirectory() as tmpdir:
        # Convert from npy to txt:
        bvals = np.load(fbval)
        bvecs = np.load(fbvec)
        np.savetxt(op.join(tmpdir, 'bvals.txt'), bvals)
        np.savetxt(op.join(tmpdir, 'bvecs.txt'), bvecs)
        for sh_order in [4, 6]:
            fname = csd.fit_csd(fdata, op.join(tmpdir, 'bvals.txt'),
                                op.join(tmpdir, 'bvecs.txt'),
                                out_dir=tmpdir, sh_order=sh_order)
            npt.assert_(op.exists(fname))
            sh_coeffs_img = nib.load(fname)
            npt.assert_equal(sh_order,
                             calculate_max_order(sh_coeffs_img.shape[-1]))
项目:pyAFQ    作者:yeatmanlab    | 项目源码 | 文件源码
def test_read_write_trk():
    sl = [np.array([[0, 0, 0], [0, 0, 0.5], [0, 0, 1], [0, 0, 1.5]]),
          np.array([[0, 0, 0], [0, 0.5, 0.5], [0, 1, 1]])]

    with nbtmp.InTemporaryDirectory() as tmpdir:
        fname = op.join(tmpdir, 'sl.trk')
        aus.write_trk(fname, sl)
        new_sl = aus.read_trk(fname)
        npt.assert_equal(list(new_sl), sl)

        # What happens if this set of streamlines has some funky affine
        # associated with it?
        aff = np.eye(4) * np.random.rand()
        aff[:3, 3] = np.array([1, 2, 3])
        aff[3, 3] = 1
        # We move the streamlines, and report the inverse of the affine:
        aus.write_trk(fname, move_streamlines(sl, aff),
                      affine=np.linalg.inv(aff))
        # When we read this, we get back what we put in:
        new_sl = aus.read_trk(fname)
        # Compare each streamline:
        for new, old in zip(new_sl, sl):
            npt.assert_almost_equal(new, old, decimal=4)
项目:pyAFQ    作者:yeatmanlab    | 项目源码 | 文件源码
def test_parfor():
    my_array = np.arange(100).reshape(10, 10)
    i, j = np.random.randint(0, 9, 2)
    my_list = list(my_array.ravel())
    for engine in ["joblib", "dask", "serial"]:
        for backend in ["threading", "multiprocessing"]:
            npt.assert_equal(para.parfor(power_it,
                                         my_list,
                                         engine=engine,
                                         backend=backend,
                                         out_shape=my_array.shape)[i, j],
                             power_it(my_array[i, j]))

            # If it's not reshaped, the first item should be the item 0, 0:
            npt.assert_equal(para.parfor(power_it,
                                         my_list,
                                         engine=engine,
                                         backend=backend)[0],
                             power_it(my_array[0, 0]))
项目:Alfred    作者:jkachhadia    | 项目源码 | 文件源码
def test_safe_binop():
    # Test checked arithmetic routines

    ops = [
        (operator.add, 1),
        (operator.sub, 2),
        (operator.mul, 3)
    ]

    with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
        for xop, a, b in it:
            pyop, op = xop
            c = pyop(a, b)

            if not (INT64_MIN <= c <= INT64_MAX):
                assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
            else:
                d = mt.extint_safe_binop(a, b, op)
                if c != d:
                    # assert_equal is slow
                    assert_equal(d, c)
项目:Alfred    作者:jkachhadia    | 项目源码 | 文件源码
def test_divmod_128_64():
    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
        for a, b in it:
            if a >= 0:
                c, cr = divmod(a, b)
            else:
                c, cr = divmod(-a, b)
                c = -c
                cr = -cr

            d, dr = mt.extint_divmod_128_64(a, b)

            if c != d or d != dr or b*d + dr != a:
                assert_equal(d, c)
                assert_equal(dr, cr)
                assert_equal(b*d + dr, a)
项目:Alfred    作者:jkachhadia    | 项目源码 | 文件源码
def check_may_share_memory_exact(a, b):
    got = np.may_share_memory(a, b, max_work=MAY_SHARE_EXACT)

    assert_equal(np.may_share_memory(a, b),
                 np.may_share_memory(a, b, max_work=MAY_SHARE_BOUNDS))

    a.fill(0)
    b.fill(0)
    a.fill(1)
    exact = b.any()

    err_msg = ""
    if got != exact:
        err_msg = "    " + "\n    ".join([
            "base_a - base_b = %r" % (a.__array_interface__['data'][0] - b.__array_interface__['data'][0],),
            "shape_a = %r" % (a.shape,),
            "shape_b = %r" % (b.shape,),
            "strides_a = %r" % (a.strides,),
            "strides_b = %r" % (b.strides,),
            "size_a = %r" % (a.size,),
            "size_b = %r" % (b.size,)
        ])

    assert_equal(got, exact, err_msg=err_msg)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_to_128():
    with exc_iter(INT64_VALUES) as it:
        for a, in it:
            b = mt.extint_to_128(a)
            if a != b:
                assert_equal(b, a)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_to_64():
    with exc_iter(INT128_VALUES) as it:
        for a, in it:
            if not (INT64_MIN <= a <= INT64_MAX):
                assert_raises(OverflowError, mt.extint_to_64, a)
            else:
                b = mt.extint_to_64(a)
                if a != b:
                    assert_equal(b, a)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_mul_64_64():
    with exc_iter(INT64_VALUES, INT64_VALUES) as it:
        for a, b in it:
            c = a * b
            d = mt.extint_mul_64_64(a, b)
            if c != d:
                assert_equal(d, c)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_add_128():
    with exc_iter(INT128_VALUES, INT128_VALUES) as it:
        for a, b in it:
            c = a + b
            if not (INT128_MIN <= c <= INT128_MAX):
                assert_raises(OverflowError, mt.extint_add_128, a, b)
            else:
                d = mt.extint_add_128(a, b)
                if c != d:
                    assert_equal(d, c)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_neg_128():
    with exc_iter(INT128_VALUES) as it:
        for a, in it:
            b = -a
            c = mt.extint_neg_128(a)
            if b != c:
                assert_equal(c, b)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_shl_128():
    with exc_iter(INT128_VALUES) as it:
        for a, in it:
            if a < 0:
                b = -(((-a) << 1) & (2**128-1))
            else:
                b = (a << 1) & (2**128-1)
            c = mt.extint_shl_128(a)
            if b != c:
                assert_equal(c, b)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_shr_128():
    with exc_iter(INT128_VALUES) as it:
        for a, in it:
            if a < 0:
                b = -((-a) >> 1)
            else:
                b = a >> 1
            c = mt.extint_shr_128(a)
            if b != c:
                assert_equal(c, b)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_gt_128():
    with exc_iter(INT128_VALUES, INT128_VALUES) as it:
        for a, b in it:
            c = a > b
            d = mt.extint_gt_128(a, b)
            if c != d:
                assert_equal(d, c)
项目:radar    作者:amoose136    | 项目源码 | 文件源码
def test_ceildiv_128_64():
    with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
        for a, b in it:
            c = (a + b - 1) // b
            d = mt.extint_ceildiv_128_64(a, b)

            if c != d:
                assert_equal(d, c)