Python multiprocessing 模块,Barrier() 实例源码

我们从Python开源项目中,提取了以下14个代码示例,用于说明如何使用multiprocessing.Barrier()

项目:Synkhronos    作者:astooke    | 项目源码 | 文件源码
def test_multi_process_simultaneous(n_gpu=2, worker_func_maker=unpickle_func, bar_loop=False):
    barrier = mp.Barrier(n_gpu)
    if PROFILE:
        target = sim_profiling_worker
    else:
        target = simultaneous_worker
    procs = [mp.Process(target=target,
                        args=(rank, worker_func_maker, barrier, bar_loop))
            for rank in range(1, n_gpu)]
    for p in procs:
        p.start()

    theano.gpuarray.use("cuda0")
    f_train, name = build_train_func()

    barrier.wait()
    # workers build or unpickle
    time.sleep(1)
    barrier.wait()
    # workers are ready.
    test_the_function(f_train, name=name, barrier=barrier, bar_loop=bar_loop)

    for p in procs:
        p.join()
项目:Synkhronos    作者:astooke    | 项目源码 | 文件源码
def main(n_pairs=7):
    n_pairs = int(n_pairs)

    barrier = mp.Barrier(n_pairs + 1)
    mgr = mp.Manager()
    sync_dict = mgr.dict()

    workers = [mp.Process(target=worker, args=(rank + 1, barrier, sync_dict))
            for rank in range(n_pairs)]

    for w in workers:
        w.start()

    master(n_pairs, barrier, sync_dict)

    for w in workers:
        w.join()
项目:ringbuffer    作者:bslatkin    | 项目源码 | 文件源码
def test_writer_blocks_multiple_readers(self):
        with self.lock.for_write():
            before_read = multiprocessing.Barrier(3)
            during_read = multiprocessing.Barrier(2)
            after_read = multiprocessing.Barrier(2)

            def test():
                self.assert_writer()

                before_read.wait()

                with self.lock.for_read():
                    during_read.wait()
                    value = self.reader_count()
                    after_read.wait()
                    return value

            r1 = self.async(test)
            r2 = self.async(test)

            # Wait until we can confirm that all readers are locked out
            before_read.wait()
            self.assert_writer()

        self.assertEqual(2, self.get_result(r1))
        self.assertEqual(2, self.get_result(r2))
        self.assert_unlocked()
项目:ringbuffer    作者:bslatkin    | 项目源码 | 文件源码
def test_reader_blocks_writer(self):
        with self.lock.for_read():
            before_write = multiprocessing.Barrier(2)
            during_write = multiprocessing.Barrier(2)
            after_write = multiprocessing.Barrier(2)
            after_unlock = multiprocessing.Barrier(2)

            def test():
                self.assert_readers(1)

                before_write.wait()

                with self.lock.for_write():
                    self.assert_writer()
                    return 'written'

            writer = self.async(test)

            # Wait until we can confirm that all writers are locked out.
            before_write.wait()
            self.assert_readers(1)

        self.assertEqual('written', self.get_result(writer))
        self.assert_unlocked()
项目:ringbuffer    作者:bslatkin    | 项目源码 | 文件源码
def test_multiple_readers_block_writer(self):
        with self.lock.for_read():
            before_read = multiprocessing.Barrier(3)
            after_read = multiprocessing.Barrier(2)

            def test_reader():
                self.assert_readers(1)

                with self.lock.for_read():
                    before_read.wait()
                    value = self.reader_count()
                    after_read.wait()
                    return value

            def test_writer():
                before_read.wait()

                with self.lock.for_write():
                    self.assert_writer()
                    return 'written'

            reader = self.async(test_reader)
            writer = self.async(test_writer)

            # Wait for the write to be blocked by multiple readers.
            before_read.wait()
            self.assert_readers(2)
            after_read.wait()

        self.assertEqual(2, self.get_result(reader))
        self.assertEqual('written', self.get_result(writer))
        self.assert_unlocked()
项目:ringbuffer    作者:bslatkin    | 项目源码 | 文件源码
def test_multiple_writers_block_each_other(self):
        with self.lock.for_write():
            before_write = multiprocessing.Barrier(2)

            def test():
                before_write.wait()

                with self.lock.for_write():
                    self.assert_writer()
                    return 'written'

            writer = self.async(test)

            before_write.wait()
            self.assert_writer()

        self.assertEqual('written', self.get_result(writer))
        self.assert_unlocked()
项目:ouroboros    作者:pybee    | 项目源码 | 文件源码
def setUp(self):
        self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout)
项目:ouroboros    作者:pybee    | 项目源码 | 文件源码
def test_action(self):
        """
        Test the 'action' callback
        """
        results = self.DummyList()
        barrier = self.Barrier(self.N, action=AppendTrue(results))
        self.run_threads(self._test_action_f, (barrier, results))
        self.assertEqual(len(results), 1)
项目:ouroboros    作者:pybee    | 项目源码 | 文件源码
def test_abort_and_reset(self):
        """
        Test that a barrier can be reset after being broken.
        """
        results1 = self.DummyList()
        results2 = self.DummyList()
        results3 = self.DummyList()
        barrier2 = self.Barrier(self.N)

        self.run_threads(self._test_abort_and_reset_f,
                         (self.barrier, barrier2, results1, results2, results3))
        self.assertEqual(len(results1), 0)
        self.assertEqual(len(results2), self.N-1)
        self.assertEqual(len(results3), self.N)
项目:ouroboros    作者:pybee    | 项目源码 | 文件源码
def test_default_timeout(self):
        """
        Test the barrier's default timeout
        """
        barrier = self.Barrier(self.N, timeout=0.5)
        results = self.DummyList()
        self.run_threads(self._test_default_timeout_f, (barrier, results))
        self.assertEqual(len(results), barrier.parties)
项目:Synkhronos    作者:astooke    | 项目源码 | 文件源码
def test_multi_process_sequence(n_gpu=2, worker_func_maker=unpickle_func):
    barrier = mp.Barrier(n_gpu)
    if PROFILE:
        target = seq_profiling_worker
    else:
        target = sequence_worker
    procs = [mp.Process(target=target,
                        args=(rank, n_gpu, barrier, worker_func_maker))
        for rank in range(1, n_gpu)]
    for p in procs:
        p.start()

    theano.gpuarray.use("cuda0")
    f_train, name = build_train_func()
    pickle_func(f_train)

    barrier.wait()
    # workers make function (maybe unpickle).
    barrier.wait()
    for i in range(n_gpu):
        time.sleep(1)
        barrier.wait()
        if i == 0:
            test_the_function(f_train, name)

    for p in procs:
        p.join()
项目:Synkhronos    作者:astooke    | 项目源码 | 文件源码
def main():

    x = np.ctypeslib.as_array(mp.RawArray('f', N * C * H * W)).reshape(N, C, H, W)
    print(x.shape)

    b = mp.Barrier(G)

    workers = [mp.Process(target=worker, args=(x, b, rank)) for rank in range(1, G)]
    for w in workers:
        w.start()

    worker(x, b, 0)

    for w in workers:
        w.join()
项目:ouroboros    作者:pybee    | 项目源码 | 文件源码
def test_event(self):
        event = self.Event()
        wait = TimingWrapper(event.wait)

        # Removed temporarily, due to API shear, this does not
        # work with threading._Event objects. is_set == isSet
        self.assertEqual(event.is_set(), False)

        # Removed, threading.Event.wait() will return the value of the __flag
        # instead of None. API Shear with the semaphore backed mp.Event
        self.assertEqual(wait(0.0), False)
        self.assertTimingAlmostEqual(wait.elapsed, 0.0)
        self.assertEqual(wait(TIMEOUT1), False)
        self.assertTimingAlmostEqual(wait.elapsed, TIMEOUT1)

        event.set()

        # See note above on the API differences
        self.assertEqual(event.is_set(), True)
        self.assertEqual(wait(), True)
        self.assertTimingAlmostEqual(wait.elapsed, 0.0)
        self.assertEqual(wait(TIMEOUT1), True)
        self.assertTimingAlmostEqual(wait.elapsed, 0.0)
        # self.assertEqual(event.is_set(), True)

        event.clear()

        #self.assertEqual(event.is_set(), False)

        p = self.Process(target=self._test_event, args=(event,))
        p.daemon = True
        p.start()
        self.assertEqual(wait(), True)

#
# Tests for Barrier - adapted from tests in test/lock_tests.py
#

# Many of the tests for threading.Barrier use a list as an atomic
# counter: a value is appended to increment the counter, and the
# length of the list gives the value.  We use the class DummyList
# for the same purpose.
项目:RL-Universe    作者:Bifrost-Research    | 项目源码 | 文件源码
def main(args):
    logger.debug("CONFIGURATION : {}".format(args))

    #Global shared counter alloated in the shared memory. i = signed int
    args.global_step = Value('i', 0)

    #Barrier used to synchronize the threads
    args.barrier = Barrier(args.num_actor_learners)

    #Thread safe queue used to communicate between the threads
    args.queue = Queue()

    #Number of actions available at each steps of the game
    args.nb_actions = atari_environment.get_num_actions(args.game)

    if args.visualize == 0:
        args.visualize = False
    else:
        args.visualize = True

    actor_learners = []

    #n-1 pipes are needed.
    pipes = [Pipe() for _ in range(args.num_actor_learners - 1)]

    #Loop launching all the learned on different process
    for i in range(args.num_actor_learners):

        if i == 0:
            #A pipe to each other processe
            args.pipes = [pipe[0] for pipe in pipes]
        else:
            #A pipe to the process 0
            args.pipes = [pipes[i-1][1]]

        #Process id
        args.actor_id = i

        #Random see for each process
        rng = np.random.RandomState(int(time.time()))
        args.random_seed = rng.randint(1000)

        actor_learners.append(A3C_Learner(args))
        actor_learners[-1].start()

    #Waiting for the processes to finish
    for t in actor_learners:
        t.join()

    logger.debug("All processes are over")