Python tensorflow.python.framework.ops 模块,get_default_graph() 实例源码

我们从Python开源项目中,提取了以下43个代码示例,用于说明如何使用tensorflow.python.framework.ops.get_default_graph()

项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def get_uid(prefix=''):
      """Associates a string prefix with an integer counter in a TensorFlow graph.

      Arguments:
        prefix: String prefix to index.

      Returns:
        Unique integer ID.

      Example:
>>> get_uid('dense')
    1
    >>> get_uid('dense')
    2
  ```
  """
  graph = ops.get_default_graph()
  layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph]
  layer_name_uids[prefix] += 1
  return layer_name_uids[prefix]

```

项目:lsdc    作者:febert    | 项目源码 | 文件源码
def create_global_step(graph=None):
  """Create global step tensor in graph.

  Args:
    graph: The graph in which to create the global step. If missing, use default
        graph.

  Returns:
    Global step tensor.

  Raises:
    ValueError: if global step key is already defined.
  """
  graph = ops.get_default_graph() if graph is None else graph
  if get_global_step(graph) is not None:
    raise ValueError('"global_step" already exists.')
  # Create in proper graph and base name_scope.
  with graph.as_default() as g, g.name_scope(None):
    collections = [ops.GraphKeys.VARIABLES, ops.GraphKeys.GLOBAL_STEP]
    return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64,
                    initializer=init_ops.zeros_initializer, trainable=False,
                    collections=collections)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _as_graph_element(obj):
  """Retrieves Graph element."""
  graph = ops.get_default_graph()
  if not isinstance(obj, six.string_types):
    if not hasattr(obj, "graph") or obj.graph != graph:
      raise ValueError("Passed %s should have graph attribute that is equal "
                       "to current graph %s." % (obj, graph))
    return obj
  if ":" in obj:
    element = graph.as_graph_element(obj)
  else:
    element = graph.as_graph_element(obj + ":0")
    # Check that there is no :1 (e.g. it's single output).
    try:
      graph.as_graph_element(obj + ":1")
    except (KeyError, ValueError):
      pass
    else:
      raise ValueError("Name %s is ambiguous, "
                       "as this `Operation` has multiple outputs "
                       "(at least 2)." % obj)
  return element
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get(logdir):
    """Returns the SummaryWriter for the specified directory.

    Args:
      logdir: str, name of the directory.

    Returns:
      A `SummaryWriter`.
    """
    with SummaryWriterCache._lock:
      if logdir not in SummaryWriterCache._cache:
        SummaryWriterCache._cache[logdir] = summary_io.SummaryWriter(
            logdir, graph=ops.get_default_graph())
      return SummaryWriterCache._cache[logdir]


# Backward compatible interface.  Remove?
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _as_graph_element(obj):
  """Retrieves Graph element."""
  graph = ops.get_default_graph()
  if not isinstance(obj, six.string_types):
    if not hasattr(obj, "graph") or obj.graph != graph:
      raise ValueError("Passed %s should have graph attribute that is equal "
                       "to current graph %s." % (obj, graph))
    return obj
  if ":" in obj:
    element = graph.as_graph_element(obj)
  else:
    element = graph.as_graph_element(obj + ":0")
    # Check that there is no :1 (e.g. it's single output).
    try:
      graph.as_graph_element(obj + ":1")
    except (KeyError, ValueError):
      pass
    else:
      raise ValueError("Name %s is ambiguous, "
                       "as this `Operation` has multiple outputs "
                       "(at least 2)." % obj)
  return element
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def create_global_step(graph=None):
  """Create global step tensor in graph.

  Args:
    graph: The graph in which to create the global step. If missing, use default
        graph.

  Returns:
    Global step tensor.

  Raises:
    ValueError: if global step key is already defined.
  """
  graph = ops.get_default_graph() if graph is None else graph
  if get_global_step(graph) is not None:
    raise ValueError('"global_step" already exists.')
  # Create in proper graph and base name_scope.
  with graph.as_default() as g, g.name_scope(None):
    collections = [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]
    return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64,
                    initializer=init_ops.zeros_initializer, trainable=False,
                    collections=collections)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_or_create_eval_step():
  """Gets or creates the eval step `Tensor`.

  Returns:
    A `Tensor` representing a counter for the evaluation step.

  Raises:
    ValueError: If multiple `Tensors` have been added to the
      `tf.GraphKeys.EVAL_STEP` collection.
  """
  graph = ops.get_default_graph()
  eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP)
  if len(eval_steps) == 1:
    return eval_steps[0]
  elif len(eval_steps) > 1:
    raise ValueError(
        'Multiple tensors added to tf.GraphKeys.EVAL_STEP')
  else:
    counter = variables.local_variable(0.0, name='eval_step')
    graph.add_to_collection(ops.GraphKeys.EVAL_STEP, counter)
    return counter
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _as_graph_element(obj):
  """Retrieves Graph element."""
  graph = ops.get_default_graph()
  if not isinstance(obj, six.string_types):
    if not hasattr(obj, "graph") or obj.graph != graph:
      raise ValueError("Passed %s should have graph attribute that is equal "
                       "to current graph %s." % (obj, graph))
    return obj
  if ":" in obj:
    element = graph.as_graph_element(obj)
  else:
    element = graph.as_graph_element(obj + ":0")
    # Check that there is no :1 (e.g. it's single output).
    try:
      graph.as_graph_element(obj + ":1")
    except (KeyError, ValueError):
      pass
    else:
      raise ValueError("Name %s is ambiguous, "
                       "as this `Operation` has multiple outputs "
                       "(at least 2)." % obj)
  return element
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_train_ops(self, features, labels):
    """See base class."""

    features = self._get_feature_dict(features)
    features, labels = self._feature_engineering_fn(features, labels)
    logits = self._logits(features, is_training=True)

    def _make_training_op(training_loss):
      global_step = contrib_variables.get_global_step()
      assert global_step

      linear_train_step = self._linear_model.get_train_step(training_loss)
      dnn_train_step = (self._dnn_model.get_train_step(training_loss) if
                        self._dnn_model else [])
      with ops.control_dependencies(linear_train_step + dnn_train_step):
        with ops.get_default_graph().colocate_with(global_step):
          return state_ops.assign_add(global_step, 1).op

    return self._head.head_ops(features, labels,
                               model_fn.ModeKeys.TRAIN,
                               _make_training_op,
                               logits=logits)
项目:Machine-Learning    作者:sfeng15    | 项目源码 | 文件源码
def export_meta_graph(self, filename=None, collection_list=None,
                        as_text=False):
    """Writes `MetaGraphDef` to save_path/filename.

    Args:
      filename: Optional meta_graph filename including the path.
      collection_list: List of string keys to collect.
      as_text: If `True`, writes the meta_graph as an ASCII proto.

    Returns:
      A `MetaGraphDef` proto.
    """
    return export_meta_graph(filename=filename,
                             graph_def=ops.get_default_graph().as_graph_def(),
                             saver_def=self.saver_def,
                             collection_list=collection_list,
                             as_text=as_text)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _as_graph_element(obj):
  """Retrieves Graph element."""
  graph = ops.get_default_graph()
  if not isinstance(obj, six.string_types):
    if not hasattr(obj, "graph") or obj.graph != graph:
      raise ValueError("Passed %s should have graph attribute that is equal "
                       "to current graph %s." % (obj, graph))
    return obj
  if ":" in obj:
    element = graph.as_graph_element(obj)
  else:
    element = graph.as_graph_element(obj + ":0")
    # Check that there is no :1 (e.g. it's single output).
    try:
      graph.as_graph_element(obj + ":1")
    except (KeyError, ValueError):
      pass
    else:
      raise ValueError("Name %s is ambiguous, "
                       "as this `Operation` has multiple outputs "
                       "(at least 2)." % obj)
  return element
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _assert_summaries(self,
                        output_dir,
                        writer,
                        expected_summaries=None,
                        expected_graphs=None,
                        expected_meta_graphs=None,
                        expected_session_logs=None):
    self.assertTrue(isinstance(writer, testing.FakeSummaryWriter))
    writer.assert_summaries(
        self,
        expected_logdir=output_dir,
        expected_graph=ops.get_default_graph(),
        expected_summaries=expected_summaries,
        expected_added_graphs=expected_graphs,
        expected_added_meta_graphs=expected_meta_graphs,
        expected_session_logs=expected_session_logs)

  # TODO(ptucker): Test number and contents of checkpoint files.
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testCustomConfig(self):
    test_random_seed = 5783452

    class TestInput(object):

      def __init__(self):
        self.random_seed = 0

      def config_test_input_fn(self):
        self.random_seed = ops.get_default_graph().seed
        return constant_op.constant([[1.]]), constant_op.constant([1.])

    config = run_config.RunConfig(tf_random_seed=test_random_seed)
    test_input = TestInput()
    est = estimator.Estimator(model_fn=linear_model_fn, config=config)
    est.fit(input_fn=test_input.config_test_input_fn, steps=1)
    # If input_fn ran, it will have given us the random seed set on the graph.
    self.assertEquals(test_random_seed, test_input.random_seed)
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def __init__(self, fetches, contraction_fn):
    """Creates an _ElementFetchMapper.

    This is the fetch mapper used for leaves in the fetch struct.  Because of
    the expansions mechanism, a leaf can actually fetch more than one tensor.

    Also note that the fetches here can be just strings (tensor or op names) or
    any other object that the graph knows how to convert to a tensor, such as a
    Variable.  So we have to run each fetch through `as_graph_element()` to get
    the corresponding tensor or op.

    Args:
      fetches: List of objects, as returned by a fetch_fn defined
        in _REGISTERED_EXPANSIONS.
      contraction_fn: Callable as returned by a fetch_fn.
    """
    self._unique_fetches = []
    for fetch in fetches:
      try:
        self._unique_fetches.append(ops.get_default_graph().as_graph_element(
            fetch, allow_tensor=True, allow_operation=True))
      except TypeError as e:
        raise TypeError('Fetch argument %r has invalid type %r, '
                        'must be a string or Tensor. (%s)'
                        % (fetch, type(fetch), str(e)))
      except ValueError as e:
        raise ValueError('Fetch argument %r cannot be interpreted as a '
                         'Tensor. (%s)' % (fetch, str(e)))
      except KeyError as e:
        raise ValueError('Fetch argument %r cannot be interpreted as a '
                         'Tensor. (%s)' % (fetch, str(e)))
    self._contraction_fn = contraction_fn
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def clear_session():
      """Destroys the current TF graph and creates a new one.

      Useful to avoid clutter from old models / layers.
      """
      global _SESSION
      global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
      ops.reset_default_graph()
      reset_uids()
      _SESSION = None
      phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
      _GRAPH_LEARNING_PHASES = {}
      _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def learning_phase():
      """Returns the learning phase flag.

      The learning phase flag is a bool tensor (0 = test, 1 = train)
      to be passed as input to any Keras function
      that uses a different behavior at train time and test time.

      Returns:
          Learning phase (scalar integer tensor or Python integer).
      """
      graph = ops.get_default_graph()
      if graph not in _GRAPH_LEARNING_PHASES:
        phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
        _GRAPH_LEARNING_PHASES[graph] = phase
      return _GRAPH_LEARNING_PHASES[graph]
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def __init__(self, layers=None, name=None):
        self.layers = []  # Stack of layers.
        self.model = None  # Internal Model instance.
        self.inputs = []  # List of input tensors
        self.outputs = []  # List of length 1: the output tensor (unique).
        self._trainable = True
        self._initial_weights = None

        # Model attributes.
        self.inbound_nodes = []
        self.outbound_nodes = []
        self.built = False

        # Set model name.
        if not name:
          prefix = 'sequential_'
          name = prefix + str(K.get_uid(prefix))
        self.name = name

        # The following properties are not actually used by Keras;
        # they exist for compatibility with TF's variable scoping mechanism.
        self._updates = []
        self._scope = None
        self._reuse = None
        self._base_name = name
        self._graph = ops.get_default_graph()

        # Add to the model any layers passed to the constructor.
        if layers:
          for layer in layers:
            self.add(layer)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_global_step(graph=None):
  """Get the global step tensor.

  The global step tensor must be an integer variable. We first try to find it
  in the collection `GLOBAL_STEP`, or by name `global_step:0`.

  Args:
    graph: The graph to find the global step in. If missing, use default graph.

  Returns:
    The global step variable, or `None` if none was found.

  Raises:
    TypeError: If the global step tensor has a non-integer type, or if it is not
      a `Variable`.
  """
  graph = ops.get_default_graph() if graph is None else graph
  global_step_tensor = None
  global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
  if len(global_step_tensors) == 1:
    global_step_tensor = global_step_tensors[0]
  elif not global_step_tensors:
    try:
      global_step_tensor = graph.get_tensor_by_name('global_step:0')
    except KeyError:
      return None
  else:
    logging.error('Multiple tensors in global_step collection.')
    return None

  assert_global_step(global_step_tensor)
  return global_step_tensor
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_or_create_global_step(graph=None):
  """Returns and create (if necessary) the global step variable.

  Args:
    graph: The graph in which to create the global step. If missing, use default
        graph.

  Returns:
    the tensor representing the global step variable.
  """
  graph = ops.get_default_graph() if graph is None else graph
  globalstep = get_global_step(graph)
  if globalstep is None:
    globalstep = create_global_step(graph)
  return globalstep
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def begin(self, max_steps=None):
    super(GraphDump, self).begin(max_steps=max_steps)
    self._tensors = []
    graph = ops.get_default_graph()
    graph_def = graph.as_graph_def()
    for node in graph_def.node:
      if node.op in self._ignore_ops:
        continue
      logging.info("op=%s name=%s.", node.op, node.name)
      try:
        self._tensors.append(graph.get_tensor_by_name(node.name + ":0"))
      except KeyError:
        pass
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_train_ops(self, features, targets):
    """See base class."""
    global_step = contrib_variables.get_global_step()
    assert global_step

    features = self._get_feature_dict(features)
    logits = self._logits(features, is_training=True)
    if self._enable_centered_bias:
      centered_bias_step = [self._centered_bias_step(targets, features)]
    else:
      centered_bias_step = []
    with ops.control_dependencies(centered_bias_step):
      training_loss = self._target_column.training_loss(logits, targets,
                                                        features)
      weighted_average_loss = self._target_column.loss(logits, targets,
                                                       features)

    logging_ops.scalar_summary("loss", weighted_average_loss)

    linear_train_step = self._linear_model.get_train_step(training_loss)
    dnn_train_step = (self._dnn_model.get_train_step(training_loss) if
                      self._dnn_model else [])

    with ops.control_dependencies(linear_train_step + dnn_train_step):
      with ops.get_default_graph().colocate_with(global_step):
        return state_ops.assign_add(global_step, 1).op, weighted_average_loss
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_model_fn(self, model_fn):
    """Backward compatibility way of adding class weight and IS_TRAINING.

    TODO(ipolosukhin): Remove this function after new layers are available.
    Specifically:
     * dropout and batch norm should work via update ops.
     * class weights should be retrieved from weights column or hparams.

    Args:
      model_fn: Core model function.

    Returns:
      Model function.
    """
    def _model_fn(features, targets, mode):
      """Model function."""
      ops.get_default_graph().add_to_collection('IS_TRAINING', mode == 'train')
      if self.class_weight is not None:
        constant_op.constant(self.class_weight, name='class_weight')
      predictions, loss = model_fn(features, targets)
      if isinstance(self.learning_rate, types.FunctionType):
        learning_rate = self.learning_rate(contrib_framework.get_global_step())
      else:
        learning_rate = self.learning_rate
      if isinstance(self.optimizer, types.FunctionType):
        optimizer = self.optimizer(learning_rate)
      else:
        optimizer = self.optimizer
      train_op = layers.optimize_loss(
          loss,
          contrib_framework.get_global_step(),
          learning_rate=learning_rate,
          optimizer=optimizer,
          clip_gradients=self.clip_gradients)
      return predictions, loss, train_op
    return _model_fn
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _loss_to_train_op(self, loss):
    """Map `loss` to a training op."""
    with ops.name_scope('loss_to_train_op'):
      trainable_variables = ops.get_default_graph().get_collection(
          ops.GraphKeys.TRAINABLE_VARIABLES)
      global_step = contrib_framework.get_global_step()
      gradients = self._optimizer.compute_gradients(
          loss=loss, var_list=trainable_variables)
      processed_gradients = self._process_gradients(gradients)
      return self._optimizer.apply_gradients(
          processed_gradients, global_step=global_step)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def finalize(self):
    """Creates operations if needed and finalizes the graph."""
    if self._init_op is None:
      self._init_op = Scaffold._get_or_default(
          'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
    if self._ready_op is None:
      self._ready_op = Scaffold._get_or_default(
          'ready_op', ops.GraphKeys.READY_OP,
          variables.report_uninitialized_variables)
    if self._local_init_op is None:
      self._local_init_op = Scaffold._get_or_default(
          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
          Scaffold._default_local_init_op)
    if self._summary_op is None:
      self._summary_op = Scaffold._get_or_default(
          'summary_op', ops.GraphKeys.SUMMARY_OP,
          logging_ops.merge_all_summaries)
    # pylint: disable=g-long-lambda
    if self._saver is None:
      self._saver = Scaffold._get_or_default(
          'saver',
          ops.GraphKeys.SAVERS,
          lambda: training_saver.Saver(sharded=True, allow_empty=True))
    # pylint: enable=g-long-lambda
    self._saver.build()

    ops.get_default_graph().finalize()
    return self
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_session_manager(self):
    if self._session_manager:
      return self._session_manager

    self._session_manager = sm.SessionManager(
        local_init_op=self._scaffold.local_init_op,
        ready_op=self._scaffold.ready_op,
        graph=ops.get_default_graph())
    return self._session_manager
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def before_run(self, run_context):  # pylint: disable=unused-argument
    if self._last_saved_time is None:
      # Write graph in the first call
      training_util.write_graph(
          ops.get_default_graph().as_graph_def(add_shapes=True),
          self._checkpoint_dir,
          "graph.pbtxt")
      self._summary_writer.add_graph(ops.get_default_graph())

    return SessionRunArgs(self._global_step_tensor)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def begin(self, max_steps=None):
    super(GraphDump, self).begin(max_steps=max_steps)
    self._tensors = []
    graph = ops.get_default_graph()
    graph_def = graph.as_graph_def()
    for node in graph_def.node:
      if node.op in self._ignore_ops:
        continue
      logging.info("op=%s name=%s.", node.op, node.name)
      try:
        self._tensors.append(graph.get_tensor_by_name(node.name + ":0"))
      except KeyError:
        pass
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_train_ops(self, features, labels):
    global_step = contrib_variables.get_global_step()
    assert global_step

    logits = self._model.build_model(
        features, self._feature_columns, is_training=True)
    model_fn_ops = self._head.head_ops(features, labels,
                                       tf.contrib.learn.ModeKeys.TRAIN,
                                       _noop_training_fn, logits=logits)
    train_step = self._model.get_train_step(model_fn_ops.loss)

    with ops.control_dependencies(train_step):
      with ops.get_default_graph().colocate_with(global_step):
        return state_ops.assign_add(global_step, 1).op, model_fn_ops.loss
项目:self-supervision    作者:gustavla    | 项目源码 | 文件源码
def average_name(self, var):
    """Returns the name of the `Variable` holding the average for `var`.

    The typical scenario for `ExponentialMovingAverage` is to compute moving
    averages of variables during training, and restore the variables from the
    computed moving averages during evaluations.

    To restore variables, you have to know the name of the shadow variables.
    That name and the original variable can then be passed to a `Saver()` object
    to restore the variable from the moving average value with:
      `saver = tf.train.Saver({ema.average_name(var): var})`

    `average_name()` can be called whether or not `apply()` has been called.

    Args:
      var: A `Variable` object.

    Returns:
      A string: The name of the variable that will be used or was used
      by the `ExponentialMovingAverage class` to hold the moving average of
      `var`.
    """
    if var in self._averages:
      return self._averages[var].op.name
    return ops.get_default_graph().unique_name(
        var.op.name + "/" + self._name, mark_as_used=False)
项目:sciencebeam-gym    作者:elifesciences    | 项目源码 | 文件源码
def __init__(self, session_init_fn, graph=None):
    self._session_init_fn = session_init_fn
    if graph is None:
      graph = ops.get_default_graph()
    self._graph = graph
项目:sktacc    作者:jclee81    | 项目源码 | 文件源码
def restore_graph(s):
    log.info('restore_graph')
    g = ops.get_default_graph()
    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(s)
    # print_nodes(graph_def)
    # print ('before', len(g.as_graph_def().node))
    importer.import_graph_def(graph_def, name='restore')
    # print ('after', len(g.as_graph_def().node))
    # print_nodes(g.as_graph_def())
    # t = g.get_tensor_by_name('restore/y1:0')
    return graph_def
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def create_global_step(graph=None):
  """Create global step tensor in graph.

  Args:
    graph: The graph in which to create the global step. If missing, use default
        graph.

  Returns:
    Global step tensor.

  Raises:
    ValueError: if global step key is already defined.
  """
  graph = ops.get_default_graph() if graph is None else graph
  if get_global_step(graph) is not None:
    raise ValueError('"global_step" already exists.')
  # Create in proper graph and base name_scope.
  with graph.as_default() as g, g.name_scope(None):
    collections = [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]
    return variable(
        ops.GraphKeys.GLOBAL_STEP,
        shape=[],
        dtype=dtypes.int64,
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        collections=collections)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def get_or_create_global_step(graph=None):
  """Returns and create (if necessary) the global step variable.

  Args:
    graph: The graph in which to create the global step. If missing, use default
        graph.

  Returns:
    the tensor representing the global step variable.
  """
  graph = ops.get_default_graph() if graph is None else graph
  globalstep = get_global_step(graph)
  if globalstep is None:
    globalstep = create_global_step(graph)
  return globalstep
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def begin(self, max_steps=None):
    super(GraphDump, self).begin(max_steps=max_steps)
    self._tensors = []
    graph = ops.get_default_graph()
    graph_def = graph.as_graph_def()
    for node in graph_def.node:
      if node.op in self._ignore_ops:
        continue
      logging.info("op=%s name=%s.", node.op, node.name)
      try:
        self._tensors.append(graph.get_tensor_by_name(node.name + ":0"))
      except KeyError:
        pass
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _base_model_fn(features, labels, mode, params):
  model = params['model']
  feature_columns = params['feature_columns']
  head = params['head']

  if mode == model_fn_lib.ModeKeys.TRAIN:
    logits = model.build_model(features, feature_columns, is_training=True)
  elif mode == model_fn_lib.ModeKeys.EVAL:
    logits = model.build_model(features, feature_columns, is_training=False)
  else:
    raise NotImplementedError

  def _train_op_fn(loss):
    global_step = contrib_variables.get_global_step()
    assert global_step
    train_step = model.get_train_step(loss)

    with ops.control_dependencies(train_step):
      with ops.get_default_graph().colocate_with(global_step):
        return state_ops.assign_add(global_step, 1).op

  return head.create_model_fn_ops(
      features=features,
      mode=mode,
      labels=labels,
      train_op_fn=_train_op_fn,
      logits=logits)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def begin(self):
      self._loss_tensor = ops.get_default_graph().get_tensor_by_name(
          KMeansClustering.LOSS_OP_NAME + ':0')
      assert self._loss_tensor is not None
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def as_default(self):
    """Returns a context manager that makes this object the default session.

    Use with the `with` keyword to specify that calls to
    @{tf.Operation.run} or @{tf.Tensor.eval} should be executed in
    this session.

    ```python
    c = tf.constant(..)
    sess = tf.Session()

    with sess.as_default():
      assert tf.get_default_session() is sess
      print(c.eval())
To get the current default session, use @{tf.get_default_session}.

*N.B.* The `as_default` context manager *does not* close the
session when you exit the context, and you must close the session
explicitly.

```python
c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
  print(c.eval())
# ...
with sess.as_default():
  print(c.eval())

sess.close()
```

Alternatively, you can use `with tf.Session():` to create a
session that is automatically closed on exiting the context,
including when an uncaught exception is raised.

*N.B.* The default session is a property of the current thread. If you
create a new thread, and wish to use the default session in that
thread, you must explicitly add a `with sess.as_default():` in that
thread's function.

*N.B.* Entering a `with sess.as_default():` block does not affect
the current default graph. If you are using multiple graphs, and
`sess.graph` is different from the value of @{tf.get_default_graph},
you must explicitly enter a `with sess.graph.as_default():` block
to make `sess.graph` the default graph.

Returns:
  A context manager using this session as the default session.
"""
return ops.default_session(self)

```

项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
  """Creates a constant tensor.

   The resulting tensor is populated with values of type `dtype`, as
   specified by arguments `value` and (optionally) `shape` (see examples
   below).

   The argument `value` can be a constant value, or a list of values of type
   `dtype`. If `value` is a list, then the length of the list must be less
   than or equal to the number of elements implied by the `shape` argument (if
   specified). In the case where the list length is less than the number of
   elements specified by `shape`, the last element in the list will be used
   to fill the remaining entries.

   The argument `shape` is optional. If present, it specifies the dimensions of
   the resulting tensor. If not present, the shape of `value` is used.

   If the argument `dtype` is not specified, then the type is inferred from
   the type of `value`.

   For example:

   ```python
   # Constant 1-D Tensor populated with value list.
   tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7]

   # Constant 2-D tensor populated with scalar value -1.
   tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
                                                [-1. -1. -1.]]

Args: value: A constant value (or list) of output type dtype.

dtype:          The type of the elements of the resulting tensor.

shape:          Optional dimensions of resulting tensor.

name:           Optional name for the tensor.

verify_shape:   Boolean that enables verification of a shape of values.

Returns: A Constant Tensor. """ g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) const_tensor = g.create_op( "Const", [], [dtype_value.type], attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0] return const_tensor ```

项目:lsdc    作者:febert    | 项目源码 | 文件源码
def run(self,
          num_batches=None,
          graph=None,
          session=None,
          start_queues=True,
          initialize_variables=True,
          **kwargs):
    """Builds and runs the columns of the `DataFrame` and yields batches.

    This is a generator that yields a dictionary mapping column names to
    evaluated columns.

    Args:
      num_batches: the maximum number of batches to produce. If none specified,
        the returned value will iterate through infinite batches.
      graph: the `Graph` in which the `DataFrame` should be built.
      session: the `Session` in which to run the columns of the `DataFrame`.
      start_queues: if true, queues will be started before running and halted
        after producting `n` batches.
      initialize_variables: if true, variables will be initialized.
      **kwargs: Additional keyword arguments e.g. `num_epochs`.

    Yields:
      A dictionary, mapping column names to the values resulting from running
      each column for a single batch.
    """
    if graph is None:
      graph = ops.get_default_graph()
    with graph.as_default():
      if session is None:
        session = sess.Session()
      self_built = self.build(**kwargs)
      keys = list(self_built.keys())
      cols = list(self_built.values())
      if initialize_variables:
        if variables.local_variables():
          session.run(variables.initialize_local_variables())
        if variables.all_variables():
          session.run(variables.initialize_all_variables())
      if start_queues:
        coord = coordinator.Coordinator()
        threads = qr.start_queue_runners(sess=session, coord=coord)
      i = 0
      while num_batches is None or i < num_batches:
        i += 1
        try:
          values = session.run(cols)
          yield collections.OrderedDict(zip(keys, values))
        except errors.OutOfRangeError:
          break
      if start_queues:
        coord.request_stop()
        coord.join(threads)
项目:Machine-Learning    作者:sfeng15    | 项目源码 | 文件源码
def _as_meta_graph_def(meta_info_def=None, graph_def=None, saver_def=None,
                       collection_list=None):
  """Construct and returns a `MetaGraphDef` protocol buffer.

  Args:
    meta_info_def: `MetaInfoDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    saver_def: `SaverDef` protocol buffer.
    collection_list: List of string keys to collect.

  Returns:
    MetaGraphDef protocol buffer.

  Raises:
    TypeError: If the arguments are not of the correct proto buffer type.
  """
  # Type check.
  if meta_info_def and not isinstance(meta_info_def,
                                      meta_graph_pb2.MetaGraphDef.MetaInfoDef):
    raise TypeError("meta_info_def must be of type MetaInfoDef, not %s",
                    type(meta_info_def))
  if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
    raise TypeError("graph_def must be of type GraphDef, not %s",
                    type(graph_def))
  if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
    raise TypeError("saver_def must be of type SaverDef, not %s",
                    type(saver_def))

  # Creates a MetaGraphDef proto.
  meta_graph_def = meta_graph_pb2.MetaGraphDef()
  # Adds meta_info_def.
  if meta_info_def:
    meta_graph_def.meta_info_def.MergeFrom(meta_info_def)

  # Adds graph_def or the default.
  if not graph_def:
    meta_graph_def.graph_def.MergeFrom(ops.get_default_graph().as_graph_def())
  else:
    meta_graph_def.graph_def.MergeFrom(graph_def)

  # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
  # pylint: disable=g-explicit-length-test
  if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
    meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
        stripped_op_list_for_graph(meta_graph_def.graph_def))
  # pylint: enable=g-explicit-length-test

  # Adds saver_def.
  if saver_def:
    meta_graph_def.saver_def.MergeFrom(saver_def)

  # Adds collection_list.
  if collection_list:
    clist = collection_list
  else:
    clist = ops.get_all_collection_keys()
  for ctype in clist:
    _add_collection_def(meta_graph_def, ctype)
  return meta_graph_def
项目:Machine-Learning    作者:sfeng15    | 项目源码 | 文件源码
def _import_meta_graph_def(meta_graph_def):
  """Recreates a Graph saved in a `MetaGraphDef` proto.

  This function adds all the nodes from the meta graph def proto to the current
  graph, recreates all the collections, and returns a saver from saver_def.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer.

  Returns:
    A saver constructed from `saver_def` in `meta_graph_def` or None.

    A None value is returned if no variables exist in the `meta_graph_def`
    (i.e., no variables to restore).
  """
  # Gathers the list of nodes we are interested in.
  importer.import_graph_def(meta_graph_def.graph_def, name="")

  # Restores all the other collections.
  for key, col_def in meta_graph_def.collection_def.items():
    kind = col_def.WhichOneof("kind")
    if kind is None:
      logging.error("Cannot identify data type for collection %s. Skipping."
                    % key)
      continue
    from_proto = ops.get_from_proto_function(key)
    if from_proto:
      assert kind == "bytes_list"
      proto_type = ops.get_collection_proto_type(key)
      for value in col_def.bytes_list.value:
        proto = proto_type()
        proto.ParseFromString(value)
        ops.add_to_collection(key, from_proto(proto))
    else:
      field = getattr(col_def, kind)
      if kind == "node_list":
        for value in field.value:
          col_op = ops.get_default_graph().as_graph_element(value)
          ops.add_to_collection(key, col_op)
      elif kind == "int64_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python2 distinguishes between int and long, while Python3 has
        # only int.
        for value in field.value:
          ops.add_to_collection(key, int(value))
      else:
        for value in field.value:
          ops.add_to_collection(key, value)

  if meta_graph_def.HasField("saver_def"):
    return Saver(saver_def=meta_graph_def.saver_def)
  else:
    if variables.all_variables():
      # Return the default saver instance for all graph variables.
      return Saver()
    else:
      # If not graph variables exist, then a Saver cannot be constructed.
      logging.info("Saver not created because there are no variables in the"
                   " graph to restore")
      return None
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def run(self,
          num_batches=None,
          graph=None,
          session=None,
          start_queues=True,
          initialize_variables=True,
          **kwargs):
    """Builds and runs the columns of the `DataFrame` and yields batches.

    This is a generator that yields a dictionary mapping column names to
    evaluated columns.

    Args:
      num_batches: the maximum number of batches to produce. If none specified,
        the returned value will iterate through infinite batches.
      graph: the `Graph` in which the `DataFrame` should be built.
      session: the `Session` in which to run the columns of the `DataFrame`.
      start_queues: if true, queues will be started before running and halted
        after producting `n` batches.
      initialize_variables: if true, variables will be initialized.
      **kwargs: Additional keyword arguments e.g. `num_epochs`.

    Yields:
      A dictionary, mapping column names to the values resulting from running
      each column for a single batch.
    """
    if graph is None:
      graph = ops.get_default_graph()
    with graph.as_default():
      if session is None:
        session = sess.Session()
      self_built = self.build(**kwargs)
      keys = list(self_built.keys())
      cols = list(self_built.values())
      if initialize_variables:
        if variables.local_variables():
          session.run(variables.local_variables_initializer())
        if variables.global_variables():
          session.run(variables.global_variables_initializer())
      if start_queues:
        coord = coordinator.Coordinator()
        threads = qr.start_queue_runners(sess=session, coord=coord)
      i = 0
      while num_batches is None or i < num_batches:
        i += 1
        try:
          values = session.run(cols)
          yield collections.OrderedDict(zip(keys, values))
        except errors.OutOfRangeError:
          break
      if start_queues:
        coord.request_stop()
        coord.join(threads)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def experimental_jit_scope(compile_ops=True):
  """Enable or disable JIT compilation of operators within the scope.

  NOTE: This is an experimental feature.

  The compilation is a hint and only supported on a best-effort basis.

  Example usage:
    with tf.contrib.compiler.experimental_jit_scope():
      c = tf.matmul(a, b)  # compiled
    with tf.contrib.compiler.experimental_jit_scope(compile_ops=False):
      d = tf.matmul(a, c)  # not compiled
    with tf.contrib.compiler.experimental_jit_scope(
        compile_ops=lambda node_def: 'matmul' in node_def.op.lower()):
      e = tf.matmul(a, b) + d  # matmul is compiled, the addition is not.

  Args:
    compile_ops: Whether to enable or disable compilation in the scope.
      Either a Python bool, or a callable that accepts the parameter
      `node_def` and returns a python bool.
  Yields:
    The current scope, enabling or disabling compilation.

  """
  if callable(compile_ops):
    def xla_compile(node_def):
      return attr_value_pb2.AttrValue(b=compile_ops(node_def))
  else:
    xla_compile = attr_value_pb2.AttrValue(b=compile_ops)
  attrs = {"_XlaCompile": xla_compile}

  # TODO(ebrevdo): Keep a global XlaScope counter and here create a
  # special scope that checks if already within a xla scope or creates
  # a new one with a new scope string.  Add a new attr _XlaScope
  # taking this string.  Modify the xla fusion to respect scope
  # boundaries.  Modify gradients_impl to either create a new gradient
  # scope with a suffix from the fw scope or to try to fuse with
  # the fw scope of the given op.  Should be backwards compatible to
  # avoid having to modify Defun compilation attributes.

  # pylint: disable=protected-access
  with ops.get_default_graph()._attr_scope(attrs):
    yield
  # pylint: enable=protected-access