Python keras.backend 模块,batch_set_value() 实例源码

我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用keras.backend.batch_set_value()

项目:speech_ml    作者:coopie    | 项目源码 | 文件源码
def get_weights_from_h5_group(model, model_weights, verbose=1):
    layers = model.layers

    weight_value_tuples = []
    for layer in layers:
        name = layer.name
        if name in model_weights and len(model_weights[name]) > 0:
            layer_weights = model_weights[name]
            weight_names = [n.decode('utf8') for n in layer_weights.attrs['weight_names']]
            weight_values = [layer_weights[weight_name] for weight_name in weight_names]
            symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
            if len(weight_values) != len(symbolic_weights):
                raise Exception('Layer #' + str(k) +
                                ' (named "' + layer.name +
                                '" in the current model) was found to '
                                'correspond to layer ' + name +
                                ' in the save file. '
                                'However the new layer ' + layer.name +
                                ' expects ' + str(len(symbolic_weights)) +
                                ' weights, but the saved weights have ' +
                                str(len(weight_values)) +
                                ' elements.')

            if verbose:
                print('Setting_weights for layer:', name)

            weight_value_tuples += zip(symbolic_weights, weight_values)

    K.batch_set_value(weight_value_tuples)
项目:keras-fcn    作者:JihongJu    | 项目源码 | 文件源码
def load_weights(model, weights_path):
    """Load weights from Caffe models."""
    print("Loading weights...")
    if h5py is None:
        raise ImportError('`load_weights` requires h5py.')
    f = h5py.File(weights_path, mode='r')

    # New file format.
    layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]

    # Reverse index of layer name to list of layers with name.
    index = {}
    for layer in model.layers:
        if layer.name:
            index.setdefault(layer.name, []).append(layer)

    # We batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
        weight_values = [g[weight_name] for weight_name in weight_names]

        for layer in index.get(name, []):
            symbolic_weights = layer.weights
            # Set values.
            for i in range(len(weight_values)):
                weight_value_tuples.append((symbolic_weights[i],
                                            weight_values[i]))
    K.batch_set_value(weight_value_tuples)

    return layer_names
项目:aetros-cli    作者:aetros    | 项目源码 | 文件源码
def load_weights(model, weights_path):
    from keras import backend as K

    if not os.path.isfile(weights_path):
        raise Exception("File does not exist.")

    import h5py
    f = h5py.File(weights_path, mode='r')

    # new file format
    layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
    if len(layer_names) != len(model.layers):
        print("Warning: Layer count different")

    # we batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]

        layer = model.get_layer(name=name)
        if layer and len(weight_names):
            weight_values = [g[weight_name] for weight_name in weight_names]
            if not hasattr(layer, 'trainable_weights'):
                print("Layer %s (%s) has no trainable weights, but we tried to load it." % (
                name, type(layer).__name__))
            else:
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights

                if len(weight_values) != len(symbolic_weights):
                    raise Exception('Layer #' + str(k) +
                                    ' (named "' + layer.name +
                                    '" in the current model) was found to '
                                    'correspond to layer ' + name +
                                    ' in the save file. '
                                    'However the new layer ' + layer.name +
                                    ' expects ' + str(len(symbolic_weights)) +
                                    ' weights, but the saved weights have ' +
                                    str(len(weight_values)) +
                                    ' elements.')

                weight_value_tuples += list(zip(symbolic_weights, weight_values))
    K.batch_set_value(weight_value_tuples)

    f.close()