Python tensorflow.python.framework.ops 模块,get_collection() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.get_collection()

项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.VARIABLES):
  """Gets the list of variables, filtered by scope and/or suffix.

  Args:
    scope: an optional scope for filtering the variables to return.
    suffix: an optional suffix for filtering the variables to return.
    collection: in which collection search for. Defaults to GraphKeys.VARIABLES.

  Returns:
    a list of variables in collection with scope and suffix.
  """
  if suffix is not None:
    if ':' not in suffix:
      suffix += ':'
    scope = (scope or '') + '.*' + suffix
  return ops.get_collection(collection, scope)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _export_graph(graph, saver, checkpoint_path, export_dir,
                  default_graph_signature, named_graph_signatures,
                  exports_to_keep):
  """Exports graph via session_bundle, by creating a Session."""
  with graph.as_default():
    with tf_session.Session('') as session:
      variables.initialize_local_variables()
      data_flow_ops.initialize_all_tables()
      saver.restore(session, checkpoint_path)

      export = exporter.Exporter(saver)
      export.init(init_op=control_flow_ops.group(
          variables.initialize_local_variables(),
          data_flow_ops.initialize_all_tables()),
                  default_graph_signature=default_graph_signature,
                  named_graph_signatures=named_graph_signatures,
                  assets_collection=ops.get_collection(
                      ops.GraphKeys.ASSET_FILEPATHS))
      return export.export(export_dir, contrib_variables.get_global_step(),
                           session, exports_to_keep=exports_to_keep)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_variables(scope=None, suffix=None,
                  collection=ops.GraphKeys.GLOBAL_VARIABLES):
  """Gets the list of variables, filtered by scope and/or suffix.

  Args:
    scope: an optional scope for filtering the variables to return. Can be a
      variable scope or a string.
    suffix: an optional suffix for filtering the variables to return.
    collection: in which collection search for. Defaults to
      `GraphKeys.GLOBAL_VARIABLES`.

  Returns:
    a list of variables in collection with scope and suffix.
  """
  if isinstance(scope, variable_scope.VariableScope):
    scope = scope.name
  if suffix is not None:
    if ':' not in suffix:
      suffix += ':'
    scope = (scope or '') + '.*' + suffix
  return ops.get_collection(collection, scope)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _export_graph(graph, saver, checkpoint_path, export_dir,
                  default_graph_signature, named_graph_signatures,
                  exports_to_keep):
  """Exports graph via session_bundle, by creating a Session."""
  with graph.as_default():
    with tf_session.Session('') as session:
      variables.local_variables_initializer()
      data_flow_ops.initialize_all_tables()
      saver.restore(session, checkpoint_path)

      export = exporter.Exporter(saver)
      export.init(init_op=control_flow_ops.group(
          variables.local_variables_initializer(),
          data_flow_ops.initialize_all_tables()),
                  default_graph_signature=default_graph_signature,
                  named_graph_signatures=named_graph_signatures,
                  assets_collection=ops.get_collection(
                      ops.GraphKeys.ASSET_FILEPATHS))
      return export.export(export_dir, contrib_variables.get_global_step(),
                           session, exports_to_keep=exports_to_keep)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:ChineseNER    作者:zjy-ucas    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:LSTM-CRF-For-Named-Entity-Recognition    作者:zpppy    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:DL-Benchmarks    作者:DL-Benchmarks    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:PLSTM    作者:Enny1991    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
    """Get a sharded variable concatenated into one tensor."""
    sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
    if len(sharded_variable) == 1:
        return sharded_variable[0]

    concat_name = name + "/concat"
    concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
    for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
        if value.name == concat_full_name:
            return value

    concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
    ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                          concat_variable)
    return concat_variable
项目:diversity_based_attention    作者:PrekshaNema25    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:ROLO    作者:Guanghan    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testNoUpdatesWhenIsTrainingFalse(self):
    height, width = 3, 3
    with self.test_session() as sess:
      image_shape = (10, height, width, 3)
      image_values = np.random.rand(*image_shape)
      images = constant_op.constant(
          image_values, shape=image_shape, dtype=dtypes.float32)
      output = _layers.batch_norm(images, decay=0.1, is_training=False)
      update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
      # updates_ops are not added to UPDATE_OPS collection.
      self.assertEqual(len(update_ops), 0)
      # Initialize all variables
      sess.run(variables_lib.global_variables_initializer())
      moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
      moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
      mean, variance = sess.run([moving_mean, moving_variance])
      # After initialization moving_mean == 0 and moving_variance == 1.
      self.assertAllClose(mean, [0] * 3)
      self.assertAllClose(variance, [1] * 3)
      # When is_training is False batch_norm doesn't update moving_vars.
      for _ in range(10):
        sess.run([output])
      self.assertAllClose(moving_mean.eval(), [0] * 3)
      self.assertAllClose(moving_variance.eval(), [1] * 3)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testNoneUpdatesCollectionNoTraining(self):
    height, width = 3, 3
    with self.test_session() as sess:
      image_shape = (10, height, width, 3)
      image_values = np.random.rand(*image_shape)
      images = constant_op.constant(
          image_values, shape=image_shape, dtype=dtypes.float32)
      output = _layers.batch_norm(
          images, decay=0.1, updates_collections=None, is_training=False)
      # updates_ops are not added to UPDATE_OPS collection.
      self.assertEqual(ops.get_collection(ops.GraphKeys.UPDATE_OPS), [])
      # Initialize all variables
      sess.run(variables_lib.global_variables_initializer())
      moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
      moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
      mean, variance = sess.run([moving_mean, moving_variance])
      # After initialization moving_mean == 0 and moving_variance == 1.
      self.assertAllClose(mean, [0] * 3)
      self.assertAllClose(variance, [1] * 3)
      # When is_training is False batch_norm doesn't update moving_vars.
      for _ in range(10):
        sess.run([output])
      self.assertAllClose(moving_mean.eval(), [0] * 3)
      self.assertAllClose(moving_variance.eval(), [1] * 3)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testCreateConvWithWeightDecay(self):
    random_seed.set_random_seed(0)
    height, width = 3, 3
    with self.test_session() as sess:
      images = random_ops.random_uniform((5, height, width, 3), seed=1)
      regularizer = regularizers.l2_regularizer(0.01)
      layers_lib.separable_conv2d(
          images, 32, [3, 3], 2, weights_regularizer=regularizer)
      self.assertEqual(
          len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
      weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
      self.assertEqual(
          weight_decay.op.name,
          'SeparableConv2d/depthwise_kernel/Regularizer/l2_regularizer')
      sess.run(variables_lib.global_variables_initializer())
      self.assertLessEqual(sess.run(weight_decay), 0.05)
      weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[1]
      self.assertEqual(
          weight_decay.op.name,
          'SeparableConv2d/pointwise_kernel/Regularizer/l2_regularizer')
      self.assertLessEqual(sess.run(weight_decay), 0.05)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testReuseConvWithWeightDecay(self):
    height, width = 3, 3
    with self.test_session():
      images = random_ops.random_uniform((5, height, width, 3), seed=1)
      regularizer = regularizers.l2_regularizer(0.01)
      layers_lib.separable_conv2d(
          images, 32, [3, 3], 2, weights_regularizer=regularizer, scope='conv1')
      self.assertEqual(
          len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
      layers_lib.separable_conv2d(
          images,
          32, [3, 3],
          2,
          weights_regularizer=regularizer,
          scope='conv1',
          reuse=True)
      self.assertEqual(
          len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_relu_layer_basic_use(self):
    output = layers_lib.legacy_relu(self.input, 8)

    with session.Session() as sess:
      with self.assertRaises(errors_impl.FailedPreconditionError):
        sess.run(output)

      variables_lib.global_variables_initializer().run()
      out_value = sess.run(output)

    self.assertEqual(output.get_shape().as_list(), [2, 8])
    self.assertTrue(np.all(out_value >= 0), 'Relu should have all values >= 0.')

    self.assertEqual(2,
                     len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)))
    self.assertEqual(
        0, len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)))
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_regularizer_with_variable_reuse(self):
    cnt = [0]
    tensor = constant_op.constant(5.0)

    def test_fn(_):
      cnt[0] += 1
      return tensor

    with variable_scope.variable_scope('test') as vs:
      _layers.legacy_fully_connected(self.input, 2, weight_regularizer=test_fn)

    with variable_scope.variable_scope(vs, reuse=True):
      _layers.legacy_fully_connected(self.input, 2, weight_regularizer=test_fn)

    self.assertEqual([tensor],
                     ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES))
    self.assertEqual(1, cnt[0])
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testScatteredEmbeddingColumnSucceedsForDNN(self):
    wire_tensor = sparse_tensor.SparseTensor(
        values=["omar", "stringer", "marlo", "omar"],
        indices=[[0, 0], [1, 0], [1, 1], [2, 0]],
        dense_shape=[3, 2])

    features = {"wire": wire_tensor}
    # Big enough hash space so that hopefully there is no collision
    embedded_sparse = feature_column.scattered_embedding_column(
        "wire", 1000, 3, layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
    output = feature_column_ops.input_from_feature_columns(
        features, [embedded_sparse], weight_collections=["my_collection"])
    weights = ops.get_collection("my_collection")
    grad = gradients_impl.gradients(output, weights)
    with self.test_session():
      variables_lib.global_variables_initializer().run()
      gradient_values = []
      # Collect the gradient from the different partitions (one in this test)
      for p in range(len(grad)):
        gradient_values.extend(grad[p].values.eval())
      gradient_values.sort()
      self.assertAllEqual(gradient_values, [0.5] * 6 + [2] * 3)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testInputLayerWithCollectionsForDNN(self):
    real_valued = feature_column.real_valued_column("price")
    bucket = feature_column.bucketized_column(
        real_valued, boundaries=[0., 10., 100.])
    hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
    features = {
        "price":
            constant_op.constant([[20.], [110], [-3]]),
        "wire":
            sparse_tensor.SparseTensor(
                values=["omar", "stringer", "marlo"],
                indices=[[0, 0], [1, 0], [2, 0]],
                dense_shape=[3, 1])
    }
    embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
    feature_column_ops.input_from_feature_columns(
        features, [real_valued, bucket, embeded_sparse],
        weight_collections=["my_collection"])
    weights = ops.get_collection("my_collection")
    # one variable for embeded sparse
    self.assertEqual(1, len(weights))
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testVariablesAddedToCollection(self):
    price_bucket = feature_column.bucketized_column(
        feature_column.real_valued_column("price"), boundaries=[0., 10., 100.])
    country = feature_column.sparse_column_with_hash_bucket(
        "country", hash_bucket_size=5)
    country_price = feature_column.crossed_column(
        [country, price_bucket], hash_bucket_size=10)
    with ops.Graph().as_default():
      features = {
          "price":
              constant_op.constant([[20.]]),
          "country":
              sparse_tensor.SparseTensor(
                  values=["US", "SV"],
                  indices=[[0, 0], [0, 1]],
                  dense_shape=[1, 2])
      }
      feature_column_ops.weighted_sum_from_feature_columns(
          features, [country_price, price_bucket],
          num_outputs=1,
          weight_collections=["my_collection"])
      weights = ops.get_collection("my_collection")
      # 3 = bias + price_bucket + country_price
      self.assertEqual(3, len(weights))
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def benchmarkTfRNNLSTMBlockCellTraining(self):
    test_configs = self._GetTestConfig()
    for config_name, config in test_configs.items():
      num_layers = config["num_layers"]
      num_units = config["num_units"]
      batch_size = config["batch_size"]
      seq_length = config["seq_length"]

      with ops.Graph().as_default(), ops.device("/gpu:0"):
        inputs = seq_length * [
            array_ops.zeros([batch_size, num_units], dtypes.float32)
        ]
        cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units)  # pylint: disable=cell-var-from-loop

        multi_cell = core_rnn_cell_impl.MultiRNNCell(
            [cell() for _ in range(num_layers)])
        outputs, final_state = core_rnn.static_rnn(
            multi_cell, inputs, dtype=dtypes.float32)
        trainable_variables = ops.get_collection(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        gradients = gradients_impl.gradients([outputs, final_state],
                                             trainable_variables)
        training_op = control_flow_ops.group(*gradients)
        self._BenchmarkOp(training_op, "tf_rnn_lstm_block_cell %s %s" %
                          (config_name, self._GetConfigDesc(config)))
项目:Tensormodels    作者:asheshjain399    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:piecewisecrf    作者:Vaan5    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:terngrad    作者:wenwei202    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_global_step(graph=None):
  """Get the global step tensor.

  The global step tensor must be an integer variable. We first try to find it
  in the collection `GLOBAL_STEP`, or by name `global_step:0`.

  Args:
    graph: The graph to find the global step in. If missing, use default graph.

  Returns:
    The global step variable, or `None` if none was found.

  Raises:
    TypeError: If the global step tensor has a non-integer type, or if it is not
      a `Variable`.
  """
  graph = ops.get_default_graph() if graph is None else graph
  global_step_tensor = None
  global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
  if len(global_step_tensors) == 1:
    global_step_tensor = global_step_tensors[0]
  elif not global_step_tensors:
    try:
      global_step_tensor = graph.get_tensor_by_name('global_step:0')
    except KeyError:
      return None
  else:
    logging.error('Multiple tensors in global_step collection.')
    return None

  assert_global_step(global_step_tensor)
  return global_step_tensor
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def add_model_variable(var):
  """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.

  Args:
    var: a variable.
  """
  if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
    ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None):
  """Map stochastic tensors to the fixed losses that depend on them.

  Args:
    fixed_losses: a list of `Tensor`s.
    stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses.
      If `None`, all `StochasticTensor`s in the graph will be used.

  Returns:
    A dict `dependencies` that maps `StochasticTensor` objects to subsets of
    `fixed_losses`.

    If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there
    is a direct path from `st.value()` to `loss` in the graph.
  """
  stoch_value_collection = stochastic_tensors or ops.get_collection(
      stochastic_tensor.STOCHASTIC_TENSOR_COLLECTION)

  if not stoch_value_collection:
    return {}

  stoch_value_map = dict(
      (node.value(), node) for node in stoch_value_collection)

  # Step backwards through the graph to see which surrogate losses correspond
  # to which fixed_losses.
  #
  # TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the
  # same frame.
  stoch_dependencies_map = collections.defaultdict(set)
  for loss in fixed_losses:
    boundary = set([loss])
    while boundary:
      edge = boundary.pop()
      edge_stoch_node = stoch_value_map.get(edge, None)
      if edge_stoch_node:
        stoch_dependencies_map[edge_stoch_node].add(loss)
      boundary.update(edge.op.inputs)

  return stoch_dependencies_map
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
  """Gets the list of losses from the loss_collection.

  Args:
    scope: an optional scope for filtering the losses to return.
    loss_collection: Optional losses collection.

  Returns:
    a list of loss tensors.
  """
  return ops.get_collection(loss_collection, scope)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_regularization_losses(scope=None):
  """Gets the regularization losses.

  Args:
    scope: an optional scope for filtering the losses to return.

  Returns:
    A list of loss variables.
  """
  return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def QueueRunners(session):
  """Creates a context manager that handles starting and stopping queue runners.

  Args:
    session: the currently running session.

  Yields:
    a context in which queues are run.

  Raises:
    NestedQueueRunnerError: if a QueueRunners context is nested within another.
  """
  if not _queue_runner_lock.acquire(False):
    raise NestedQueueRunnerError('QueueRunners cannot be nested')

  coord = coordinator.Coordinator()
  threads = []
  for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
    threads.extend(qr.create_threads(session,
                                     coord=coord,
                                     daemon=True,
                                     start=True))
  try:
    yield
  finally:
    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=120)

    _queue_runner_lock.release()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def dnn(tensor_in, hidden_units, activation=nn.relu, dropout=None):
  """Creates fully connected deep neural network subgraph.
  This is deprecated. Please use contrib.layers.dnn instead.

  Args:
    tensor_in: tensor or placeholder for input features.
    hidden_units: list of counts of hidden units in each layer.
    activation: activation function between layers. Can be None.
    dropout: if not None, will add a dropout layer with given probability.

  Returns:
    A tensor which would be a deep neural network.
  """
  logging.warning("learn.ops.dnn is deprecated, \
    please use contrib.layers.dnn.")
  with vs.variable_scope('dnn'):
    for i, n_units in enumerate(hidden_units):
      with vs.variable_scope('layer%d' % i):
        # Weight initializer was set to None to replicate the behavior of
        # rnn_cell.linear. Using fully_connected's default initializer gets
        # slightly worse quality results on unit tests.
        tensor_in = layers.legacy_fully_connected(
            tensor_in,
            n_units,
            weight_init=None,
            weight_collections=['dnn_weights'],
            bias_collections=['dnn_biases'])
        if activation is not None:
          tensor_in = activation(tensor_in)
        if dropout is not None:
          is_training = array_ops_.squeeze(ops.get_collection('IS_TRAINING'))
          tensor_in = control_flow_ops.cond(
              is_training,
              lambda: dropout_ops.dropout(tensor_in, prob=(1.0 - dropout)),
              lambda: tensor_in)
    return tensor_in
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def every_n_step_begin(self, step):
    super(LoggingTrainable, self).every_n_step_begin(step)
    # Get a list of trainable variables at the begining of every N steps.
    # We cannot get this in __init__ because train_op has not been generated.
    trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=self._scope)
    self._names = {}
    for var in trainables:
      self._names[var.name] = var.value().name
    return list(self._names.values())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _centered_bias_step(self, targets, features):
    centered_bias = ops.get_collection(self._centered_bias_weight_collection)
    batch_size = array_ops.shape(targets)[0]
    logits = array_ops.reshape(
        array_ops.tile(centered_bias[0], [batch_size]),
        [batch_size, self._target_column.num_label_columns])
    with ops.name_scope(None, "centered_bias", (targets, features)):
      training_loss = self._target_column.training_loss(
          logits, targets, features)
    # Learn central bias by an optimizer. 0.1 is a convervative lr for a
    # single variable.
    return training.AdagradOptimizer(0.1).minimize(
        training_loss, var_list=centered_bias)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _centered_bias_step(targets, loss_fn, num_label_columns):
  centered_bias = ops.get_collection("centered_bias")
  batch_size = array_ops.shape(targets)[0]
  logits = array_ops.reshape(
      array_ops.tile(centered_bias[0], [batch_size]),
      [batch_size, num_label_columns])
  loss = loss_fn(logits, targets)
  return train.AdagradOptimizer(0.1).minimize(loss, var_list=centered_bias)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_vars(self):
    if self._get_feature_columns():
      return ops.get_collection(self._scope)
    return []
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_first_op_from_collection(collection_name):
  """Get first element from the collection."""
  elements = ops.get_collection(collection_name)
  if elements is not None:
    if elements:
      return elements[0]
  return None
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_first_op_from_collection(collection_name):
  elements = ops.get_collection(collection_name)
  if elements:
    return elements[0]
  return None
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_or_default(arg_name, collection_key, default_constructor):
    """Get from cache or create a default operation."""
    elements = ops.get_collection(collection_key)
    if elements:
      if len(elements) > 1:
        raise RuntimeError('More than one item in the collection "%s". '
                           'Please indicate which one to use by passing it to '
                           'the tf.Scaffold constructor as:  '
                           'tf.Scaffold(%s=item to use)', collection_key,
                           arg_name)
      return elements[0]
    op = default_constructor()
    if op is not None:
      ops.add_to_collection(collection_key, op)
    return op
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def apply_regularization(regularizer, weights_list=None):
  """Returns the summed penalty by applying `regularizer` to the `weights_list`.

  Adding a regularization penalty over the layer weights and embedding weights
  can help prevent overfitting the training data. Regularization over layer
  biases is less common/useful, but assuming proper data preprocessing/mean
  subtraction, it usually shouldn't hurt much either.

  Args:
    regularizer: A function that takes a single `Tensor` argument and returns
      a scalar `Tensor` output.
    weights_list: List of weights `Tensors` or `Variables` to apply
      `regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if
      `None`.

  Returns:
    A scalar representing the overall regularization penalty.

  Raises:
    ValueError: If `regularizer` does not return a scalar output, or if we find
        no weights.
  """
  if not weights_list:
    weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS)
  if not weights_list:
    raise ValueError('No weights to regularize.')
  with ops.name_scope('get_regularization_penalty',
                      values=weights_list) as scope:
    penalties = [regularizer(w) for w in weights_list]
    for p in penalties:
      if p.get_shape().ndims != 0:
        raise ValueError('regularizer must return a scalar Tensor instead of a '
                         'Tensor with rank %d.' % p.get_shape().ndims)

    summed_penalty = math_ops.add_n(penalties, name=scope)
    ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty)
    return summed_penalty
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def is_summary_tag_unique(tag):
  """Checks if a summary tag is unique.

  Args:
    tag: The tag to use

  Returns:
    True if the summary tag is unique.
  """
  existing_tags = [tensor_util.constant_value(summary.op.inputs[0])
                   for summary in ops.get_collection(ops.GraphKeys.SUMMARIES)]
  existing_tags = [name.tolist() if isinstance(name, np.ndarray) else name
                   for name in existing_tags]
  return tag.encode() not in existing_tags
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def summarize_collection(collection, name_filter=None,
                         summarizer=summarize_tensor):
  """Summarize a graph collection of tensors, possibly filtered by name."""
  tensors = []
  for op in ops.get_collection(collection):
    if name_filter is None or re.match(name_filter, op.op.name):
      tensors.append(op)
  return summarize_tensors(tensors, summarizer)


# Utility functions for commonly used collections
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def add_model_variable(var):
  """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.

  Args:
    var: a variable.
  """
  if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
    ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None):
  """Map stochastic tensors to the fixed losses that depend on them.

  Args:
    fixed_losses: a list of `Tensor`s.
    stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses.
      If `None`, all `StochasticTensor`s in the graph will be used.

  Returns:
    A dict `dependencies` that maps `StochasticTensor` objects to subsets of
    `fixed_losses`.

    If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there
    is a direct path from `st.value()` to `loss` in the graph.
  """
  stoch_value_collection = stochastic_tensors or ops.get_collection(
      stochastic_tensor.STOCHASTIC_TENSOR_COLLECTION)

  if not stoch_value_collection:
    return {}

  stoch_value_map = dict(
      (node.value(), node) for node in stoch_value_collection)

  # Step backwards through the graph to see which surrogate losses correspond
  # to which fixed_losses.
  #
  # TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the
  # same frame.
  stoch_dependencies_map = collections.defaultdict(set)
  for loss in fixed_losses:
    boundary = set([loss])
    while boundary:
      edge = boundary.pop()
      edge_stoch_node = stoch_value_map.get(edge, None)
      if edge_stoch_node:
        stoch_dependencies_map[edge_stoch_node].add(loss)
      boundary.update(edge.op.inputs)

  return stoch_dependencies_map
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
  """Gets the list of losses from the loss_collection.

  Args:
    scope: an optional scope for filtering the losses to return.
    loss_collection: Optional losses collection.

  Returns:
    a list of loss tensors.
  """
  return ops.get_collection(loss_collection, scope)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_regularization_losses(scope=None):
  """Gets the regularization losses.

  Args:
    scope: an optional scope for filtering the losses to return.

  Returns:
    A list of loss variables.
  """
  return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def QueueRunners(session):
  """Creates a context manager that handles starting and stopping queue runners.

  Args:
    session: the currently running session.

  Yields:
    a context in which queues are run.

  Raises:
    NestedQueueRunnerError: if a QueueRunners context is nested within another.
  """
  if not _queue_runner_lock.acquire(False):
    raise NestedQueueRunnerError('QueueRunners cannot be nested')

  coord = coordinator.Coordinator()
  threads = []
  for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
    threads.extend(qr.create_threads(session,
                                     coord=coord,
                                     daemon=True,
                                     start=True))
  try:
    yield
  finally:
    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=120)

    _queue_runner_lock.release()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def every_n_step_begin(self, step):
    super(LoggingTrainable, self).every_n_step_begin(step)
    # Get a list of trainable variables at the begining of every N steps.
    # We cannot get this in __init__ because train_op has not been generated.
    trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=self._scope)
    self._names = {}
    for var in trainables:
      self._names[var.name] = var.value().name
    return list(self._names.values())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_qr(self, name):
    for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
      if qr.name == name:
        return qr
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_vars(self):
    if self._get_feature_columns():
      return ops.get_collection(self._scope)
    return []