Python caffe 模块,io() 实例源码

我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用caffe.io()

项目:temporal-segment-networks    作者:yjxiong    | 项目源码 | 文件源码
def __init__(self, net_proto, net_weights, device_id, input_size=None):
        caffe.set_mode_gpu()
        caffe.set_device(device_id)
        self._net = caffe.Net(net_proto, net_weights, caffe.TEST)

        input_shape = self._net.blobs['data'].data.shape

        if input_size is not None:
            input_shape = input_shape[:2] + input_size

        transformer = caffe.io.Transformer({'data': input_shape})

        if self._net.blobs['data'].data.shape[1] == 3:
            transformer.set_transpose('data', (2, 0, 1))  # move image channels to outermost dimension
            transformer.set_mean('data', np.array([104, 117, 123]))  # subtract the dataset-mean value in each channel
        else:
            pass # non RGB data need not use transformer

        self._transformer = transformer

        self._sample_shape = self._net.blobs['data'].data.shape
项目:Video-Classification-Action-Recognition    作者:qijiezhao    | 项目源码 | 文件源码
def __init__(self, net_proto, net_weights, device_id, input_size=None):
        caffe.set_mode_gpu()
        caffe.set_device(device_id)
        self._net = caffe.Net(net_proto, net_weights, caffe.TEST)

        input_shape = self._net.blobs['data'].data.shape

        if input_size is not None:
            input_shape = input_shape[:2] + input_size

        transformer = caffe.io.Transformer({'data': input_shape})

        if self._net.blobs['data'].data.shape[1] == 3:
            transformer.set_transpose('data', (2, 0, 1))  # move image channels to outermost dimension
            transformer.set_mean('data', np.array([104, 117, 123]))  # subtract the dataset-mean value in each channel
        else:
            pass # non RGB data need not use transformer

        self._transformer = transformer

        self._sample_shape = self._net.blobs['data'].data.shape
项目:Caffe-Python-Data-Layer    作者:liuxianming    | 项目源码 | 文件源码
def set_mean(self):
        if self._mean_file:
            if type(self._mean_file) is str:
                # read image mean from file
                try:
                    # if it is a pickle file
                    self._mean = np.load(self._mean_file)
                except (IOError):
                    blob = caffe_pb2.BlobProto()
                    blob_str = open(self._mean_file, 'rb').read()
                    blob.ParseFromString(blob_str)
                    self._mean = np.array(caffe.io.blobproto_to_array(blob))[0]
            else:
                self._mean = self._mean_file
                self._mean = np.array(self._mean)
        else:
            self._mean = None
项目:anet2016-cuhk    作者:yjxiong    | 项目源码 | 文件源码
def __init__(self, net_proto, net_weights, device_id, input_size=None):
        caffe.set_mode_gpu()
        caffe.set_device(device_id)
        self._net = caffe.Net(net_proto, net_weights, caffe.TEST)

        input_shape = self._net.blobs['data'].data.shape

        if input_size is not None:
            input_shape = input_shape[:2] + input_size

        transformer = caffe.io.Transformer({'data': input_shape})

        if self._net.blobs['data'].data.shape[1] == 3:
            transformer.set_transpose('data', (2, 0, 1))  # move image channels to outermost dimension
            transformer.set_mean('data', np.array([104, 117, 123]))  # subtract the dataset-mean value in each channel
        else:
            pass # non RGB data need not use transformer

        self._transformer = transformer

        self._sample_shape = self._net.blobs['data'].data.shape
项目:tc_tripletloss    作者:abby621    | 项目源码 | 文件源码
def getFeats(ims,net,feat_layer):
    net.blobs['data'].reshape(len(ims),3,227,227)
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_mean('data', IM_MEAN)
    transformer.set_transpose('data', (2,0,1))
    transformer.set_channel_swap('data', (2,1,0))
    transformer.set_raw_scale('data', 255.0)
    caffe_input = np.empty((len(ims),3,227,227))
    for ix in range(len(ims)):
        caffe_input[ix,:,:,:] = transformer.preprocess('data',caffe.io.load_image(ims[ix]))
    net.blobs['data'].data[...] = caffe_input
    out = net.forward()
    feat = net.blobs[feat_layer].data.copy()
    return feat
项目:Caffe-Python-Data-Layer    作者:liuxianming    | 项目源码 | 文件源码
def load_all(self):
        """The function to load all data and labels

        Give:
        data: the list of raw data, needs to be decompressed
              (e.g., raw JPEG string)
        labels: 0-based labels, in format of numpy array
        """
        start = time.time()
        print("Start Loading Data from CSV File {}".format(
            self._source_fn))
        try:
            db_ = lmdb.open(self._source_fn)
            data_cursor_ = db_.begin().cursor()
            if self._label_fn:
                label_db_ = lmdb.open(self._label_fn)
                label_cursor_ = label_db_.begin().cursor()
            # begin reading data
            if self._label_fn:
                label_cursor_.first()
            while data_cursor_.next():
                value_str = data_cursor_.value()
                datum_ = caffe_pb2.Datum()
                datum_.ParseFromString(value_str)
                self._data.append(datum_.data)
                if self._label_fn:
                    label_cursor_.next()
                    label_datum_ = caffe_pb2.Datum()
                    label_datum_.ParseFromString(label_cursor_.value())
                    label_ = caffe.io.datum_to_array(label_datum_)
                    label_ = ":".join([str(x) for x in label_.astype(int)])
                else:
                    label_ = str(datum_.label)
                self._labels.appen(label_)
            # close all db
            db_.close()
            if self._label_fn:
                label_db_.close()
        except:
            raise Exception("Error in Parsing input file")
        end = time.time()
        self._labels = np.array(self._labels)
        print("Loading {} samples Done: Time cost {} seconds".format(
            len(self._data), end - start))

        return self._data, self._labels