Python caffe.proto.caffe_pb2 模块,TRAIN 实例源码

我们从Python开源项目中,提取了以下12个代码示例,用于说明如何使用caffe.proto.caffe_pb2.TRAIN

项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def orth_loss_v2(self, bottom_name):
        # self.Python('orth_loss', 'orthLossLayer', loss_weight=1, bottom=[bottom_name], top=[name], name=name)
        # , bottom=[bottom+'_MVN']
        # save bottom
        mainpath = self.bottom

        bottom = bottom_name #'NormLayer', 
        # self.MVN(bottom=[bottom])
        layer = "TransposeLayer"
        layername = bottom_name+'_' + layer
        outputs = [layername]
        self.Python(layer, layer, top=outputs, bottom=[bottom], name=layername, phase='TRAIN')
        self.Matmul()

        outputs = [self.this.name]
        self.EuclideanLoss(name=bottom_name+'_euclidean', bottom=outputs, loss_weight=1e-1, phase='TRAIN')

        # restore bottom
        self.cur = mainpath
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def resnet(n=3, num_output = 16):
    """6n+2, n=3 9 18 coresponds to 20 56 110 layers"""    
    net_name = "resnet-"    
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    if n > 18:
        # warm up
        solver = Solver(solver_name="solver_warm.prototxt", folder=pt_folder, lr_policy=Solver.policy.fixed)
        solver.p.base_lr = 0.01
        solver.set_max_iter(500)
        solver.write()
        del solver

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.resnet_cifar(n, num_output=num_output)
    builder.write(folder=pt_folder)
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def resnet_orth_v2(n=3):
    """6n+2, n=3 9 18 coresponds to 20 56 110 layers"""    
    net_name = "resnet-orth-v2"    
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    if n > 18:
        # warm up
        solver = Solver(solver_name="solver_warm.prototxt", folder=pt_folder, lr_policy=Solver.policy.fixed)
        solver.p.base_lr = 0.01
        solver.set_max_iter(500)
        solver.write()
        del solver

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.resnet_cifar(n, orth=True, v2=True)
    builder.write(folder=pt_folder)
项目:pre-resnet-gen-caffe    作者:Cysu    | 项目源码 | 文件源码
def _get_include(phase):
    inc = caffe_pb2.NetStateRule()
    if phase == 'train':
        inc.phase = caffe_pb2.TRAIN
    elif phase == 'test':
        inc.phase = caffe_pb2.TEST
    else:
        raise ValueError("Unknown phase {}".format(phase))
    return inc
项目:resnet-cifar10-caffe    作者:yihui-he    | 项目源码 | 文件源码
def include(self, phase='TRAIN'):
        if phase is not None:
            includes = self.this.include.add()
            if phase == 'TRAIN':
                includes.phase = caffe_pb2.TRAIN
            elif phase == 'TEST':
                includes.phase = caffe_pb2.TEST
        else:
            NotImplementedError


    #************************** inplace **************************
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def include(self, phase='TRAIN'):
        if phase is not None:
            includes = self.this.include.add()
            if phase == 'TRAIN':
                includes.phase = caffe_pb2.TRAIN
            elif phase == 'TEST':
                includes.phase = caffe_pb2.TEST
        else:
            NotImplementedError


    #************************** inplace **************************
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def MVN(self, name=None, bottom=[], normalize_variance=True, across_channels=False, phase='TRAIN'):
        if across_channels:
            NotImplementedError
        if not normalize_variance:
            NotImplementedError
        self.setup(self.suffix('MVN', name),bottom=bottom, layer_type='MVN')
        if phase!='TRAIN':
            NotImplementedError
        self.include()
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def plain_func(self, name, num_output, up=False, **kwargs):
        self.conv_bn_relu(name+'_conv0', num_output=num_output, stride=1+int(up), **kwargs)
        self.conv_bn_relu(name+'_conv1', num_output=num_output, **kwargs)

    # def orth_loss(self, bottom_name):
    #     # self.Python('orth_loss', 'orthLossLayer', loss_weight=1, bottom=[bottom_name], top=[name], name=name)
    #     # , bottom=[bottom+'_MVN']

    #     # save bottom
    #     mainpath = self.bottom

    #     bottom = bottom_name #'NormLayer', 
    #     # self.MVN(bottom=[bottom])
    #     layer = "TransposeLayer"
    #     layername = bottom_name+'_' + layer
    #     outputs = [layername]#, bottom_name+'_zerolike']
    #     self.Python(layer, layer, top=outputs, bottom=[bottom], name=layername, phase='TRAIN')
    #     self.Matmul()
    #     # layer="diagLayer"
    #     # layername = bottom_name+'_' + layer

    #     # self.Python(layer, layer, top=[layername], name=layername, phase='TRAIN')
    #     outputs = [self.this.name]#, bottom_name+'_zerolike']
    #     self.EuclideanLoss(name=bottom_name+'_euclidean', bottom=outputs, loss_weight=1e-3, phase='TRAIN')

    #     # restore bottom
    #     self.cur = mainpath
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def plain(n=3):
    """6n+2, n=3 9 18 coresponds to 20 56 110 layers"""
    net_name = "plain"
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.plain_cifar(n, num_output = 16)
    builder.write(folder=pt_folder)
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def plain_orth(n=3):
    """6n+2, n=3 5 7 9 18 coresponds to 20 56 110 layers"""
    net_name = "plain-orth"
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.plain_cifar(n, orth=True)
    builder.write(folder=pt_folder)
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def plain_orth_v1(n=3):
    """6n+2, n=3 5 7 9 18 coresponds to 20 32 44 56 110 layers"""
    net_name = "plain-orth-v1-"
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.plain_cifar(n, orth=True, inplace=False, num_output = 16)
    builder.write(folder=pt_folder)
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def acc(n=3):
    """6n+2, n=3 9 18 coresponds to 20 56 110 layers"""
    net_name = "plain"
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.plain_cifar(n, num_output = 16, inplace=False)
    builder.write(folder=pt_folder)