Python gym 模块,Env() 实例源码

我们从Python开源项目中,提取了以下26个代码示例,用于说明如何使用gym.Env()

项目:distributional_perspective_on_RL    作者:Kiwoo    | 项目源码 | 文件源码
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname)
项目:pytorch.rl.learning    作者:moskomule    | 项目源码 | 文件源码
def __init__(self, env: gym.Env, gamma, epsilon, final_epsilon, final_exp_step):
        """
        :param env: environment
        :param gamma: discount rate
        :param epsilon: initial exploration rate
        :param final_epsilon: final exploration rate
        :param final_exp_step: the step terminating exploration
        """
        self.env = env
        self.action_size = self.env.action_space.n
        self.net = DQN(self.action_size)
        self.target_net = DQN(self.action_size)
        self._gamma = gamma
        self._initial_epsilon = epsilon
        self.epsilon = epsilon
        self._final_epsilon = final_epsilon
        self._final_exp_step = final_exp_step
        if cuda_available:
            self.net.cuda()
            self.target_net.cuda()
        self.update_target_net()
项目:combine-DT-with-NN-in-RL    作者:Burning-Bear    | 项目源码 | 文件源码
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname)
项目:combine-DT-with-NN-in-RL    作者:Burning-Bear    | 项目源码 | 文件源码
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname)
项目:rl-attack-detection    作者:yenchenlin    | 项目源码 | 文件源码
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname)
项目:baselines    作者:openai    | 项目源码 | 文件源码
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname)
项目:chi    作者:rmst    | 项目源码 | 文件源码
def print_env(env: Env):
    spec = getattr(env, 'spec', False)
    if spec:
        print(f'Env spec: {vars(spec)}')

    acsp = env.action_space
    obsp = env.observation_space

    print(f'Observation space {obsp}')
    if isinstance(obsp, Box) and len(obsp.high) < 20:
            print(f'low = {obsp.low}\nhigh = {obsp.high}')

    print(f'Action space {acsp}')
    if isinstance(acsp, Box) and len(acsp.high) < 20:
            print(f'low = {acsp.low}\nhigh = {acsp.high}')

    print("")
项目:chi    作者:rmst    | 项目源码 | 文件源码
def run_episode(self, env: gym.Env):

        meta_wrapper = get_wrapper(env, chi.rl.wrappers.Wrapper)

        done = False
        ob = env.reset()
        a, meta = self.act(ob)

        rs = []
        while not done:
            if meta_wrapper:
                meta_wrapper.set_meta(meta)  # send meta information to wrappers
            ob, r, done, info = env.step(a)
            a, meta = self.act(ob, r, done, info)
            rs.append(r)

        return sum(rs)
项目:ml-utils    作者:LinxiFan    | 项目源码 | 文件源码
def env_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname)
项目:rl-teacher    作者:nottombrown    | 项目源码 | 文件源码
def get_wrapper_by_name(env, classname):
    """Given an a gym environment possibly wrapped multiple times, returns a wrapper
    of class named classname or raises ValueError if no such wrapper was applied

    Parameters
    ----------
    env: gym.Env of gym.Wrapper
        gym environment
    classname: str
        name of the wrapper

    Returns
    -------
    wrapper: gym.Wrapper
        wrapper named classname
    """
    currentenv = env
    while True:
        if classname == currentenv.class_name():
            return currentenv
        elif isinstance(currentenv, gym.Wrapper):
            currentenv = currentenv.env
        else:
            raise ValueError("Couldn't find wrapper named %s" % classname)
项目:dqn-tensorflow    作者:DongjunLee    | 项目源码 | 文件源码
def bot_play(mainDQN: DeepQNetwork, env: gym.Env) -> None:
    """Test runs with rendering and logger.infos the total score
    Args:
        mainDQN (DeepQNetwork): DQN agent to run a test
        env (gym.Env): Gym Environment
    """
    state = env.reset()
    reward_sum = 0

    while True:
        env.render()
        action = np.argmax(mainDQN.predict(state))
        state, reward, done, _ = env.step(action)
        reward_sum += reward

        if done:
            logger.info("Total score: {}".format(reward_sum))
            break
项目:pytorch.rl.learning    作者:moskomule    | 项目源码 | 文件源码
def __init__(self, agent: Agent, val_env: gym.Env, lr, memory_size, target_update_freq, gradient_update_freq,
                 batch_size, replay_start, val_freq, log_freq_by_step, log_freq_by_ep, val_epsilon,
                 log_dir, weight_dir):
        """
        :param agent: agent object
        :param val_env: environment for validation
        :param lr: learning rate of optimizer
        :param memory_size: size of replay memory
        :param target_update_freq: frequency of update target network in steps
        :param gradient_update_freq: frequency of q-network update in steps
        :param batch_size: batch size for q-net
        :param replay_start: number of random exploration before starting
        :param val_freq: frequency of validation in steps
        :param log_freq_by_step: frequency of logging in steps
        :param log_freq_by_ep: frequency of logging in episodes
        :param val_epsilon: exploration rate for validation
        :param log_dir: directory for saving tensorboard things
        :param weight_dir: directory for saving weights when validated
        """
        self.agent = agent
        self.env = self.agent.env
        self.val_env = val_env
        self.optimizer = optim.RMSprop(params=self.agent.net.parameters(), lr=lr)
        self.memory = Memory(memory_size)
        self.target_update_freq = target_update_freq
        self.batch_size = batch_size
        self.replay_start = replay_start
        self.gradient_update_freq = gradient_update_freq
        self._step = 0
        self._episode = 0
        self._warmed = False
        self._val_freq = val_freq
        self.log_freq_by_step = log_freq_by_step
        self.log_freq_by_ep = log_freq_by_ep
        self._val_epsilon = val_epsilon
        self._writer = SummaryWriter(os.path.join(log_dir, datetime.now().strftime('%b%d_%H-%M-%S')))
        if weight_dir is not None and not os.path.exists(weight_dir):
            os.makedirs(weight_dir)
        self.weight_dir = weight_dir
项目:nesgym    作者:codescv    | 项目源码 | 文件源码
def __init__(self, **kwargs):
        utils.EzPickle.__init__(self)
        self.curr_seed = 0
        self.screen = np.zeros((SCREEN_HEIGHT, SCREEN_WIDTH, 3), dtype=np.uint8)
        self.closed = False
        self.can_send_command = True
        self.command_cond = Condition()
        self.viewer = None
        self.reward = 0
        episode_time_length_secs = 7
        frame_skip = 5
        fps = 60
        self.episode_length = episode_time_length_secs * fps / frame_skip

        self.actions = [
            'U', 'D', 'L', 'R',
            'UR', 'DR', 'URA', 'DRB',
            'A', 'B', 'RB', 'RA']
        self.action_space = spaces.Discrete(len(self.actions))
        self.frame = 0

        # for communication with emulator
        self.pipe_in = None
        self.pipe_out = None
        self.thread_incoming = None

        self.rom_file_path = None
        self.lua_interface_path = None
        self.emulator_started = False

    ## ---------- gym.Env methods -------------
项目:nesgym    作者:codescv    | 项目源码 | 文件源码
def _close(self):
        self.closed = True
    ## ------------- end gym.Env --------------

    ## ------------- emulator related ------------
项目:chi    作者:rmst    | 项目源码 | 文件源码
def list_wrappers(env: Union[Env, gym.Wrapper]):
    while isinstance(env, gym.Wrapper):
        yield env
        env = env.env
项目:chi    作者:rmst    | 项目源码 | 文件源码
def _step(self, action):
        s, r, t, i = super()._step(action)
        assert isinstance(self.env, Env)
        assert isinstance(self.env.action_space, Box)
        l = self.env.action_space.low
        h = self.env.action_space.high
        m = h - l
        dif = (action - np.clip(action, l - self.slack * m, h + self.slack * m))
        i.setdefault('unwrapped_reward', r)
        r -= self.alpha * np.mean(np.square(dif / m))
        return s, r, t, i
项目:gym    作者:openai    | 项目源码 | 文件源码
def test_double_close():
    class TestEnv(gym.Env):
        def __init__(self):
            self.close_count = 0

        def _close(self):
            self.close_count += 1

    env = TestEnv()
    assert env.close_count == 0
    env.close()
    assert env.close_count == 1
    env.close()
    assert env.close_count == 1
项目:gym    作者:openai    | 项目源码 | 文件源码
def test_no_monitor_reset_unless_done():
    def assert_reset_raises(env):
        errored = False
        try:
            env.reset()
        except error.Error:
            errored = True
        assert errored, "Env allowed a reset when it shouldn't have"

    with helpers.tempdir() as temp:
        # Make sure we can reset as we please without monitor
        env = gym.make('CartPole-v0')
        env.reset()
        env.step(env.action_space.sample())
        env.step(env.action_space.sample())
        env.reset()

        # can reset once as soon as we start
        env = Monitor(env, temp, video_callable=False)
        env.reset()

        # can reset multiple times in a row
        env.reset()
        env.reset()

        env.step(env.action_space.sample())
        env.step(env.action_space.sample())
        assert_reset_raises(env)

        # should allow resets after the episode is done
        d = False
        while not d:
            _, _, d, _ = env.step(env.action_space.sample())

        env.reset()
        env.reset()

        env.step(env.action_space.sample())
        assert_reset_raises(env)

        env.close()
项目:ray    作者:ray-project    | 项目源码 | 文件源码
def get_preprocessor_as_wrapper(cls, env, options=dict()):
        """Returns a preprocessor as a gym observation wrapper.

        Args:
            env (gym.Env): The gym environment to wrap.
            options (dict): Options to pass to the preprocessor.

        Returns:
            wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
        """

        preprocessor = cls.get_preprocessor(env, options)
        return _RLlibPreprocessorWrapper(env, preprocessor)
项目:gym-grid-world    作者:leomao    | 项目源码 | 文件源码
def _get_raw_array(self):
        raise NotImplementedError

    # gym.Env functions
项目:AI-Fight-the-Landlord    作者:YoungGer    | 项目源码 | 文件源码
def test_double_close():
    class TestEnv(gym.Env):
        def __init__(self):
            self.close_count = 0

        def _close(self):
            self.close_count += 1

    env = TestEnv()
    assert env.close_count == 0
    env.close()
    assert env.close_count == 1
    env.close()
    assert env.close_count == 1
项目:AI-Fight-the-Landlord    作者:YoungGer    | 项目源码 | 文件源码
def test_no_monitor_reset_unless_done():
    def assert_reset_raises(env):
        errored = False
        try:
            env.reset()
        except error.Error:
            errored = True
        assert errored, "Env allowed a reset when it shouldn't have"

    with helpers.tempdir() as temp:
        # Make sure we can reset as we please without monitor
        env = gym.make('CartPole-v0')
        env.reset()
        env.step(env.action_space.sample())
        env.step(env.action_space.sample())
        env.reset()

        # can reset once as soon as we start
        env = Monitor(env, temp, video_callable=False)
        env.reset()

        # can reset multiple times in a row
        env.reset()
        env.reset()

        env.step(env.action_space.sample())
        env.step(env.action_space.sample())
        assert_reset_raises(env)

        # should allow resets after the episode is done
        d = False
        while not d:
            _, _, d, _ = env.step(env.action_space.sample())

        env.reset()
        env.reset()

        env.step(env.action_space.sample())
        assert_reset_raises(env)

        env.close()
项目:gym-adv    作者:lerrel    | 项目源码 | 文件源码
def test_double_close():
    class TestEnv(gym.Env):
        def __init__(self):
            self.close_count = 0

        def _close(self):
            self.close_count += 1

    env = TestEnv()
    assert env.close_count == 0
    env.close()
    assert env.close_count == 1
    env.close()
    assert env.close_count == 1
项目:gym-adv    作者:lerrel    | 项目源码 | 文件源码
def test_no_monitor_reset_unless_done():
    def assert_reset_raises(env):
        errored = False
        try:
            env.reset()
        except error.Error:
            errored = True
        assert errored, "Env allowed a reset when it shouldn't have"

    with helpers.tempdir() as temp:
        # Make sure we can reset as we please without monitor
        env = gym.make('CartPole-v0')
        env.reset()
        env.step(env.action_space.sample())
        env.step(env.action_space.sample())
        env.reset()

        # can reset once as soon as we start
        env.monitor.start(temp, video_callable=False)
        env.reset()

        # can reset multiple times in a row
        env.reset()
        env.reset()

        env.step(env.action_space.sample())
        env.step(env.action_space.sample())
        assert_reset_raises(env)

        # should allow resets after the episode is done
        d = False
        while not d:
            _, _, d, _ = env.step(env.action_space.sample())

        env.reset()
        env.reset()

        env.step(env.action_space.sample())
        assert_reset_raises(env)

        env.monitor.close()
项目:reinforceflow    作者:dbobrenko    | 项目源码 | 文件源码
def add_observation_summary(obs, env):
    """Adds observation summary.
    Supports observation tensors with 1, 2 and 3 dimensions only.
    1-D tensors logs as histogram summary.
    2-D and 3-D tensors logs as image summary.

    Args:
        obs (Tensor): Observation.
        env (gym.Env): Environment instance.
    """
    from reinforceflow.envs.gym_wrapper import ObservationStackWrap, ImageWrap
    # Get all wrappers
    all_wrappers = {}
    env_wrapper = env
    while True:
        if isinstance(env_wrapper, gym.Wrapper):
            all_wrappers[env_wrapper.__class__] = env_wrapper
            env_wrapper = env_wrapper.env
        else:
            break

    # Check for grayscale
    gray = False
    if ImageWrap in all_wrappers:
        gray = all_wrappers[ImageWrap].grayscale

    # Check and wrap observation stack
    if ObservationStackWrap in all_wrappers:
        channels = 1 if gray else 3
        for obs_id in range(all_wrappers[ObservationStackWrap].obs_stack):
            o = obs[:, :, :, obs_id*channels:(obs_id+1)*channels]
            tf.summary.image('observation%d' % obs_id, o, max_outputs=1)
        return

    # Try to wrap current observation
    if len(env.observation_space.shape) == 1:
        tf.summary.histogram('observation', obs)
    elif len(env.observation_space.shape) == 2:
        tf.summary.image('observation', obs)
    elif len(env.observation_space.shape) == 3 and env.observation_space.shape[2] in (1, 3):
        tf.summary.image('observation', obs)
    else:
        logger.warn('Cannot create summary for observation with shape',
                    env.observation_space.shape)
项目:ray    作者:ray-project    | 项目源码 | 文件源码
def get_preprocessor(cls, env, options=dict()):
        """Returns a suitable processor for the given environment.

        Args:
            env (gym.Env): The gym environment to preprocess.
            options (dict): Options to pass to the preprocessor.

        Returns:
            preprocessor (Preprocessor): Preprocessor for the env observations.
        """

        # For older gym versions that don't set shape for Discrete
        if not hasattr(env.observation_space, "shape") and \
                isinstance(env.observation_space, gym.spaces.Discrete):
            env.observation_space.shape = ()

        env_name = env.spec.id
        obs_shape = env.observation_space.shape

        for k in options.keys():
            if k not in MODEL_CONFIGS:
                raise Exception(
                    "Unknown config key `{}`, all keys: {}".format(
                        k, MODEL_CONFIGS))

        print("Observation shape is {}".format(obs_shape))

        if env_name in cls._registered_preprocessor:
            return cls._registered_preprocessor[env_name](
                    env.observation_space, options)

        if obs_shape == ():
            print("Using one-hot preprocessor for discrete envs.")
            preprocessor = OneHotPreprocessor
        elif obs_shape == cls.ATARI_OBS_SHAPE:
            print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
            preprocessor = AtariPixelPreprocessor
        elif obs_shape == cls.ATARI_RAM_OBS_SHAPE:
            print("Assuming Atari ram env, using AtariRamPreprocessor.")
            preprocessor = AtariRamPreprocessor
        else:
            print("Non-atari env, not using any observation preprocessor.")
            preprocessor = NoPreprocessor

        return preprocessor(env.observation_space, options)