Python keras.backend 模块,round() 实例源码

我们从Python开源项目中,提取了以下37个代码示例,用于说明如何使用keras.backend.round()

项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def generate_gpu(configs,**kwargs):
    configs = np.array(configs)
    import math
    size = int(math.sqrt(len(configs[0])))
    base = panels.shape[1]
    dim = base*size

    def build():
        P = 2
        configs = Input(shape=(size*size,))
        _configs = 1 - K.round((configs/2)+0.5) # from -1/1 to 1/0
        configs_one_hot = K.one_hot(K.cast(_configs,'int32'), P)
        configs_one_hot = K.reshape(configs_one_hot, [-1,P])
        _panels = K.variable(panels)
        _panels = K.reshape(_panels, [P, base*base])
        states = tf.matmul(configs_one_hot, _panels)
        states = K.reshape(states, [-1, size, size, base, base])
        states = K.permute_dimensions(states, [0, 1, 3, 2, 4])
        states = K.reshape(states, [-1, size*base, size*base, 1])
        states = K.spatial_2d_padding(states, padding=((pad,pad),(pad,pad)))
        states = K.squeeze(states, -1)
        return Model(configs, wrap(configs, states))

    return preprocess(batch_swirl(build().predict(configs,**kwargs)))
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def generate_gpu2(configs,**kwargs):
    configs = np.array(configs)
    import math
    size = int(math.sqrt(len(configs[0])))
    base = panels.shape[1]
    dim = base*size

    def build():
        P = 2
        configs = Input(shape=(size*size,))
        _configs = 1 - K.round((configs/2)+0.5) # from -1/1 to 1/0
        configs_one_hot = K.one_hot(K.cast(_configs,'int32'), P)
        configs_one_hot = K.reshape(configs_one_hot, [-1,P])
        _panels = K.variable(panels)
        _panels = K.reshape(_panels, [P, base*base])
        states = tf.matmul(configs_one_hot, _panels)
        states = K.reshape(states, [-1, size, size, base, base])
        states = K.permute_dimensions(states, [0, 1, 3, 2, 4])
        states = K.reshape(states, [-1, size*base, size*base, 1])
        states = K.spatial_2d_padding(states, padding=((pad,pad),(pad,pad)))
        states = K.squeeze(states, -1)
        states = tensor_swirl(states, radius=dim+2*pad * relative_swirl_radius, **swirl_args)
        return Model(configs, wrap(configs, states))

    return preprocess(build().predict(configs,**kwargs))
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def to_configs(states, verbose=True, **kwargs):
    base = panels.shape[1]
    dim  = states.shape[1] - pad*2
    size = dim // base

    def build():
        states = Input(shape=(dim+2*pad,dim+2*pad))
        s = tensor_swirl(states, radius=dim+2*pad * relative_swirl_radius, **unswirl_args)
        error = build_errors(s,base,pad,dim,size)
        matches = 1 - K.clip(K.sign(error - threshold),0,1)
        # a, h, w, panel
        matches = K.reshape(matches, [K.shape(states)[0], size * size, -1])
        # a, pos, panel
        config = matches * K.arange(2,dtype='float')
        config = K.sum(config, axis=-1)
        # this is 0,1 configs; for compatibility, we need -1 and 1
        config = - (config - 0.5)*2
        return Model(states, wrap(states, K.round(config)))

    return build().predict(states, **kwargs)
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def to_configs(states, verbose=True, **kwargs):
    base = panels.shape[1]
    size = states.shape[1]//base
    dim  = states.shape[1]

    def build():
        states = Input(shape=(dim,dim))
        error = build_errors(states,base,dim,size)
        matches = 1 - K.clip(K.sign(error - threshold),0,1)
        # a, h, w, panel
        matches = K.reshape(matches, [K.shape(states)[0], size * size, -1])
        # a, pos, panel
        config = matches * K.arange(2,dtype='float')
        config = K.sum(config, axis=-1)
        # this is 0,1 configs; for compatibility, we need -1 and 1
        config = - (config - 0.5)*2
        return Model(states, wrap(states, K.round(config)))

    model = build()
    return model.predict(states, **kwargs)
项目:deepcpg    作者:cangermueller    | 项目源码 | 文件源码
def contingency_table(y, z):
    """Compute contingency table."""
    y = K.round(y)
    z = K.round(z)

    def count_matches(a, b):
        tmp = K.concatenate([a, b])
        return K.sum(K.cast(K.all(tmp, -1), K.floatx()))

    ones = K.ones_like(y)
    zeros = K.zeros_like(y)
    y_ones = K.equal(y, ones)
    y_zeros = K.equal(y, zeros)
    z_ones = K.equal(z, ones)
    z_zeros = K.equal(z, zeros)

    tp = count_matches(y_ones, z_ones)
    tn = count_matches(y_zeros, z_zeros)
    fp = count_matches(y_zeros, z_ones)
    fn = count_matches(y_ones, z_zeros)

    return (tp, tn, fp, fn)
项目:kaggle_amazon    作者:asanakoy    | 项目源码 | 文件源码
def fscore(y_true, y_pred, average='samples', beta=2):
    sum_axis = 1 if average == 'samples' else 0

    # calculate weighted counts
    true_and_pred = K.round(K.clip(y_true * y_pred, 0, 1))
    tp_sum = K.sum(true_and_pred, axis=sum_axis)
    pred_sum = K.sum(y_pred, axis=sum_axis)
    true_sum = K.sum(y_true, axis=sum_axis)

    beta2 = beta ** 2

    precision = tp_sum / (pred_sum + K.epsilon())
    recall = tp_sum / (true_sum + K.epsilon())

    f_score = ((1 + beta2) * precision * recall /
               (beta2 * precision + recall + K.epsilon()))
    # f_score[tp_sum == 0] = 0.0
    # f_score = K.switch(K.equal(f_score, 0.0), 0.0, f_score)
    return K.mean(f_score)
项目:keras_bn_library    作者:bnsnapper    | 项目源码 | 文件源码
def call(self, x, mask=None):
        if self.mode == 'maximum_likelihood':
            # draw maximum likelihood sample from Bernoulli distribution
            #    x* = argmax_x p(x) = 1         if p(x=1) >= 0.5
            #                         0         otherwise
            return K.round(x)
        elif self.mode == 'random':
            # draw random sample from Bernoulli distribution
            #    x* = x ~ p(x) = 1              if p(x=1) > uniform(0, 1)
            #                    0              otherwise
            #return self.srng.binomial(size=x.shape, n=1, p=x, dtype=K.floatx())
            return K.random_binomial(x.shape, p=x, dtype=K.floatx())
        elif self.mode == 'mean_field':
            # draw mean-field approximation sample from Bernoulli distribution
            #    x* = E[p(x)] = E[Bern(x; p)] = p
            return x
        elif self.mode == 'nrlu':
            return nrlu(x)
        else:
            raise NotImplementedError('Unknown sample mode!')
项目:nn_playground    作者:DingKe    | 项目源码 | 文件源码
def round_through(x):
    '''Element-wise rounding to the closest integer with full gradient propagation.
    A trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182)
    '''
    rounded = K.round(x)
    return x + K.stop_gradient(rounded - x)
项目:nn_playground    作者:DingKe    | 项目源码 | 文件源码
def round_through(x):
    '''Element-wise rounding to the closest integer with full gradient propagation.
    A trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182)
    '''
    rounded = K.round(x)
    return x + K.stop_gradient(rounded - x)
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def build_error(s, height, width, base):
    P = len(setting['panels'])
    s = K.reshape(s,[-1,height,base,width,base])
    s = K.permute_dimensions(s, [0,1,3,2,4])
    s = K.reshape(s,[-1,height,width,1,base,base])
    s = K.tile(s, [1,1,1,P,1,1,])

    allpanels = K.variable(np.array(setting['panels']))
    allpanels = K.reshape(allpanels, [1,1,1,P,base,base])
    allpanels = K.tile(allpanels, [K.shape(s)[0], height, width, 1, 1, 1])

    def hash(x):
        ## 2x2 average hashing
        x = K.reshape(x, [-1,height,width,P, base//2, 2, base//2, 2])
        x = K.mean(x, axis=(5,7))
        return K.round(x)
        ## diff hashing (horizontal diff)
        # x1 = x[:,:,:,:,:,:-1]
        # x2 = x[:,:,:,:,:,1:]
        # d = x1 - x2
        # return K.round(d)
        ## just rounding
        # return K.round(x)
        ## do nothing
        # return x

    # s         = hash(s)
    # allpanels = hash(allpanels)

    # error = K.binary_crossentropy(s, allpanels)
    error = K.abs(s - allpanels)
    error = hash(error)
    error = K.mean(error, axis=(4,5))
    return error
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def build_errors(states,base,pad,dim,size):
    # address the numerical viscosity in swirling
    s = K.round(states+viscosity_adjustment)
    s = Reshape((dim+2*pad,dim+2*pad,1))(s)
    s = Cropping2D(((pad,pad),(pad,pad)))(s)
    s = K.reshape(s,[-1,size,base,size,base])
    s = K.permute_dimensions(s, [0,1,3,2,4])
    s = K.reshape(s,[-1,size,size,1,base,base])
    s = K.tile   (s,[1, 1, 1, 2, 1, 1,]) # number of panels : 2

    allpanels = K.variable(panels)
    allpanels = K.reshape(allpanels, [1,1,1,2,base,base])
    allpanels = K.tile(allpanels, [K.shape(s)[0], size,size, 1, 1, 1])

    def hash(x):
        ## 2x2 average hashing
        x = K.reshape(x, [-1,size,size,2, base//3, 3, base//3, 3])
        x = K.mean(x, axis=(5,7))
        return K.round(x)
        ## diff hashing (horizontal diff)
        # x1 = x[:,:,:,:,:,:-1]
        # x2 = x[:,:,:,:,:,1:]
        # d = x1 - x2
        # return K.round(d)
        ## just rounding
        # return K.round(x)
        ## do nothing
        # return x

    # s         = hash(s)
    # allpanels = hash(allpanels)

    # error = K.binary_crossentropy(s, allpanels)
    error = K.abs(s - allpanels)
    error = hash(error)
    error = K.mean(error, axis=(4,5))
    return error
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def build_errors(states,base,dim,size):
    s = K.reshape(states,[-1,size,base,size,base])
    s = K.permute_dimensions(s, [0,1,3,2,4])
    s = K.reshape(s,[-1,size,size,1,base,base])
    s = K.tile   (s,[1, 1, 1, 2, 1, 1,]) # number of panels : 2

    allpanels = K.variable(panels)
    allpanels = K.reshape(allpanels, [1,1,1,2,base,base])
    allpanels = K.tile(allpanels, [K.shape(s)[0], size,size, 1, 1, 1])

    def hash(x):
        ## 2x2 average hashing
        # x = K.reshape(x, [-1,size,size,2, base//2, 2, base//2, 2])
        # x = K.mean(x, axis=(5,7))
        # return K.round(x)
        ## diff hashing (horizontal diff)
        # x1 = x[:,:,:,:,:,:-1]
        # x2 = x[:,:,:,:,:,1:]
        # d = x1 - x2
        # return K.round(d)
        ## just rounding
        return K.round(x)
        ## do nothing
        # return x

    # s         = hash(s)
    # allpanels = hash(allpanels)

    # error = K.binary_crossentropy(s, allpanels)
    error = K.abs(s - allpanels)
    error = hash(error)
    error = K.mean(error, axis=(4,5))
    return error
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def build_error(s, disks, towers, tower_width, panels):
    s = K.reshape(s,[-1,disks, disk_height, towers, tower_width])
    s = K.permute_dimensions(s, [0,1,3,2,4])
    s = K.reshape(s,[-1,disks,towers,1,    disk_height,tower_width])
    s = K.tile   (s,[1, 1, 1, disks+1,1, 1,])

    allpanels = K.variable(panels)
    allpanels = K.reshape(allpanels, [1,1,1,disks+1,disk_height,tower_width])
    allpanels = K.tile(allpanels, [K.shape(s)[0], disks, towers, 1, 1, 1])

    def hash(x):
        ## 2x2 average hashing (now it does not work since disks have 1 pixel height)
        # x = K.reshape(x, [-1,disks,towers,disks+1, disk_height,tower_width//2,2])
        # x = K.mean(x, axis=(4,))
        # return K.round(x)
        ## diff hashing (horizontal diff)
        # x1 = x[:,:,:,:,:,:-1]
        # x2 = x[:,:,:,:,:,1:]
        # d = x1 - x2
        # return K.round(d)
        ## just rounding
        return K.round(x)
        ## do nothing
        # return x

    s         = hash(s)
    allpanels = hash(allpanels)

    # error = K.binary_crossentropy(s, allpanels)
    error = K.abs(s - allpanels)
    error = K.mean(error, axis=(4,5))
    return error
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def plot(self,data,path,verbose=False):
        self.load()
        x = data
        z = self.encode_binary(x)
        y = self.decode_binary(z)
        b = np.round(z)
        by = self.decode_binary(b)

        xg = gaussian(x)
        xs = salt(x)
        xp = pepper(x)

        yg = self.autoencode(xg)
        ys = self.autoencode(xs)
        yp = self.autoencode(xp)

        dy  =  y-x
        dby = by-x
        dyg = yg-x
        dys = ys-x
        dyp = yp-x

        from .util.plot import plot_grid, squarify
        _z = squarify(z)
        _b = squarify(b)

        images = []
        from .util.plot import plot_grid
        for seq in zip(x, _z, y, dy, _b, by, dby, xg, yg, dyg, xs, ys, dys, xp, yp, dyp):
            images.extend(seq)
        plot_grid(images, w=16, path=self.local(path), verbose=verbose)
        return x,z,y,b,by
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def plot_autodecode(self,data,path,verbose=False):
        self.load()
        z = data
        x = self.decode_binary(z)

        z2 = self.encode_binary(x)
        z2r = z2.round()
        x2 = self.decode_binary(z2)
        x2r = self.decode_binary(z2r)

        z3 = self.encode_binary(x2)
        z3r = z3.round()
        x3 = self.decode_binary(z3)
        x3r = self.decode_binary(z3r)

        M, N = self.parameters['M'], self.parameters['N']

        from .util.plot import plot_grid, squarify
        _z   = squarify(z)
        _z2  = squarify(z2)
        _z2r = squarify(z2r)
        _z3  = squarify(z3)
        _z3r = squarify(z3r)

        images = []
        from .util.plot import plot_grid
        for seq in zip(_z, x, _z2, _z2r, x2, x2r, _z3, _z3r, x3, x3r):
            images.extend(seq)
        plot_grid(images, w=10, path=self.local(path), verbose=verbose)
        return _z, x, _z2, _z2r
项目:latplan    作者:guicho271828    | 项目源码 | 文件源码
def _build(self,input_shape):
        x = Input(shape=input_shape)

        self.discriminators = []
        for i in range(self.parameters['bagging']):
            d = Discriminator(self.path+"/"+str(i),self.parameters)
            d.build(input_shape)
            self.discriminators.append(d)

        y = average([ d.net(x) for d in self.discriminators ])
        y = wrap(y,K.round(y))
        self.net = Model(x,y)
        self.net.compile(optimizer='adam',loss=bce)
项目:dream2016_dm    作者:lishen    | 项目源码 | 文件源码
def sensitivity(y_true, y_pred):
        y_pred_pos = K.round(K.clip(y_pred, 0, 1))
        y_pos = K.round(K.clip(y_true, 0, 1))
        tp = K.sum(y_pos * y_pred_pos)
        pos = K.sum(y_pos)

        return tp / (pos + K.epsilon())
项目:dream2016_dm    作者:lishen    | 项目源码 | 文件源码
def specificity(y_true, y_pred):
        y_pred_neg = 1 - K.round(K.clip(y_pred, 0, 1))
        y_neg = 1 - K.round(K.clip(y_true, 0, 1))
        tn = K.sum(y_neg * y_pred_neg)
        neg = K.sum(y_neg)

        return tn / (neg + K.epsilon())
项目:HighwayNetwork    作者:trangptm    | 项目源码 | 文件源码
def precision(y_true, y_pred):
    y_p = K.round(y_pred)
    tp = 1.0 * K.sum(y_true * y_p)
    return tp / K.sum(y_p)
项目:HighwayNetwork    作者:trangptm    | 项目源码 | 文件源码
def recall(y_true, y_pred):
    y_p = K.round(y_pred)
    tp = 1.0 * K.sum(y_true * y_p)
    return tp / K.sum(y_true)
项目:kaggle_amazon    作者:asanakoy    | 项目源码 | 文件源码
def precision(y_true, y_pred):
    """Precision metric.

    Only computes a batch-wise average of precision.

    Computes the precision, a metric for multi-label classification of
    how many selected items are relevant.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision
项目:kaggle_amazon    作者:asanakoy    | 项目源码 | 文件源码
def recall(y_true, y_pred):
    """Recall metric.

    Only computes a batch-wise average of recall.

    Computes the recall, a metric for multi-label classification of
    how many relevant items are selected.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall
项目:kaggle_amazon    作者:asanakoy    | 项目源码 | 文件源码
def fbeta_score(y_true, y_pred, beta=1):
    """Computes the F score.

    The F score is the weighted harmonic mean of precision and recall.
    Here it is only computed as a batch-wise average, not globally.

    This is useful for multi-label classification, where input samples can be
    classified as sets of labels. By only using accuracy (precision) a model
    would achieve a perfect score by simply assigning every class to every
    input. In order to avoid this, a metric should penalize incorrect class
    assignments as well (recall). The F-beta score (ranged from 0.0 to 1.0)
    computes this, as a weighted mean of the proportion of correct class
    assignments vs. the proportion of incorrect class assignments.

    With beta = 1, this is equivalent to a F-measure. With beta < 1, assigning
    correct classes becomes more important, and with beta > 1 the metric is
    instead weighted towards penalizing incorrect class assignments.
    """
    if beta < 0:
        raise ValueError('The lowest choosable beta is zero (only precision).')

    # If there are no true positives, fix the F score at 0 like sklearn.
    if K.sum(K.round(K.clip(y_true, 0, 1))) == 0:
        return 0

    p = precision(y_true, y_pred)
    r = recall(y_true, y_pred)
    bb = beta ** 2
    fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())
    return fbeta_score
项目:kaggle_amazon    作者:asanakoy    | 项目源码 | 文件源码
def f2score_samples(y_true, y_pred, thresh=0.2):
    y_pred = K.round(K.clip(y_pred + thresh, 0, 1))
    return fscore(y_true, y_pred, average='samples', beta=2)
项目:neural-decoder    作者:Krastanov    | 项目源码 | 文件源码
def exact_reversal(self, y_true, y_pred):
        "Fraction exactly predicted qubit flips."
        if self.p:
            y_pred = undo_normcentererr(y_pred, self.p)
            y_true = undo_normcentererr(y_true, self.p)
        return K.mean(F(K.all(K.equal(y_true, K.round(y_pred)), axis=-1)))
项目:neural-decoder    作者:Krastanov    | 项目源码 | 文件源码
def non_triv_stab_expanded(self, y_true, y_pred):
        "Whether the stabilizer after correction is not trivial."
        if self.p:
            y_pred = undo_normcentererr(y_pred, self.p)
            y_true = undo_normcentererr(y_true, self.p)
        return K.any(K.dot(self.H, K.transpose((K.round(y_pred)+y_true)%2))%2, axis=0)
项目:neural-decoder    作者:Krastanov    | 项目源码 | 文件源码
def logic_error_expanded(self, y_true, y_pred):
        "Whether there is a logical error after correction."
        if self.p:
            y_pred = undo_normcentererr(y_pred, self.p)
            y_true = undo_normcentererr(y_true, self.p)
        return K.any(K.dot(self.E, K.transpose((K.round(y_pred)+y_true)%2))%2, axis=0)
项目:kaggle-dstl-satellite-imagery-feature-detection    作者:alno    | 项目源码 | 文件源码
def jaccard_coef_int(y_true, y_pred, smooth=1e-12, class_weights=1):
    # __author__ = Vladimir Iglovikov
    y_pred_pos = K.round(K.clip(y_pred, 0, 1))

    intersection = K.sum(y_true * y_pred_pos, axis=[0, -1, -2])
    union = K.sum(y_true + y_pred, axis=[0, -1, -2]) - intersection

    return K.mean((intersection + smooth) / (union + smooth) * class_weights)
项目:FingerNet    作者:felixTY    | 项目源码 | 文件源码
def seg_acc_pos(y_true, y_pred):
    y_true = K.tf.where(K.tf.less(y_true,0.0), K.tf.zeros_like(y_true), y_true)
    acc = K.cast(K.equal(y_true, K.round(y_pred)), dtype=K.tf.float32)
    acc = K.sum(acc * y_true) / (K.sum(y_true)+K.epsilon())
    return acc
项目:FingerNet    作者:felixTY    | 项目源码 | 文件源码
def seg_acc_neg(y_true, y_pred):
    y_true = K.tf.where(K.tf.less(y_true,0.0), K.tf.zeros_like(y_true), y_true)
    acc = K.cast(K.equal(y_true, K.round(y_pred)), dtype=K.tf.float32)
    acc = K.sum(acc * (1-y_true)) / (K.sum(1-y_true)+K.epsilon())
    return acc
项目:FingerNet    作者:felixTY    | 项目源码 | 文件源码
def seg_acc_all(y_true, y_pred):
    y_true = K.tf.where(K.tf.less(y_true,0.0), K.tf.zeros_like(y_true), y_true)
    return K.mean(K.equal(y_true, K.round(y_pred)))
项目:FingerNet    作者:felixTY    | 项目源码 | 文件源码
def seg_acc_pos(y_true, y_pred):
    y_true = K.tf.where(K.tf.less(y_true,0.0), K.tf.zeros_like(y_true), y_true)
    acc = K.cast(K.equal(y_true, K.round(y_pred)), dtype=K.tf.float32)
    acc = K.sum(acc * y_true) / (K.sum(y_true)+K.epsilon())
    return acc
项目:FingerNet    作者:felixTY    | 项目源码 | 文件源码
def seg_acc_all(y_true, y_pred):
    y_true = K.tf.where(K.tf.less(y_true,0.0), K.tf.zeros_like(y_true), y_true)
    return K.mean(K.equal(y_true, K.round(y_pred)))
项目:Ultras-Sound-Nerve-Segmentation---Kaggle    作者:Simoncarbo    | 项目源码 | 文件源码
def dice_tresh(y_true, y_pred):
    y_pred = K.round(y_pred)

    intersection = K.sum(K.sum(y_true * y_pred,axis = -1),axis = -1)
    sum_pred = K.sum(K.sum(y_pred,axis = -1),axis = -1)
    sum_true = K.sum(K.sum(y_true,axis = -1),axis = -1)

    return -K.mean((2. * intersection  + smooth) / (sum_true + sum_pred + smooth))
项目:Ultras-Sound-Nerve-Segmentation---Kaggle    作者:Simoncarbo    | 项目源码 | 文件源码
def pres_acc(y_true, y_pred):
    true = (K.max(K.max(y_true,axis = -1),axis = -1))
    pred = (K.max(K.max(y_pred,axis = -1),axis = -1))
    return K.mean(K.equal(K.round(pred),K.round(true)))
项目:FingerNet    作者:felixTY    | 项目源码 | 文件源码
def deploy(deploy_set, set_name=None):
    if set_name is None:
        set_name = deploy_set.split('/')[-2]
    mkdir(output_dir+'/'+set_name+'/')
    logging.info("Predicting %s:"%(set_name)) 
    _, img_name = get_files_in_folder(deploy_set, '.bmp')
    if len(img_name) == 0:
        deploy_set = deploy_set+'images/'
        _, img_name = get_files_in_folder(deploy_set, '.bmp')
    img_size = misc.imread(deploy_set+img_name[0]+'.bmp', mode='L').shape
    img_size = np.array(img_size, dtype=np.int32)/8*8      
    main_net_model = get_main_net((img_size[0],img_size[1],1), pretrain)
    _, img_name = get_files_in_folder(deploy_set, '.bmp')
    time_c = []
    for i in xrange(0,len(img_name)):
        logging.info("%s %d / %d: %s"%(set_name, i+1, len(img_name), img_name[i]))
        time_start = time()    
        image = misc.imread(deploy_set+img_name[i]+'.bmp', mode='L') / 255.0
        image = image[:img_size[0],:img_size[1]]      
        image = np.reshape(image,[1, image.shape[0], image.shape[1], 1])
        enhance_img, ori_out_1, ori_out_2, seg_out, mnt_o_out, mnt_w_out, mnt_h_out, mnt_s_out = main_net_model.predict(image) 
        time_afterconv = time()
        round_seg = np.round(np.squeeze(seg_out))
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5, 5))
        seg_out = cv2.morphologyEx(round_seg, cv2.MORPH_OPEN, kernel)
        mnt = label2mnt(np.squeeze(mnt_s_out)*np.round(np.squeeze(seg_out)), mnt_w_out, mnt_h_out, mnt_o_out, thresh=0.5)
        mnt_nms = nms(mnt)
        ori = sess.run(ori_highest_peak(ori_out_1))                           
        ori = (np.argmax(ori, axis=-1)*2-90)/180.*np.pi  
        time_afterpost = time()
        mnt_writer(mnt_nms, img_name[i], img_size, "%s/%s/%s.mnt"%(output_dir, set_name, img_name[i]))        
        draw_ori_on_img(image, ori, np.ones_like(seg_out), "%s/%s/%s_ori.png"%(output_dir, set_name, img_name[i]))        
        draw_minutiae(image, mnt_nms[:,:3], "%s/%s/%s_mnt.png"%(output_dir, set_name, img_name[i]))
        misc.imsave("%s/%s/%s_enh.png"%(output_dir, set_name, img_name[i]), np.squeeze(enhance_img)*ndimage.zoom(np.round(np.squeeze(seg_out)), [8,8], order=0))
        misc.imsave("%s/%s/%s_seg.png"%(output_dir, set_name, img_name[i]), ndimage.zoom(np.round(np.squeeze(seg_out)), [8,8], order=0)) 
        io.savemat("%s/%s/%s.mat"%(output_dir, set_name, img_name[i]), {'orientation':ori, 'orientation_distribution_map':ori_out_1})
        time_afterdraw = time()
        time_c.append([time_afterconv-time_start, time_afterpost-time_afterconv, time_afterdraw-time_afterpost])
        logging.info("load+conv: %.3fs, seg-postpro+nms: %.3f, draw: %.3f"%(time_c[-1][0],time_c[-1][1],time_c[-1][2]))
    time_c = np.mean(np.array(time_c),axis=0)
    logging.info("Average: load+conv: %.3fs, oir-select+seg-post+nms: %.3f, draw: %.3f"%(time_c[0],time_c[1],time_c[2]))
    return
项目:FingerNet    作者:felixTY    | 项目源码 | 文件源码
def deploy(deploy_set, set_name=None):
    if set_name is None:
        set_name = deploy_set.split('/')[-2]
    mkdir(output_dir+'/'+set_name+'/')
    logging.info("Predicting %s:"%(set_name)) 
    _, img_name = get_files_in_folder(deploy_set, '.bmp')
    if len(img_name) == 0:
        deploy_set = deploy_set+'images/'
        _, img_name = get_files_in_folder(deploy_set, '.bmp')
    img_size = misc.imread(deploy_set+img_name[0]+'.bmp', mode='L').shape
    img_size = np.array(img_size, dtype=np.int32)/8*8      
    main_net_model = get_main_net((img_size[0],img_size[1],1), pretrain)
    _, img_name = get_files_in_folder(deploy_set, '.bmp')
    time_c = []
    for i in xrange(0,len(img_name)):
        logging.info("%s %d / %d: %s"%(set_name, i+1, len(img_name), img_name[i]))
        time_start = time()    
        image = misc.imread(deploy_set+img_name[i]+'.bmp', mode='L') / 255.0
        image = image[:img_size[0],:img_size[1]]      
        image = np.reshape(image,[1, image.shape[0], image.shape[1], 1])
        enhance_img, ori_out_1, ori_out_2, seg_out, mnt_o_out, mnt_w_out, mnt_h_out, mnt_s_out = main_net_model.predict(image) 
        time_afterconv = time()
        round_seg = np.round(np.squeeze(seg_out))
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5, 5))
        seg_out = cv2.morphologyEx(round_seg, cv2.MORPH_OPEN, kernel)
        mnt = label2mnt(np.squeeze(mnt_s_out)*np.round(np.squeeze(seg_out)), mnt_w_out, mnt_h_out, mnt_o_out, thresh=0.5)
        mnt_nms = nms(mnt)
        ori = sess.run(ori_highest_peak(ori_out_1))                           
        ori = (np.argmax(ori, axis=-1)*2-90)/180.*np.pi  
        time_afterpost = time()
        mnt_writer(mnt_nms, img_name[i], img_size, "%s/%s/%s.mnt"%(output_dir, set_name, img_name[i]))        
        draw_ori_on_img(image, ori, np.ones_like(seg_out), "%s/%s/%s_ori.png"%(output_dir, set_name, img_name[i]))        
        draw_minutiae(image, mnt_nms[:,:3], "%s/%s/%s_mnt.png"%(output_dir, set_name, img_name[i]))
        misc.imsave("%s/%s/%s_enh.png"%(output_dir, set_name, img_name[i]), np.squeeze(enhance_img)*ndimage.zoom(np.round(np.squeeze(seg_out)), [8,8], order=0))
        misc.imsave("%s/%s/%s_seg.png"%(output_dir, set_name, img_name[i]), ndimage.zoom(np.round(np.squeeze(seg_out)), [8,8], order=0)) 
        io.savemat("%s/%s/%s.mat"%(output_dir, set_name, img_name[i]), {'orientation':ori, 'orientation_distribution_map':ori_out_1})
        time_afterdraw = time()
        time_c.append([time_afterconv-time_start, time_afterpost-time_afterconv, time_afterdraw-time_afterpost])
        logging.info("load+conv: %.3fs, seg-postpro+nms: %.3f, draw: %.3f"%(time_c[-1][0],time_c[-1][1],time_c[-1][2]))
    time_c = np.mean(np.array(time_c),axis=0)
    logging.info("Average: load+conv: %.3fs, oir-select+seg-post+nms: %.3f, draw: %.3f"%(time_c[0],time_c[1],time_c[2]))
    return