Python tensorflow.python.framework.graph_util 模块,extract_sub_graph() 实例源码

我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用tensorflow.python.framework.graph_util.extract_sub_graph()

项目:MMdnn    作者:Microsoft    | 项目源码 | 文件源码
def __init__(self, input_args, dest_nodes = None):
        super(TensorflowParser, self).__init__()

        # load model files into Keras graph
        from six import string_types as _string_types
        if isinstance(input_args, _string_types):
            model = TensorflowParser._load_meta(input_args)
        elif isinstance(input_args, tuple):
            model = TensorflowParser._load_meta(input_args[0])
            self.ckpt_data = TensorflowParser._load_weights(input_args[1])
            self.weight_loaded = True

        if dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes.split(','))

        # Build network graph
        self.tf_graph =  TensorflowGraph(model)
        self.tf_graph.build()
项目:benderthon    作者:xmartlabs    | 项目源码 | 文件源码
def save_graph_only(sess, output_file_path, output_node_names, as_text=False):
    """Save a small version of the graph based on a session and the output node names."""
    for node in sess.graph_def.node:
        node.device = ''
    graph_def = graph_util.extract_sub_graph(sess.graph_def, output_node_names)
    output_dir, output_filename = os.path.split(output_file_path)
    graph_io.write_graph(graph_def, output_dir, output_filename, as_text=as_text)
项目:tensorflow-for-poets-2    作者:googlecodelabs    | 项目源码 | 文件源码
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names)
项目:MobileNet    作者:Zehaos    | 项目源码 | 文件源码
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names)
项目:TensorFlow_DCIGN    作者:yselivonchyk    | 项目源码 | 文件源码
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
  """Removes unused nodes from a graph."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  input_node_names_list = input_node_names.split(",")
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names_list:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))