Python tensorflow.python.framework.graph_util 模块,convert_variables_to_constants() 实例源码

我们从Python开源项目中,提取了以下32个代码示例,用于说明如何使用tensorflow.python.framework.graph_util.convert_variables_to_constants()

项目:lsdc    作者:febert    | 项目源码 | 文件源码
def setUp(self):
    self.base_path = os.path.join(tf.test.get_temp_dir(), "no_vars")
    if not os.path.exists(self.base_path):
      os.mkdir(self.base_path)

    # Create a simple graph with a variable, then convert variables to
    # constants and export the graph.
    with tf.Graph().as_default() as g:
      x = tf.placeholder(tf.float32, name="x")
      w = tf.Variable(3.0)
      y = tf.sub(w * x, 7.0, name="y")  # pylint: disable=unused-variable
      tf.add_to_collection("meta", "this is meta")

      with self.test_session(graph=g) as session:
        tf.initialize_all_variables().run()
        new_graph_def = graph_util.convert_variables_to_constants(
            session, g.as_graph_def(), ["y"])

      filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME)
      tf.train.export_meta_graph(
          filename, graph_def=new_graph_def, collection_list=["meta"])
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def setUp(self):
    self.base_path = os.path.join(tf.test.get_temp_dir(), "no_vars")
    if not os.path.exists(self.base_path):
      os.mkdir(self.base_path)

    # Create a simple graph with a variable, then convert variables to
    # constants and export the graph.
    with tf.Graph().as_default() as g:
      x = tf.placeholder(tf.float32, name="x")
      w = tf.Variable(3.0)
      y = tf.sub(w * x, 7.0, name="y")  # pylint: disable=unused-variable
      tf.add_to_collection("meta", "this is meta")

      with self.test_session(graph=g) as session:
        tf.global_variables_initializer().run()
        new_graph_def = graph_util.convert_variables_to_constants(
            session, g.as_graph_def(), ["y"])

      filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME)
      tf.train.export_meta_graph(
          filename, graph_def=new_graph_def, collection_list=["meta"])
项目:hive    作者:metaflow-ai    | 项目源码 | 文件源码
def freeze_graph(model_folder):
    from tensorflow.python.framework import graph_util

    checkpoint = tf.train.get_checkpoint_state(args.model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_node_names = "Accuracy/predictions"
    output_graph = absolute_model_folder + "/frozen_model_2.pb"
    clear_devices = True

    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, output_node_names.split(","))

        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
项目:TFExperiments    作者:gnperdue    | 项目源码 | 文件源码
def freeze_graph(
        model_dir, output_nodes_list, output_graph_name='frozen_model.pb'
):
    """
    reduce a saved model and metadata down to a deployable file
    """
    from tensorflow.python.framework import graph_util

    LOGGER.info('Attempting to freeze graph at {}'.format(model_dir))
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    if input_checkpoint is None:
        LOGGER.error('Cannot load checkpoint at {}'.format(model_dir))
        return None

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/' + output_graph_name
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
                                       clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, output_nodes_list
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        LOGGER.info('Froze graph with {} ops'.format(
            len(output_graph_def.node)
        ))

    return output_graph
项目:TFExperiments    作者:gnperdue    | 项目源码 | 文件源码
def freeze_graph(
        model_dir, output_nodes_list, output_graph_name='frozen_model.pb'
):
    """
    reduce a saved model and metadata down to a deployable file; following
    https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc

    output_nodes_list = e.g., ['softmax_linear/logits']
    """
    from tensorflow.python.framework import graph_util

    LOGGER.info('Attempting to freeze graph at {}'.format(model_dir))
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    if input_checkpoint is None:
        LOGGER.error('Cannot load checkpoint at {}'.format(model_dir))
        return None

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/' + output_graph_name
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
                                       clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, output_nodes_list
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        LOGGER.info('Froze graph with {} ops'.format(
            len(output_graph_def.node)
        ))

    return output_graph
项目:TFExperiments    作者:gnperdue    | 项目源码 | 文件源码
def freeze_graph(
        model_dir, output_nodes_list, output_graph_name='frozen_model.pb'
):
    """
    reduce a saved model and metadata down to a deployable file; following
    https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc

    output_nodes_list = e.g., ['softmax_linear/logits']
    """
    from tensorflow.python.framework import graph_util

    LOGGER.info('Attempting to freeze graph at {}'.format(model_dir))
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    if input_checkpoint is None:
        LOGGER.error('Cannot load checkpoint at {}'.format(model_dir))
        return None

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/' + output_graph_name
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
                                       clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, output_nodes_list
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        LOGGER.info('Froze graph with {} ops'.format(
            len(output_graph_def.node)
        ))

    return output_graph
项目:faceNet_RealTime    作者:jack55436001    | 项目源码 | 文件源码
def freeze_graph_def(sess, input_graph_def, output_node_names):
    for node in input_graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr: del node.attr['use_locking']

    # Get the list of important nodes
    whitelist_names = []
    for node in input_graph_def.node:
        if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or 
                node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')):
            whitelist_names.append(node.name)

    # Replace all the variables in the graph with constants of the same values
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","),
        variable_names_whitelist=whitelist_names)
    return output_graph_def
项目:tensorflow-for-poets-2    作者:googlecodelabs    | 项目源码 | 文件源码
def save_graph_to_file(sess, graph, graph_file_name):
  output_graph_def = graph_util.convert_variables_to_constants(
      sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
  with gfile.FastGFile(graph_file_name, 'wb') as f:
    f.write(output_graph_def.SerializeToString())
  return
项目:cnn-bnn    作者:jpdz    | 项目源码 | 文件源码
def generatePB(pb_dest = "model_mnist_bnn.pb"):
    gd = sess.graph.as_graph_def()
    gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output'])
    with gfile.FastGFile(pb_dest, 'wb') as f:
        f.write(gd2.SerializeToString())
    print('pb saved')
项目:cnn-bnn    作者:jpdz    | 项目源码 | 文件源码
def generatePB(pb_dest = "model.pb"):
    gd = sess.graph.as_graph_def()
    gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output'])
    with gfile.FastGFile(pb_dest, 'wb') as f:
        f.write(gd2.SerializeToString())
    print('pb saved')
项目:cnn-bnn    作者:jpdz    | 项目源码 | 文件源码
def generatePB(pb_dest = "cifar_bnn_new.pb"):
    gd = sess.graph.as_graph_def()
    gd2 = graph_util.convert_variables_to_constants(sess, gd, ['output'])
    with gfile.FastGFile(pb_dest, 'wb') as f:
        f.write(gd2.SerializeToString())
    print('pb saved')
项目:blog    作者:metaflow-ai    | 项目源码 | 文件源码
def freeze_graph(model_folder):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + "/frozen_model.pb"

    # Before exporting our graph, we need to precise what is our output node
    # NOTE: this variables is plural, because you can have multiple output nodes
    output_node_names = "Accuracy/predictions"

    # We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated
    clear_devices = True

    # We import the meta graph and retrive a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constant
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, 
            input_graph_def, 
            output_node_names.split(",") # We split on comma for convenience
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
项目:facenet    作者:davidsandberg    | 项目源码 | 文件源码
def freeze_graph_def(sess, input_graph_def, output_node_names):
    for node in input_graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr: del node.attr['use_locking']

    # Get the list of important nodes
    whitelist_names = []
    for node in input_graph_def.node:
        if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or 
                node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')):
            whitelist_names.append(node.name)

    # Replace all the variables in the graph with constants of the same values
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","),
        variable_names_whitelist=whitelist_names)
    return output_graph_def
项目:sktacc    作者:jclee81    | 项目源码 | 文件源码
def run():
    log.info('Run freeze restore')

    y = tf.Variable([float(88.8), float(5)], name='y1')
    # print(y.op.node_def)
    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init_op)
        # sess.run(y)
        g = sess.graph
        g_def = g.as_graph_def()

        # print node names
        # print([n.name for n in g_def.node])

        # constants
        constants = graph_util.convert_variables_to_constants(
            sess, g_def, ['y1'])

        # serialize
        s = constants.SerializeToString()
        # print(len(g_def.node))
        print_nodes(g.as_graph_def(), 'before restore:')
        _ = restore_graph(s)
        print_nodes(g.as_graph_def(), 'after restore:')
        t = g.get_tensor_by_name('restore/y1:0')

        sess.run(y.assign(y + t))
        print(sess.run(y))
        # print(len(g_def.node))
        # print(sess.run(y.assign([float(99.9)])))
        # print(n)
        # print(sess.run(y.assign(n)))
        # g2 = tf.Graph()
        # g2_def = g2.as_graph_def()
        # print([n.name for n in g2_def.node])

# run()
项目:sktacc    作者:jclee81    | 项目源码 | 文件源码
def _set_variable_and_publish(self, sess, iteration_id, transaction_id,
                                  group_id):
        # v = variable
        # s = v.to_proto().SerializeToString()
        # h = ':'.join('{:02x}'.format(ord(c)) for c in s)

        variable_names = [var.op.name for var in self.variables]

        g = sess.graph
        g_def = g.as_graph_def()
        constants = graph_util.convert_variables_to_constants(
            sess, g_def, variable_names)
        s = constants.SerializeToString()

        parallel_count = self.infra_info['parallel_count']
        self.rc.set(transaction_id, s)
        message = json.dumps({
            'key': 'set_variable',
            'transaction_id': transaction_id,
            'group_id': group_id,
            'variables': variable_names,
            'worker_id': self.worker_id,
            'train_id': self.train_id,
            'iteration_id': iteration_id,
            'parallel_count': parallel_count
        })
        self.r.publish(channel=channel, message=message)
        log.debug('pub %s' % iteration_id)
        return len(s)
项目:sktacc    作者:jclee81    | 项目源码 | 文件源码
def _set_variable_and_publish(self, sess, iteration_id, variables,
                                  transaction_id, group_id):
        # v = variable
        # s = v.to_proto().SerializeToString()
        # h = ':'.join('{:02x}'.format(ord(c)) for c in s)

        variable_names = [var.op.name for var in variables]

        g = sess.graph
        g_def = g.as_graph_def()
        constants = graph_util.convert_variables_to_constants(
            sess, g_def, variable_names)
        s = constants.SerializeToString()

        parallel_count = self.infra_info['parallel_count']
        self.rc.set(transaction_id, s)
        message = json.dumps({
            'key': 'set_variable',
            'transaction_id': transaction_id,
            'group_id': group_id,
            'variables': variable_names,
            'worker_id': self.worker_id,
            'train_id': self.train_id,
            'iteration_id': iteration_id,
            'parallel_count': parallel_count
        })
        self.r.publish(channel=channel, message=message)
        self._log('pub %s' % iteration_id)
        return len(s)
项目:ZOO-Attack    作者:huanzhang12    | 项目源码 | 文件源码
def save_graph_to_file(sess, graph, graph_file_name):
  output_graph_def = graph_util.convert_variables_to_constants(
      sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
  with gfile.FastGFile(graph_file_name, 'wb') as f:
    f.write(output_graph_def.SerializeToString())
  return
项目:image2text    作者:KleinYuan    | 项目源码 | 文件源码
def freeze_graph(model_folder, net_name):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + "/%s.pb" % net_name
    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    # NOTE: this variable is plural, because you can have multiple output nodes
    output_node_names = "Placeholder,Placeholder_1,Placeholder_2,out/add"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        graph_def = sess.graph.as_graph_def()
        for node in graph_def.node:
            print node.name
        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            input_graph_def,  # The graph_def is used to retrieve the nodes
            output_node_names.split(",")  # The output node names are used to select the usefull nodes
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
项目:vae-style-transfer    作者:sunsided    | 项目源码 | 文件源码
def export_graph(input_path, output_path, output_nodes, debug=False):
    # todo: might want to look at http://stackoverflow.com/a/39578062/195651

    checkpoint = tf.train.latest_checkpoint(input_path)
    importer = tf.train.import_meta_graph(checkpoint + '.meta', clear_devices=True)

    graph = tf.get_default_graph()  # type: tf.Graph
    gd = graph.as_graph_def()  # type: tf.GraphDef

    if debug:
        op_names = [op.name for op in graph.get_operations()]
        print(op_names)

    # fix batch norm nodes
    # https://github.com/tensorflow/tensorflow/issues/3628
    for node in gd.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in range(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] += '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']

    if debug:
        print('Freezing the graph ...')
    with tf.Session() as sess:
        importer.restore(sess, checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_nodes)
        tf.train.write_graph(output_graph_def, path.dirname(output_path), path.basename(output_path), as_text=False)
项目:face2face-demo    作者:datitran    | 项目源码 | 文件源码
def freeze_graph(model_folder):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_folder = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + '/frozen_model.pb'

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    # NOTE: this variable is plural, because you can have multiple output nodes
    output_node_names = 'generate_output/output'

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            input_graph_def,  # The graph_def is used to retrieve the nodes
            output_node_names.split(",")  # The output node names are used to select the usefull nodes
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))
项目:han    作者:croath    | 项目源码 | 文件源码
def main(_):
    output_node_names = "output_prob"

    session_config = tf.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_fraction

    with tf.Session(config=session_config) as sess:
        ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
        saver = tf.train.import_meta_graph(ckpt + '.meta')
        if ckpt:
            saver.restore(sess, ckpt)

        # for node in input_graph_def.node:
        #     print(node.name, node.op, node.input)

        # Retrieve the protobuf graph definition and fix the batch norm nodes
        # Fix for bug of BN.
        # Ref 1 Solution: https://github.com/davidsandberg/facenet/issues/161
        # Ref 2 Official Issue: https://github.com/tensorflow/tensorflow/issues/3628
        gd = sess.graph.as_graph_def()
        for node in gd.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in range(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr: del node.attr['use_locking']
            elif node.op == 'AssignAdd':
                node.op = 'Add'
                if 'use_locking' in node.attr: del node.attr['use_locking']

        output_graph_def = graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            gd, # The graph_def is used to retrieve the nodes
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        )

        with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'model.pb'), "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
项目:keras-to-tensorflow    作者:bitbionic    | 项目源码 | 文件源码
def convertGraph( modelPath, outdir, numoutputs, prefix, name):
    '''
    Converts an HD5F file to a .pb file for use with Tensorflow.

    Args:
        modelPath (str): path to the .h5 file
           outdir (str): path to the output directory
       numoutputs (int):   
           prefix (str): the prefix of the output aliasing
             name (str):
    Returns:
        None
    '''

    #NOTE: If using Python > 3.2, this could be replaced with os.makedirs( name, exist_ok=True )
    if not os.path.isdir(outdir):
        os.mkdir(outdir)

    K.set_learning_phase(0)

    net_model = load_model(modelPath)

    # Alias the outputs in the model - this sometimes makes them easier to access in TF
    pred = [None]*numoutputs
    pred_node_names = [None]*numoutputs
    for i in range(numoutputs):
        pred_node_names[i] = prefix+'_'+str(i)
        pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])
    print('Output nodes names are: ', pred_node_names)

    sess = K.get_session()

    # Write the graph in human readable
    f = 'graph_def_for_reference.pb.ascii'
    tf.train.write_graph(sess.graph.as_graph_def(), outdir, f, as_text=True)
    print('Saved the graph definition in ascii format at: ', osp.join(outdir, f))

    # Write the graph in binary .pb file
    from tensorflow.python.framework import graph_util
    from tensorflow.python.framework import graph_io
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
    graph_io.write_graph(constant_graph, outdir, name, as_text=False)
    print('Saved the constant graph (ready for inference) at: ', osp.join(outdir, name))
项目:tensorflow-vgg16-train-and-test    作者:ppplinday    | 项目源码 | 文件源码
def train_network(graph, batch_size, num_epochs, pb_file_path):
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        epoch_delta = 2
        for epoch_index in range(num_epochs):
            for i in range(12):
                sess.run([graph['optimize']], feed_dict={
                    graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
                    graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
                })
            if epoch_index % epoch_delta == 0:
                total_batches_in_train_set = 0
                total_correct_times_in_train_set = 0
                total_cost_in_train_set = 0.
                for i in range(12):
                    return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
                        graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
                        graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
                    })
                    mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
                        graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
                        graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
                    })
                    total_batches_in_train_set += 1
                    total_correct_times_in_train_set += return_correct_times_in_batch
                    total_cost_in_train_set += (mean_cost_in_batch * batch_size)


                total_batches_in_test_set = 0
                total_correct_times_in_test_set = 0
                total_cost_in_test_set = 0.
                for i in range(3):
                    return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
                        graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
                        graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
                    })
                    mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
                        graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
                        graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
                    })
                    total_batches_in_test_set += 1
                    total_correct_times_in_test_set += return_correct_times_in_batch
                    total_cost_in_test_set += (mean_cost_in_batch * batch_size)

                acy_on_test  = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size)
                acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size)
                print('Epoch - {:2d}, acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set,
                                                                                                                                                   total_batches_in_test_set * batch_size,
                                                                                                                                                   total_cost_in_test_set,
                                                                                                                                                   acy_on_train * 100.0,
                                                                                                                                                   total_correct_times_in_train_set,
                                                                                                                                                   total_batches_in_train_set * batch_size,
                                                                                                                                                   total_cost_in_train_set))
            constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
            with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
                f.write(constant_graph.SerializeToString())
项目:deep-style-transfer    作者:albertlai    | 项目源码 | 文件源码
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices, initializer_nodes):
  """Converts all variables in a graph and checkpoint into constants."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not tf.gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not tf.train.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read().decode("utf-8"), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = tf.import_graph_def(input_graph_def, name="")

  with tf.Session() as sess:
    if input_saver:
      with tf.gfile.FastGFile(input_saver, mode) as f:
        saver_def = tf.train.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = tf.train.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
项目:sktacc    作者:jclee81    | 项目源码 | 文件源码
def _calculate_average_and_put(self, group_id, item, m):
        keys = item['keys']
        tf.reset_default_graph()
        sess = tf.Session()
        new_vars = []

        m_cal_and_put = SimpleMeasurement('cal_and_put', m)

        m_init = SimpleMeasurement('init', m)
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        m_init.end_measure()

        for v in item['variables']:
            count = 0
            name = 'average_%s' % v
            ts = []
            for key in keys:
                raw = self.rc.get(key)
                # TODO: check raw is not None
                util.restore_graph(key, raw)
                g = sess.graph
                t = g.get_tensor_by_name('%s/%s:0' % (key, v))
                ts.append(t)
                count += 1

            m_cal = SimpleMeasurement('cal', m)
            avg = tf.foldl(tf.add, ts) / count
            new_var = tf.Variable(avg, name=name)
            sess.run(new_var.initializer)
            sess.run(new_var)
            new_vars.append(name)
            m_cal.end_measure()

        g = sess.graph
        g_def = g.as_graph_def()
        constants = graph_util.convert_variables_to_constants(
            sess, g_def, new_vars)
        s = constants.SerializeToString()
        self.rc.set(group_id, s)
        sess.close()

        m_cal_and_put.end_measure()
项目:zhihu_kanshanbei    作者:No-account    | 项目源码 | 文件源码
def freeze_graph(model_folder):
    # We retrieve our checkpoint fullpath  
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph  
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + "/frozen_model.pb"

    # Before exporting our graph, we need to precise what is our output node  
    # this variables is plural, because you can have multiple output nodes  
    # freeze?????????????,???????????????
    # ??????????????
    # ?????????,freeze????????????????????,??????????
    # ??,output_node_names?????????????
    output_node_names = "softmaxLayer/Softmax"

    # We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated  
    clear_devices = True

    # We import the meta graph and retrive a Saver  
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition  
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    # ???????????????,????????????,???????????,??????frozen
    # ???????????????
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constant  
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(",")  # We split on comma for convenience
        )

        # Finally we serialize and dump the output graph to the filesystem  
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
项目:neural_style    作者:metaflow-ai    | 项目源码 | 文件源码
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices, initializer_nodes, verbose=True):
  """Converts all variables in a graph and checkpoint into constants."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not tf.gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  if not tf.gfile.Glob(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = tf.import_graph_def(input_graph_def, name="")

  with tf.Session() as sess:
    if input_saver:
      with tf.gfile.FastGFile(input_saver, mode) as f:
        saver_def = tf.train.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = tf.train.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  if verbose == True:
    print("%d ops in the final graph." % len(output_graph_def.node))
项目:FCN-TensorFlow    作者:shoaibahmed    | 项目源码 | 文件源码
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices, initializer_nodes):
  """Converts all variables in a graph and checkpoint into constants."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not tf.gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  if not tf.gfile.Glob(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = tf.import_graph_def(input_graph_def, name="")

  with tf.Session() as sess:
    if input_saver:
      with tf.gfile.FastGFile(input_saver, mode) as f:
        saver_def = tf.train.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = tf.train.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
项目:TextCNN    作者:ivancruzbht    | 项目源码 | 文件源码
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices, initializer_nodes):
  """Converts all variables in a graph and checkpoint into constants."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not tf.gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not tf.train.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read().decode("utf-8"), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = tf.import_graph_def(input_graph_def, name="")

  with tf.Session() as sess:
    if input_saver:
      with tf.gfile.FastGFile(input_saver, mode) as f:
        saver_def = tf.train.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = tf.train.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
项目:TensorFlow_DCIGN    作者:yselivonchyk    | 项目源码 | 文件源码
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices, initializer_nodes):
  """Converts all variables in a graph and checkpoint into constants."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not tf.gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  if not tf.gfile.Glob(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = tf.import_graph_def(input_graph_def, name="")

  with tf.Session() as sess:
    if input_saver:
      with tf.gfile.FastGFile(input_saver, mode) as f:
        saver_def = tf.train.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = tf.train.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
项目:lego_yolo    作者:dEcmir    | 项目源码 | 文件源码
def freeze_graph(model_folder):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + "/frozen_model.pb"

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    # NOTE: this variables is plural, because you can have multiple output
    # nodes
    output_node_names = "29_fully_connected"

    # We clear the devices, to allow TensorFlow to control on the loading
    # where it wants operations to be calculated
    clear_devices = True

    # We import the meta graph and retrive a Saver
    saver = tf.train.import_meta_graph(
        input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constant
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            input_graph_def,  # The graph_def is used to retrieve the nodes
            # The output node names are used to select the usefull nodes
            output_node_names.split(",")
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
项目:pgnet    作者:galeone    | 项目源码 | 文件源码
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices, initializer_nodes):
    """Converts all variables in a graph and checkpoint into constants."""

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not tf.gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    if not tf.gfile.Glob(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read(), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    _ = tf.import_graph_def(input_graph_def, name="")

    with tf.Session() as sess:
        if input_saver:
            with tf.gfile.FastGFile(input_saver, mode) as f:
                saver_def = tf.train.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = tf.train.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            sess.run([restore_op_name],
                     {filename_tensor_name: input_checkpoint})
            if initializer_nodes:
                sess.run(initializer_nodes)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, output_node_names.split(","))

    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))