Python tensorflow.contrib.rnn 模块,RNNCell() 实例源码

我们从Python开源项目中,提取了以下14个代码示例,用于说明如何使用tensorflow.contrib.rnn.RNNCell()

项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def apply_dropout(
    cell, input_keep_probability, output_keep_probability, random_seed=None):
  """Apply dropout to the outputs and inputs of `cell`.

  Args:
    cell: An `RNNCell`.
    input_keep_probability: Probability to keep inputs to `cell`. If `None`,
      no dropout is applied.
    output_keep_probability: Probability to keep outputs of `cell`. If `None`,
      no dropout is applied.
    random_seed: Seed for random dropout.

  Returns:
    An `RNNCell`, the result of applying the supplied dropouts to `cell`.
  """
  input_prob_none = input_keep_probability is None
  output_prob_none = output_keep_probability is None
  if input_prob_none and output_prob_none:
    return cell
  if input_prob_none:
    input_keep_probability = 1.0
  if output_prob_none:
    output_keep_probability = 1.0
  return contrib_rnn.DropoutWrapper(
      cell, input_keep_probability, output_keep_probability, random_seed)
项目:opinatt    作者:epochx    | 项目源码 | 文件源码
def __init__(self, cell, zoneout_prob, is_training=True):
    if not isinstance(cell, RNNCell):
      raise TypeError("The parameter cell is not an RNNCell.")
    if isinstance(cell, BasicLSTMCell):
      self._tuple = lambda x: LSTMStateTuple(*x)
    else:
      self._tuple = lambda x: tuple(x)
    if (isinstance(zoneout_prob, float) and
          not (zoneout_prob >= 0.0 and zoneout_prob <= 1.0)):
      raise ValueError("Parameter zoneout_prob must be between 0 and 1: %d"
                       % zoneout_prob)
    self._cell = cell
    self._zoneout_prob = zoneout_prob
    self.is_training = is_training
项目:neuralmonkey    作者:ufal    | 项目源码 | 文件源码
def _make_rnn_cell(spec: RNNSpec) -> Callable[[], RNNCell]:
    """Return the graph template for creating RNN cells."""
    if spec.cell_type == "GRU":
        def cell():
            return OrthoGRUCell(spec.size)
    elif spec.cell_type == "LSTM":
        def cell():
            return tf.contrib.rnn.LSTMCell(spec.size)
    else:
        raise ValueError("Unknown RNN cell: {}".format(spec.cell_type))

    return cell


# pylint: disable=too-many-instance-attributes
项目:neuralmonkey    作者:ufal    | 项目源码 | 文件源码
def _make_rnn_cell(spec: RNNSpec) -> Callable[[], RNNCell]:
    """Return the graph template for creating RNN cells."""
    if spec.cell_type == "GRU":
        def cell():
            return OrthoGRUCell(spec.size)
    elif spec.cell_type == "LSTM":
        def cell():
            return tf.contrib.rnn.LSTMCell(spec.size)
    else:
        raise ValueError("Unknown RNN cell: {}".format(spec.cell_type))

    return cell


# pylint: disable=too-many-instance-attributes
项目:neuralmonkey    作者:ufal    | 项目源码 | 文件源码
def _make_rnn_cell(spec: RNNSpec) -> Callable[[], RNNCell]:
    """Return the graph template for creating RNN cells."""
    if spec.cell_type == "GRU":
        def cell():
            return OrthoGRUCell(spec.size)
    elif spec.cell_type == "LSTM":
        def cell():
            return tf.contrib.rnn.LSTMCell(spec.size)
    else:
        raise ValueError("Unknown RNN cell: {}".format(spec.cell_type))

    return cell


# pylint: disable=too-many-instance-attributes
项目:tf-tutorial    作者:zchen0211    | 项目源码 | 文件源码
def with_batch_norm_control(self, is_training=True, test_local_stats=True):
    """Wraps this RNNCore with the additional control input to the `BatchNorm`s.

    Example usage:

      lstm = nnd.LSTM(4)
      is_training = tf.placeholder(tf.bool)
      rnn_input = ...
      my_rnn = rnn.rnn(lstm.with_batch_norm_control(is_training), rnn_input)

    Args:
      is_training: Boolean that indicates whether we are in
        training mode or testing mode. When in training mode, the batch norm
        statistics are taken from the given batch, and moving statistics are
        updated. When in testing mode, the moving statistics are not updated,
        and in addition if `test_local_stats` is False then the moving
        statistics are used for the batch statistics. See the `BatchNorm` module
        for more details.
      test_local_stats: Boolean scalar indicated whether to use local
        batch statistics in test mode.

    Returns:
      RNNCell wrapping this class with the extra input(s) added.
    """
    return LSTM.CellWithExtraInput(self,
                                   is_training=is_training,
                                   test_local_stats=test_local_stats)
项目:tf-tutorial    作者:zchen0211    | 项目源码 | 文件源码
def __init__(self, cell, *args, **kwargs):
      """Construct the CellWithExtraInput.

      Args:
        cell: The RNNCell to wrap (typically a nn.RNNCore).
        *args: Extra arguments to pass to __call__.
        **kwargs: Extra keyword arguments to pass to __call__.
      """
      self._cell = cell
      self._args = args
      self._kwargs = kwargs
项目:emoatt    作者:epochx    | 项目源码 | 文件源码
def __init__(self, cell, zoneout_prob, is_training=True):
    if not isinstance(cell, RNNCell):
      raise TypeError("The parameter cell is not an RNNCell.")
    if isinstance(cell, BasicLSTMCell):
      self._tuple = lambda x: LSTMStateTuple(*x)
    else:
      self._tuple = lambda x: tuple(x)
    if (isinstance(zoneout_prob, float) and
          not (zoneout_prob >= 0.0 and zoneout_prob <= 1.0)):
      raise ValueError("Parameter zoneout_prob must be between 0 and 1: %d"
                       % zoneout_prob)
    self._cell = cell
    self._zoneout_prob = zoneout_prob
    self.is_training = is_training
项目:tf-sparql    作者:derdav3    | 项目源码 | 文件源码
def with_batch_norm_control(self, is_training=True, test_local_stats=True):
    """Wraps this RNNCore with the additional control input to the `BatchNorm`s.

    Example usage:

      lstm = nnd.LSTM(4)
      is_training = tf.placeholder(tf.bool)
      rnn_input = ...
      my_rnn = rnn.rnn(lstm.with_batch_norm_control(is_training), rnn_input)

    Args:
      is_training: Boolean that indicates whether we are in
        training mode or testing mode. When in training mode, the batch norm
        statistics are taken from the given batch, and moving statistics are
        updated. When in testing mode, the moving statistics are not updated,
        and in addition if `test_local_stats` is False then the moving
        statistics are used for the batch statistics. See the `BatchNorm` module
        for more details.
      test_local_stats: Boolean scalar indicated whether to use local
        batch statistics in test mode.

    Returns:
      RNNCell wrapping this class with the extra input(s) added.
    """
    return LSTM.CellWithExtraInput(self,
                                   is_training=is_training,
                                   test_local_stats=test_local_stats)
项目:tf-sparql    作者:derdav3    | 项目源码 | 文件源码
def __init__(self, cell, *args, **kwargs):
      """Construct the CellWithExtraInput.

      Args:
        cell: The RNNCell to wrap (typically a nn.RNNCore).
        *args: Extra arguments to pass to __call__.
        **kwargs: Extra keyword arguments to pass to __call__.
      """
      self._cell = cell
      self._args = args
      self._kwargs = kwargs
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _to_rnn_cell(cell_or_type, num_units, num_layers):
  """Constructs and return an `RNNCell`.

  Args:
    cell_or_type: Either a string identifying the `RNNCell` type, a subclass of
      `RNNCell` or an instance of an `RNNCell`.
    num_units: The number of units in the `RNNCell`.
    num_layers: The number of layers in the RNN.
  Returns:
    An initialized `RNNCell`.
  Raises:
    ValueError: `cell_or_type` is an invalid `RNNCell` name.
    TypeError: `cell_or_type` is not a string or a subclass of `RNNCell`.
  """
  if isinstance(cell_or_type, contrib_rnn.RNNCell):
    return cell_or_type
  if isinstance(cell_or_type, str):
    cell_or_type = _CELL_TYPES.get(cell_or_type)
    if cell_or_type is None:
      raise ValueError('The supported cell types are {}; got {}'.format(
          list(_CELL_TYPES.keys()), cell_or_type))
  if not issubclass(cell_or_type, contrib_rnn.RNNCell):
    raise TypeError(
        'cell_or_type must be a subclass of RNNCell or one of {}.'.format(
            list(_CELL_TYPES.keys())))
  single_cell = lambda: cell_or_type(num_units=num_units)
  if num_layers > 1:
    cell = contrib_rnn.MultiRNNCell(
        [single_cell() for _ in range(num_layers)], state_is_tuple=True)
  else:
    cell = single_cell()
  return cell
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def dict_to_state_tuple(input_dict, cell):
  """Reconstructs nested `state` from a dict containing state `Tensor`s.

  Args:
    input_dict: A dict of `Tensor`s.
    cell: An instance of `RNNCell`.
  Returns:
    If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n`
    where `n` is the number of nested entries in `cell.state_size`, this
    function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size`
    is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested
    tuple.
  Raises:
    ValueError: State is partially specified. The `input_dict` must contain
      values for all state components or none at all.
  """
  flat_state_sizes = nest.flatten(cell.state_size)
  state_tensors = []
  with ops.name_scope('dict_to_state_tuple'):
    for i, state_size in enumerate(flat_state_sizes):
      state_name = _get_state_name(i)
      state_tensor = input_dict.get(state_name)
      if state_tensor is not None:
        rank_check = check_ops.assert_rank(
            state_tensor, 2, name='check_state_{}_rank'.format(i))
        shape_check = check_ops.assert_equal(
            array_ops.shape(state_tensor)[1],
            state_size,
            name='check_state_{}_shape'.format(i))
        with ops.control_dependencies([rank_check, shape_check]):
          state_tensor = array_ops.identity(state_tensor, name=state_name)
        state_tensors.append(state_tensor)
    if not state_tensors:
      return None
    elif len(state_tensors) == len(flat_state_sizes):
      dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool)
      return nest.pack_sequence_as(dummy_state, state_tensors)
    else:
      raise ValueError(
          'RNN state was partially specified.'
          'Expected zero or {} state Tensors; got {}'.
          format(len(flat_state_sizes), len(state_tensors)))
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def construct_rnn(initial_state,
                  sequence_input,
                  cell,
                  num_label_columns,
                  dtype=dtypes.float32,
                  parallel_iterations=32,
                  swap_memory=True):
  """Build an RNN and apply a fully connected layer to get the desired output.

  Args:
    initial_state: The initial state to pass the RNN. If `None`, the
      default starting state for `self._cell` is used.
    sequence_input: A `Tensor` with shape `[batch_size, padded_length, d]`
      that will be passed as input to the RNN.
    cell: An initialized `RNNCell`.
    num_label_columns: The desired output dimension.
    dtype: dtype of `cell`.
    parallel_iterations: Number of iterations to run in parallel. Values >> 1
      use more memory but take less time, while smaller values use less memory
      but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs
      which would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
  Returns:
    activations: The output of the RNN, projected to `num_label_columns`
      dimensions.
    final_state: A `Tensor` or nested tuple of `Tensor`s representing the final
      state output by the RNN.
  """
  with ops.name_scope('RNN'):
    rnn_outputs, final_state = rnn.dynamic_rnn(
        cell=cell,
        inputs=sequence_input,
        initial_state=initial_state,
        dtype=dtype,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory,
        time_major=False)
    activations = layers.fully_connected(
        inputs=rnn_outputs,
        num_outputs=num_label_columns,
        activation_fn=None,
        trainable=True)
    return activations, final_state
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def __init__(self,
               num_units,
               num_dims=1,
               input_dims=None,
               output_dims=None,
               priority_dims=None,
               non_recurrent_dims=None,
               tied=False,
               cell_fn=None,
               non_recurrent_fn=None):
    """Initialize the parameters of a Grid RNN cell

    Args:
      num_units: int, The number of units in all dimensions of this GridRNN cell
      num_dims: int, Number of dimensions of this grid.
      input_dims: int or list, List of dimensions which will receive input data.
      output_dims: int or list, List of dimensions from which the output will be
        recorded.
      priority_dims: int or list, List of dimensions to be considered as
        priority dimensions.
              If None, no dimension is prioritized.
      non_recurrent_dims: int or list, List of dimensions that are not
        recurrent.
              The transfer function for non-recurrent dimensions is specified
                via `non_recurrent_fn`,
              which is default to be `tensorflow.nn.relu`.
      tied: bool, Whether to share the weights among the dimensions of this
        GridRNN cell.
              If there are non-recurrent dimensions in the grid, weights are
                shared between each
              group of recurrent and non-recurrent dimensions.
      cell_fn: function, a function which returns the recurrent cell object. Has
        to be in the following signature:
              def cell_func(num_units, input_size):
                # ...

              and returns an object of type `RNNCell`. If None, LSTMCell with
                default parameters will be used.
      non_recurrent_fn: a tensorflow Op that will be the transfer function of
        the non-recurrent dimensions
    """
    if num_dims < 1:
      raise ValueError('dims must be >= 1: {}'.format(num_dims))

    self._config = _parse_rnn_config(num_dims, input_dims, output_dims,
                                     priority_dims, non_recurrent_dims,
                                     non_recurrent_fn or nn.relu, tied,
                                     num_units)

    cell_input_size = (self._config.num_dims - 1) * num_units
    if cell_fn is None:
      self._cell = rnn.LSTMCell(
          num_units=num_units, input_size=cell_input_size, state_is_tuple=False)
    else:
      self._cell = cell_fn(num_units, cell_input_size)
      if not isinstance(self._cell, rnn.RNNCell):
        raise ValueError('cell_fn must return an object of type RNNCell')