Python tensorflow.python.framework.ops 模块,Operation() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.Operation()

项目:lsdc    作者:febert    | 项目源码 | 文件源码
def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None):
  """For each tensor's pair, swap the end of (t0,t1).

  B0 B1     B0 B1
  |  |    =>  X
  A0 A1     A0 A1

  Args:
    ts0: an object convertible to a list of tf.Tensor.
    ts1: an object convertible to a list of tf.Tensor.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified.
      Any operation within cannot_modify will be left untouched by this
      function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
    TypeError: if can_modify or cannot_modify is not None and cannot be
      converted to a list of tf.Operation.
  """
  return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None):
  """For each tensor's pair, replace the end of t1 by the end of t0.

  B0 B1     B0 B1
  |  |    => |/
  A0 A1     A0 A1

  The end of the tensors in ts1 are left dangling.

  Args:
    ts0: an object convertible to a list of tf.Tensor.
    ts1: an object convertible to a list of tf.Tensor.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified. Any
      operation within cannot_modify will be left untouched by this function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
    TypeError: if can_modify or cannot_modify is not None and cannot be
      converted to a list of tf.Operation.
  """
  return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def reroute_b2a_ts(ts0, ts1, can_modify=None, cannot_modify=None):
  r"""For each tensor's pair, replace the end of t0 by the end of t1.

  B0 B1     B0 B1
  |  |    =>  \|
  A0 A1     A0 A1

  The end of the tensors in ts0 are left dangling.

  Args:
    ts0: an object convertible to a list of tf.Tensor.
    ts1: an object convertible to a list of tf.Tensor.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified.
      Any operation within cannot_modify will be left untouched by this
      function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
    TypeError: if can_modify or cannot_modify is not None and cannot be
      converted to a list of tf.Operation.
  """
  return _reroute_ts(ts0, ts1, _RerouteMode.b2a, can_modify, cannot_modify)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def remove_control_inputs(op, cops):
  """Remove the control inputs cops from co.

  Warning: this function is directly manipulating the internals of the tf.Graph.

  Args:
    op: a tf.Operation from which to remove the control inputs.
    cops: an object convertible to a list of tf.Operation.
  Raises:
    TypeError: if op is not a tf.Operation
    ValueError: if any cop in cops is not a control input of op.
  """
  if not isinstance(op, tf_ops.Operation):
    raise TypeError("Expected a tf.Operation, got: {}", type(op))
  cops = util.make_list_of_op(cops, allow_graph=False)
  for cop in cops:
    if cop not in op.control_inputs:
      raise ValueError("{} is not a control_input of {}".format(op.name,
                                                                cop.name))
  # pylint: disable=protected-access
  op._control_inputs = [cop for cop in op._control_inputs if cop not in cops]
  op._recompute_node_def()
  # pylint: enable=protected-access
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_consuming_ops(ts):
  """Return all the consuming ops of the tensors in ts.

  Args:
    ts: a list of tf.Tensor
  Returns:
    A list of all the consuming tf.Operation of the tensors in ts.
  Raises:
    TypeError: if ts cannot be converted to a list of tf.Tensor.
  """
  ts = make_list_of_t(ts, allow_graph=False)
  ops = []
  for t in ts:
    for op in t.consumers():
      if op not in ops:
        ops.append(op)
  return ops
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_copied_op(org_instance, graph, scope=""):
  """Given an `Operation` instance from some `Graph`, returns
  its namesake from `graph`, under the specified scope
  (default `""`).

  If a copy of `org_instance` is present in `graph` under the given
  `scope`, it will be returned.

  Args:
  org_instance: An `Operation` from some `Graph`.
  graph: The `Graph` to be searched for a copr of `org_instance`.
  scope: The scope `org_instance` is present in.

  Returns:
      The `Operation` copy from `graph`.
  """

  #The name of the copied instance
  if scope != '':
    new_name = scope + '/' + org_instance.name
  else:
    new_name = org_instance.name

  return graph.as_graph_element(new_name, allow_tensor=True,
                                allow_operation=True)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def testKFeatureTrainingConstruction(self):
    # pylint: disable=W0612
    data = constant_op.constant(
        [[random.uniform(-1, 1) for i in range(self.params.num_features)]
         for _ in range(100)])

    labels = [1 for _ in range(100)]

    with variable_scope.variable_scope(
        "KFeatureDecisionsToDataThenNNTest.testKFeatureTrainingContruction"):
      graph_builder = (
          k_feature_decisions_to_data_then_nn.KFeatureDecisionsToDataThenNN(
              self.params))
      graph = graph_builder.training_graph(data, labels, None)

      self.assertTrue(isinstance(graph, Operation))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None):
  """For each tensor's pair, replace the end of t1 by the end of t0.

  B0 B1     B0 B1
  |  |    => |/
  A0 A1     A0 A1

  The end of the tensors in ts1 are left dangling.

  Args:
    ts0: an object convertible to a list of `tf.Tensor`.
    ts1: an object convertible to a list of `tf.Tensor`.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified. Any
      operation within cannot_modify will be left untouched by this function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
    TypeError: if can_modify or cannot_modify is not None and cannot be
      converted to a list of tf.Operation.
  """
  return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def reroute_b2a_ts(ts0, ts1, can_modify=None, cannot_modify=None):
  r"""For each tensor's pair, replace the end of t0 by the end of t1.

  B0 B1     B0 B1
  |  |    =>  \|
  A0 A1     A0 A1

  The end of the tensors in ts0 are left dangling.

  Args:
    ts0: an object convertible to a list of `tf.Tensor`.
    ts1: an object convertible to a list of `tf.Tensor`.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified.
      Any operation within cannot_modify will be left untouched by this
      function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
    TypeError: if can_modify or cannot_modify is not None and cannot be
      converted to a list of tf.Operation.
  """
  return _reroute_ts(ts0, ts1, _RerouteMode.b2a, can_modify, cannot_modify)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def remove_control_inputs(op, cops):
  """Remove the control inputs cops from co.

  Warning: this function is directly manipulating the internals of the
  `tf.Graph`.

  Args:
    op: a `tf.Operation` from which to remove the control inputs.
    cops: an object convertible to a list of `tf.Operation`.
  Raises:
    TypeError: if op is not a `tf.Operation`.
    ValueError: if any cop in cops is not a control input of op.
  """
  if not isinstance(op, tf_ops.Operation):
    raise TypeError("Expected a tf.Operation, got: {}", type(op))
  cops = util.make_list_of_op(cops, allow_graph=False)
  for cop in cops:
    if cop not in op.control_inputs:
      raise ValueError("{} is not a control_input of {}".format(op.name,
                                                                cop.name))
  # pylint: disable=protected-access
  op._control_inputs = [cop for cop in op._control_inputs if cop not in cops]
  op._recompute_node_def()
  # pylint: enable=protected-access
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def add_control_inputs(op, cops):
  """Add the control inputs cops to co.

  Warning: this function is directly manipulating the internals of the tf.Graph.

  Args:
    op: a tf.Operation to which the control inputs are added.
    cops: an object convertible to a list of `tf.Operation`.
  Raises:
    TypeError: if op is not a tf.Operation
    ValueError: if any cop in cops is already a control input of op.
  """
  if not isinstance(op, tf_ops.Operation):
    raise TypeError("Expected a tf.Operation, got: {}", type(op))
  cops = util.make_list_of_op(cops, allow_graph=False)
  for cop in cops:
    if cop in op.control_inputs:
      raise ValueError("{} is already a control_input of {}".format(op.name,
                                                                    cop.name))
  # pylint: disable=protected-access
  op._control_inputs += cops
  op._recompute_node_def()
  # pylint: enable=protected-access
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_consuming_ops(ts):
  """Return all the consuming ops of the tensors in ts.

  Args:
    ts: a list of `tf.Tensor`
  Returns:
    A list of all the consuming `tf.Operation` of the tensors in `ts`.
  Raises:
    TypeError: if ts cannot be converted to a list of `tf.Tensor`.
  """
  ts = make_list_of_t(ts, allow_graph=False)
  ops = []
  for t in ts:
    for op in t.consumers():
      if op not in ops:
        ops.append(op)
  return ops
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def __init__(self, graph):
    """Create a dictionary of control-output dependencies.

    Args:
      graph: a `tf.Graph`.
    Returns:
      A dictionary where a key is a `tf.Operation` instance and the
         corresponding value is a list of all the ops which have the key
         as one of their control-input dependencies.
    Raises:
      TypeError: graph is not a `tf.Graph`.
    """
    if not isinstance(graph, tf_ops.Graph):
      raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
    self._control_outputs = {}
    self._graph = graph
    self._version = None
    self._build()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_copied_op(org_instance, graph, scope=""):
  """Given an `Operation` instance from some `Graph`, returns
  its namesake from `graph`, under the specified scope
  (default `""`).

  If a copy of `org_instance` is present in `graph` under the given
  `scope`, it will be returned.

  Args:
  org_instance: An `Operation` from some `Graph`.
  graph: The `Graph` to be searched for a copr of `org_instance`.
  scope: The scope `org_instance` is present in.

  Returns:
      The `Operation` copy from `graph`.
  """

  #The name of the copied instance
  if scope != '':
    new_name = scope + '/' + org_instance.name
  else:
    new_name = org_instance.name

  return graph.as_graph_element(new_name, allow_tensor=True,
                                allow_operation=True)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def assign_renamed_collections_handler(info, elem, elem_):
  """Add the transformed elem to the (renamed) collections of elem.

  A collection is renamed only if is not a known key, as described in
  `tf.GraphKeys`.

  Args:
    info: Transform._TmpInfo instance.
    elem: the original element (`tf.Tensor` or `tf.Operation`)
    elem_: the transformed element
  """
  known_collection_names = util.get_predefined_collection_names()
  for name, collection in iteritems(info.collections):
    if elem not in collection:
      continue

    if name in known_collection_names:
      transformed_name = name
    else:
      transformed_name = info.new_name(name)
    info.graph_.add_to_collection(transformed_name, elem_)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None):
  """For each tensor's pair, swap the end of (t0,t1).

  B0 B1     B0 B1
  |  |    =>  X
  A0 A1     A0 A1

  Args:
    ts0: an object convertible to a list of `tf.Tensor`.
    ts1: an object convertible to a list of `tf.Tensor`.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified.
      Any operation within cannot_modify will be left untouched by this
      function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
    TypeError: if can_modify or cannot_modify is not None and cannot be
      converted to a list of tf.Operation.
  """
  return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None):
  """For each tensor's pair, replace the end of t1 by the end of t0.

  B0 B1     B0 B1
  |  |    => |/
  A0 A1     A0 A1

  The end of the tensors in ts1 are left dangling.

  Args:
    ts0: an object convertible to a list of `tf.Tensor`.
    ts1: an object convertible to a list of `tf.Tensor`.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified. Any
      operation within cannot_modify will be left untouched by this function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
    TypeError: if can_modify or cannot_modify is not None and cannot be
      converted to a list of tf.Operation.
  """
  return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def remove_control_inputs(op, cops):
  """Remove the control inputs cops from co.

  Warning: this function is directly manipulating the internals of the
  `tf.Graph`.

  Args:
    op: a `tf.Operation` from which to remove the control inputs.
    cops: an object convertible to a list of `tf.Operation`.
  Raises:
    TypeError: if op is not a `tf.Operation`.
    ValueError: if any cop in cops is not a control input of op.
  """
  if not isinstance(op, tf_ops.Operation):
    raise TypeError("Expected a tf.Operation, got: {}", type(op))
  cops = util.make_list_of_op(cops, allow_graph=False)
  for cop in cops:
    if cop not in op.control_inputs:
      raise ValueError("{} is not a control_input of {}".format(op.name,
                                                                cop.name))
  # pylint: disable=protected-access
  op._control_inputs = [cop for cop in op._control_inputs if cop not in cops]
  op._recompute_node_def()
  # pylint: enable=protected-access
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def add_control_inputs(op, cops):
  """Add the control inputs cops to co.

  Warning: this function is directly manipulating the internals of the tf.Graph.

  Args:
    op: a tf.Operation to which the control inputs are added.
    cops: an object convertible to a list of `tf.Operation`.
  Raises:
    TypeError: if op is not a tf.Operation
    ValueError: if any cop in cops is already a control input of op.
  """
  if not isinstance(op, tf_ops.Operation):
    raise TypeError("Expected a tf.Operation, got: {}", type(op))
  cops = util.make_list_of_op(cops, allow_graph=False)
  for cop in cops:
    if cop in op.control_inputs:
      raise ValueError("{} is already a control_input of {}".format(op.name,
                                                                    cop.name))
  # pylint: disable=protected-access
  op._control_inputs += cops
  op._recompute_node_def()
  # pylint: enable=protected-access
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def get_consuming_ops(ts):
  """Return all the consuming ops of the tensors in ts.

  Args:
    ts: a list of `tf.Tensor`
  Returns:
    A list of all the consuming `tf.Operation` of the tensors in `ts`.
  Raises:
    TypeError: if ts cannot be converted to a list of `tf.Tensor`.
  """
  ts = make_list_of_t(ts, allow_graph=False)
  ops = []
  for t in ts:
    for op in t.consumers():
      if op not in ops:
        ops.append(op)
  return ops
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def __init__(self, graph):
    """Create a dictionary of control-output dependencies.

    Args:
      graph: a `tf.Graph`.
    Returns:
      A dictionary where a key is a `tf.Operation` instance and the
         corresponding value is a list of all the ops which have the key
         as one of their control-input dependencies.
    Raises:
      TypeError: graph is not a `tf.Graph`.
    """
    if not isinstance(graph, tf_ops.Graph):
      raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
    self._control_outputs = {}
    self._graph = graph
    self._version = None
    self._build()
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def get_copied_op(org_instance, graph, scope=""):
  """Given an `Operation` instance from some `Graph`, returns
  its namesake from `graph`, under the specified scope
  (default `""`).

  If a copy of `org_instance` is present in `graph` under the given
  `scope`, it will be returned.

  Args:
  org_instance: An `Operation` from some `Graph`.
  graph: The `Graph` to be searched for a copr of `org_instance`.
  scope: The scope `org_instance` is present in.

  Returns:
      The `Operation` copy from `graph`.
  """

  #The name of the copied instance
  if scope != '':
    new_name = scope + '/' + org_instance.name
  else:
    new_name = org_instance.name

  return graph.as_graph_element(new_name, allow_tensor=True,
                                allow_operation=True)
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def build_results(self, values):
    if not values:
      # 'Operation' case
      return None
    else:
      return self._contraction_fn(values)
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def __init__(self, graph, fetches, feeds, feed_handles=None):
    """Creates a fetch handler.

    Args:
      graph: Graph of the fetches.   Used to check for fetchability
        and to convert all fetches to tensors or ops as needed.
      fetches: An arbitrary fetch structure: singleton, list, tuple,
        namedtuple, or dict.
      feeds: A feed dict where keys are Tensors.
      feed_handles: A dict from feed Tensors to TensorHandle objects used as
        direct feeds.
    """
    with graph.as_default():
      self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    self._fetches = []
    self._targets = []
    self._feeds = feeds
    self._feed_handles = feed_handles or {}
    self._ops = []
    self._fetch_handles = {}
    for fetch in self._fetch_mapper.unique_fetches():
      if isinstance(fetch, ops.Operation):
        self._assert_fetchable(graph, fetch)
        self._targets.append(fetch)
        self._ops.append(True)
      else:
        self._assert_fetchable(graph, fetch.op)
        self._fetches.append(fetch)
        self._ops.append(False)
      # Remember the fetch if it is for a tensor handle.
      if (isinstance(fetch, ops.Tensor) and
          (fetch.op.type == 'GetSessionHandle' or
           fetch.op.type == 'GetSessionHandleV2')):
        self._fetch_handles[fetch] = fetch.op.inputs[0].dtype
    self._final_fetches = [x for x in self._fetches if x not in feeds]
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def _assert_fetchable(self, graph, op):
    if not graph.is_fetchable(op):
      raise ValueError(
          'Operation %r has been marked as not fetchable.' % op.name)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _finalize_positive_filter(self, elem):
    """Convert to a filter function."""
    if select.can_be_regex(elem):
      regex_ = select.make_regex(elem)
      return lambda op, regex=regex_: regex.search(op.name) is not None
    elif isinstance(elem, tf_ops.Operation):
      return lambda op, match_op=elem: op is match_op
    elif callable(elem):
      return elem
    elif elem is True:
      return lambda op: True
    else:
      raise ValueError("Cannot finalize the positive filter: {}".format(elem))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def __call__(self, op):
    """Evaluate if the op matches or not."""
    if not isinstance(op, tf_ops.Operation):
      raise TypeError("Expect tf.Operation, got: {}".format(type(op)))
    for positive_filter in self.positive_filters:
      if not positive_filter(op):
        return False
    if self.input_op_matches is not None:
      if len(op.inputs) != len(self.input_op_matches):
        return False
      for input_t, input_op_match in zip(op.inputs, self.input_op_matches):
        if input_op_match is None:
          continue
        if not input_op_match(input_t.op):
          return False
    if self.control_input_op_matches is not None:
      if len(op.control_inputs) != len(self.control_input_op_matches):
        return False
      for cinput_op, cinput_op_match in zip(op.control_inputs,
                                            self.control_input_op_matches):
        if cinput_op_match is None:
          continue
        if not cinput_op_match(cinput_op):
          return False
    if self.output_op_matches is not None:
      if len(op.outputs) != len(self.output_op_matches):
        return False
      for output_t, output_op_matches in zip(op.outputs,
                                             self.output_op_matches):
        if output_op_matches is None:
          continue
        if len(output_t.consumers()) != len(output_op_matches):
          return False
        for consumer_op, consumer_op_match in zip(output_t.consumers(),
                                                  output_op_matches):
          if consumer_op_match is None:
            continue
          if not consumer_op_match(consumer_op):
            return False
    return True
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def assign_renamed_collections_handler(info, elem, elem_):
  """Add the transformed elem to the (renamed) collections of elem.

  Args:
    info: Transform._Info instance.
    elem: the original element (tf.Tensor or tf.Operation)
    elem_: the transformed element
  """
  # TODO(fkp): handle known special cases
  for name, collection in iteritems(
      elem.graph._collections):  # pylint: disable=protected-access
    if elem not in collection:
      continue
    collection_name_ = info.transformer.new_name(name)
    info.graph_.add_to_collection(collection_name_, elem_)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_transformed_map(self, top):
      """Return the correct container depending on the type of `top`."""
      if isinstance(top, tf_ops.Operation):
        return self._transformed_ops
      elif isinstance(top, tf_ops.Tensor):
        return self._transformed_ts
      else:
        raise TypeError(
            "Expected a tf.Tensor or a tf.Operation, got a {}".format(
                type(top)))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _transform_op(self, op):
    """Transform a tf.Operation.

    Args:
      op: the operation to be transformed.
    Returns:
      The transformed operation.
    """
    if op in self._info.transformed_ops:
      return self._info.transformed_ops[op]

    op_ = self.transform_op_handler(self._info, op)

    # Add to all the active control dependencies
    # pylint: disable=protected-access
    self._info.graph_._record_op_seen_by_control_dependencies(op_)

    # All to all the active devices
    for device_function in reversed(self._info.graph_._device_function_stack):
      if device_function is None:
        break
      op_._set_device(device_function(op_))
    # pylint: enable=protected-access

    # TODO(fkp): Establish clear policy about what context managers are allowed.

    # assign to collection
    if op is not op_:
      self.assign_collections_handler(self._info, op, op_)

    self._info.transformed_ops[op] = op_
    return op_
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_unique_graph(tops, check_types=None, none_if_empty=False):
  """Return the unique graph used by the all the elements in tops.

  Args:
    tops: list of elements to check (usually a list of tf.Operation and/or
      tf.Tensor). Or a tf.Graph.
    check_types: check that the element in tops are of given type(s). If None,
      the types (tf.Operation, tf.Tensor) are used.
    none_if_empty: don't raise an error if tops is an empty list, just return
      None.
  Returns:
    The unique graph used by all the tops.
  Raises:
    TypeError: if tops is not a iterable of tf.Operation.
    ValueError: if the graph is not unique.
  """
  if isinstance(tops, tf_ops.Graph):
    return tops
  if not is_iterable(tops):
    raise TypeError("{} is not iterable".format(type(tops)))
  if check_types is None:
    check_types = (tf_ops.Operation, tf_ops.Tensor)
  elif not is_iterable(check_types):
    check_types = (check_types,)
  g = None
  for op in tops:
    if not isinstance(op, check_types):
      raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
          t) for t in check_types]), type(op)))
    if g is None:
      g = op.graph
    elif g is not op.graph:
      raise ValueError("Operation {} does not belong to given graph".format(op))
  if g is None and not none_if_empty:
    raise ValueError("Can't find the unique graph of an empty list")
  return g
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False):
  """Convert ops to a list of tf.Operation.

  Args:
    ops: can be an iterable of tf.Operation, a tf.Graph or a single operation.
    check_graph: if True check if all the operations belong to the same graph.
    allow_graph: if False a tf.Graph cannot be converted.
    ignore_ts: if True, silently ignore tf.Tensor.
  Returns:
    A newly created list of tf.Operation.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation or,
     if check_graph is True, if all the ops do not belong to the same graph.
  """
  if isinstance(ops, tf_ops.Graph):
    if allow_graph:
      return ops.get_operations()
    else:
      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
  else:
    if not is_iterable(ops):
      ops = [ops]
    if not ops:
      return []
    if check_graph:
      check_types = None if ignore_ts else tf_ops.Operation
      get_unique_graph(ops, check_types=check_types)
    return [op for op in ops if isinstance(op, tf_ops.Operation)]


# TODO(fkp): move this function in tf.Graph?
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
  """Convert ts to a list of tf.Tensor.

  Args:
    ts: can be an iterable of tf.Tensor, a tf.Graph or a single tensor.
    check_graph: if True check if all the tensors belong to the same graph.
    allow_graph: if False a tf.Graph cannot be converted.
    ignore_ops: if True, silently ignore tf.Operation.
  Returns:
    A newly created list of tf.Tensor.
  Raises:
    TypeError: if ts cannot be converted to a list of tf.Tensor or,
     if check_graph is True, if all the ops do not belong to the same graph.
  """
  if isinstance(ts, tf_ops.Graph):
    if allow_graph:
      return get_tensors(ts)
    else:
      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
  else:
    if not is_iterable(ts):
      ts = [ts]
    if not ts:
      return []
    if check_graph:
      check_types = None if ignore_ops else tf_ops.Tensor
      get_unique_graph(ts, check_types=check_types)
    return [t for t in ts if isinstance(t, tf_ops.Tensor)]
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_generating_ops(ts):
  """Return all the generating ops of the tensors in ts.

  Args:
    ts: a list of tf.Tensor
  Returns:
    A list of all the generating tf.Operation of the tensors in ts.
  Raises:
    TypeError: if ts cannot be converted to a list of tf.Tensor.
  """
  ts = make_list_of_t(ts, allow_graph=False)
  return [t.op for t in ts]
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def testTrainingConstruction(self):
    # pylint: disable=W0612
    data = constant_op.constant(
        [[random.uniform(-1, 1) for i in range(self.params.num_features)]
         for _ in range(100)])

    labels = [1 for _ in range(100)]

    with variable_scope.variable_scope(
        "DecisionsToDataThenNNTest_testTrainingContruction"):
      graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
          self.params)
      graph = graph_builder.training_graph(data, labels, None)

      self.assertTrue(isinstance(graph, Operation))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def testTrainingConstruction(self):
    # pylint: disable=W0612
    data = constant_op.constant(
        [[random.uniform(-1, 1) for i in range(self.params.num_features)]
         for _ in range(100)])

    labels = [1 for _ in range(100)]

    with variable_scope.variable_scope(
        "ForestToDataThenNNTest.testTrainingContruction"):
      graph_builder = forest_to_data_then_nn.ForestToDataThenNN(self.params)
      graph = graph_builder.training_graph(data, labels, None)

      self.assertTrue(isinstance(graph, Operation))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _finalize_positive_filter(self, elem):
    """Convert to a filter function."""
    if select.can_be_regex(elem):
      regex_ = select.make_regex(elem)
      return lambda op, regex=regex_: regex.search(op.name) is not None
    elif isinstance(elem, tf_ops.Operation):
      return lambda op, match_op=elem: op is match_op
    elif callable(elem):
      return elem
    elif elem is True:
      return lambda op: True
    else:
      raise ValueError("Cannot finalize the positive filter: {}".format(elem))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def assign_renamed_collections_handler(info, elem, elem_):
  """Add the transformed elem to the (renamed) collections of elem.

  Args:
    info: Transform._Info instance.
    elem: the original element (`tf.Tensor` or `tf.Operation`)
    elem_: the transformed element
  """
  # TODO(fkp): handle known special cases
  for name, collection in iteritems(
      elem.graph._collections):  # pylint: disable=protected-access
    if elem not in collection:
      continue
    collection_name_ = info.transformer.new_name(name)
    info.graph_.add_to_collection(collection_name_, elem_)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_transformed_map(self, top):
      """Return the correct container depending on the type of `top`."""
      if isinstance(top, tf_ops.Operation):
        return self._transformed_ops
      elif isinstance(top, tf_ops.Tensor):
        return self._transformed_ts
      else:
        raise TypeError(
            "Expected a tf.Tensor or a tf.Operation, got a {}".format(
                type(top)))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def __init__(self):
    """Transformer constructor.

    The following members can be modified:
    transform_op_handler: handle the transformation of a `tf.Operation`.
      This handler defaults to a simple copy.
    assign_collections_handler: handle the assignment of collections.
      This handler defaults to assigning new collections created under the
      given name-scope.
    transform_external_input_handler: handle the transform of the inputs to
      the given subgraph. This handler defaults to creating placeholders
      instead of the ops just before the input tensors of the subgraph.
    transform_external_hidden_input_handler: handle the transform of the
      hidden inputs of the subgraph, that is, the inputs which are not listed
      in sgv.inputs. This handler defaults to a transform which keep the same
      input if the source and destination graphs are the same, otherwise
      use placeholders.
    transform_original_op_handler: handle the transform of original_op. This
      handler defaults to transforming original_op only if they are in the
      subgraph, otherwise they are ignored.
    """

    # handlers
    self.transform_op_handler = copy_op_handler
    self.transform_control_input_handler = transform_op_if_inside_handler
    self.assign_collections_handler = assign_renamed_collections_handler
    self.transform_external_input_handler = replace_t_with_placeholder_handler
    self.transform_external_hidden_input_handler = keep_t_if_possible_handler
    self.transform_original_op_handler = transform_op_if_inside_handler

    # temporary per-call variable
    self._info = None
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _transform_op(self, op):
    """Transform a tf.Operation.

    Args:
      op: the operation to be transformed.
    Returns:
      The transformed operation.
    """
    if op in self._info.transformed_ops:
      return self._info.transformed_ops[op]

    op_ = self.transform_op_handler(self._info, op)

    # Add to all the active control dependencies
    # pylint: disable=protected-access
    self._info.graph_._record_op_seen_by_control_dependencies(op_)

    # All to all the active devices
    for device_function in reversed(self._info.graph_._device_function_stack):
      if device_function is None:
        break
      op_._set_device(device_function(op_))
    # pylint: enable=protected-access

    # TODO(fkp): Establish clear policy about what context managers are allowed.

    # assign to collection
    if op is not op_:
      self.assign_collections_handler(self._info, op, op_)

    self._info.transformed_ops[op] = op_
    return op_
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_unique_graph(tops, check_types=None, none_if_empty=False):
  """Return the unique graph used by the all the elements in tops.

  Args:
    tops: list of elements to check (usually a list of tf.Operation and/or
      tf.Tensor). Or a tf.Graph.
    check_types: check that the element in tops are of given type(s). If None,
      the types (tf.Operation, tf.Tensor) are used.
    none_if_empty: don't raise an error if tops is an empty list, just return
      None.
  Returns:
    The unique graph used by all the tops.
  Raises:
    TypeError: if tops is not a iterable of tf.Operation.
    ValueError: if the graph is not unique.
  """
  if isinstance(tops, tf_ops.Graph):
    return tops
  if not is_iterable(tops):
    raise TypeError("{} is not iterable".format(type(tops)))
  if check_types is None:
    check_types = (tf_ops.Operation, tf_ops.Tensor)
  elif not is_iterable(check_types):
    check_types = (check_types,)
  g = None
  for op in tops:
    if not isinstance(op, check_types):
      raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
          t) for t in check_types]), type(op)))
    if g is None:
      g = op.graph
    elif g is not op.graph:
      raise ValueError("Operation {} does not belong to given graph".format(op))
  if g is None and not none_if_empty:
    raise ValueError("Can't find the unique graph of an empty list")
  return g
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
  """Convert ts to a list of `tf.Tensor`.

  Args:
    ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
    check_graph: if `True` check if all the tensors belong to the same graph.
    allow_graph: if `False` a `tf.Graph` cannot be converted.
    ignore_ops: if `True`, silently ignore `tf.Operation`.
  Returns:
    A newly created list of `tf.Tensor`.
  Raises:
    TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
     if `check_graph` is `True`, if all the ops do not belong to the same graph.
  """
  if isinstance(ts, tf_ops.Graph):
    if allow_graph:
      return get_tensors(ts)
    else:
      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
  else:
    if not is_iterable(ts):
      ts = [ts]
    if not ts:
      return []
    if check_graph:
      check_types = None if ignore_ops else tf_ops.Tensor
      get_unique_graph(ts, check_types=check_types)
    return [t for t in ts if isinstance(t, tf_ops.Tensor)]
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def get_generating_ops(ts):
  """Return all the generating ops of the tensors in `ts`.

  Args:
    ts: a list of `tf.Tensor`
  Returns:
    A list of all the generating `tf.Operation` of the tensors in `ts`.
  Raises:
    TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
  """
  ts = make_list_of_t(ts, allow_graph=False)
  return [t.op for t in ts]
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def testTrainingConstruction(self):
    # pylint: disable=W0612
    data = constant_op.constant(
        [[random.uniform(-1, 1) for i in range(self.params.num_features)]
         for _ in range(100)])

    labels = [1 for _ in range(100)]

    with variable_scope.variable_scope(
        "DecisionsToDataThenNNTest_testTrainingContruction"):
      graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
          self.params)
      graph = graph_builder.training_graph(data, labels, None)

      self.assertTrue(isinstance(graph, Operation))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def testTrainingConstruction(self):
    # pylint: disable=W0612
    data = constant_op.constant(
        [[random.uniform(-1, 1) for i in range(self.params.num_features)]
         for _ in range(100)])

    labels = [1 for _ in range(100)]

    with variable_scope.variable_scope(
        "ForestToDataThenNNTest.testTrainingContruction"):
      graph_builder = forest_to_data_then_nn.ForestToDataThenNN(self.params)
      graph = graph_builder.training_graph(data, labels, None)

      self.assertTrue(isinstance(graph, Operation))
项目:imperative    作者:yaroslavvb    | 项目源码 | 文件源码
def op(self):
    """Method for compatibility with Tensor."""
    node_def = graph_pb2.NodeDef()
    node_def.name = "imperative-dummy-node"
    node_def.input.extend(["dummy1", "dummy2", "dummy3"])

    dummy_input1 = array_ops.placeholder(self.dtype)
    dummy_input2 = array_ops.placeholder(self.dtype)
    dummy_input3 = array_ops.placeholder(self.dtype)
    dummy_op = tf_ops.Operation(node_def, tf_ops.Graph(), inputs=[dummy_input1,
                                                                  dummy_input2,
                                                                  dummy_input3])

    return dummy_op
项目:imperative    作者:yaroslavvb    | 项目源码 | 文件源码
def op(self):
    """Method for compatibility with Tensor."""
    node_def = graph_pb2.NodeDef()
    node_def.name = "imperative-dummy-node"
    node_def.input.extend(["dummy1", "dummy2", "dummy3"])

    dummy_input1 = array_ops.placeholder(self.dtype)
    dummy_input2 = array_ops.placeholder(self.dtype)
    dummy_input3 = array_ops.placeholder(self.dtype)
    dummy_op = tf_ops.Operation(node_def, tf_ops.Graph(), inputs=[dummy_input1,
                                                                  dummy_input2,
                                                                  dummy_input3])

    return dummy_op
项目:polyaxon    作者:polyaxon    | 项目源码 | 文件源码
def _check_is_tensor_or_operation(x, name):
    if not isinstance(x, (ops.Operation, ops.Tensor)):
        raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _finalize_positive_filter(self, elem):
    """Convert to a filter function."""
    if select.can_be_regex(elem):
      regex_ = select.make_regex(elem)
      return lambda op, regex=regex_: regex.search(op.name) is not None
    elif isinstance(elem, tf_ops.Operation):
      return lambda op, match_op=elem: op is match_op
    elif callable(elem):
      return elem
    elif elem is True:
      return lambda op: True
    else:
      raise ValueError("Cannot finalize the positive filter: {}".format(elem))