Python caffe.proto.caffe_pb2 模块,Datum() 实例源码

我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用caffe.proto.caffe_pb2.Datum()

项目:score-zeroshot    作者:pedro-morgado    | 项目源码 | 文件源码
def loop_records(self, num_records=0, init_key=None):
        env = lmdb.open(self.fn, readonly=True)
        datum = Datum()
        with env.begin() as txn:
            cursor = txn.cursor()
            if init_key is not None:
                if not cursor.set_key(init_key):
                    raise ValueError('key ' + init_key + ' not found in lmdb ' + self.fn + '.')

            num_read = 0
            for key, value in cursor:
                datum.ParseFromString(value)
                label = datum.label
                data = datum_to_array(datum).squeeze()
                yield (data, label, key)
                num_read += 1
                if num_records != 0 and num_read == num_records:
                    break
        env.close()
项目:score-zeroshot    作者:pedro-morgado    | 项目源码 | 文件源码
def _add_record(self, data, label=None, key=None):
        data_dims = data.shape
        if data.ndim == 1:
            data_dims = np.array([data_dims[0], 1, 1], dtype=int)
        elif data.ndim == 2:
            data_dims = np.array([data_dims[0], data_dims[1], 1], dtype=int)

        datum = Datum()
        datum.channels, datum.height, datum.width = data_dims[0], data_dims[1], data_dims[2]
        if data.dtype == np.uint8:
            datum.data = data.tostring()
        else:
            datum.float_data.extend(data.tolist())
        datum.label = int(label) if label is not None else -1

        key = ('{:08}'.format(self.num) if key is None else key).encode('ascii')
        with self.env.begin(write=True) as txn:
            txn.put(key, datum.SerializeToString())
        self.num += 1
项目:fast-image-retrieval    作者:xueeinstein    | 项目源码 | 文件源码
def save_to_lmdb(images, labels, lmdb_file):
    if not os.path.exists(lmdb_file):
        batch_size = 256
        lmdb_env = lmdb.open(lmdb_file, map_size=int(1e12))
        lmdb_txn = lmdb_env.begin(write=True)
        item_id = 0
        datum = caffe_pb2.Datum()

        for i in range(images.shape[0]):
            im = cv2.imread(images[i])
            im = cv2.resize(im, (IM_HEIGHT, IM_WIDTH))
            datum.channels = im.shape[2]
            datum.height = im.shape[0]
            datum.width = im.shape[1]
            datum.data = im.tobytes()
            datum.label = labels[i]
            keystr = '{:0>8d}'.format(item_id)
            lmdb_txn.put(keystr, datum.SerializeToString())

            # write batch
            if (item_id + 1) % batch_size == 0:
                lmdb_txn.commit()
                lmdb_txn = lmdb_env.begin(write=True)
                print('converted {} images'.format(item_id + 1))

            item_id += 1

        # write last batch
        if (item_id + 1) % batch_size != 0:
            lmdb_txn.commit()
            print('converted {} images'.format(item_id + 1))
            print('Generated ' + lmdb_file)
    else:
        print(lmdb_file + ' already exists')
项目:fast-image-retrieval    作者:xueeinstein    | 项目源码 | 文件源码
def save_to_lmdb(images, labels, lmdb_file):
    if not os.path.exists(lmdb_file):
        batch_size = 256
        lmdb_env = lmdb.open(lmdb_file, map_size=int(1e12))
        lmdb_txn = lmdb_env.begin(write=True)
        item_id = 0
        datum = caffe_pb2.Datum()

        for i in range(images.shape[0]):
            im = cv2.imread(images[i])
            if im is None:
                continue
            im = cv2.resize(im, (IM_HEIGHT, IM_WIDTH))
            datum.channels = im.shape[2]
            datum.height = im.shape[0]
            datum.width = im.shape[1]
            datum.data = im.tobytes()
            datum.label = labels[i]
            keystr = '{:0>8d}'.format(item_id)
            lmdb_txn.put(keystr, datum.SerializeToString())

            # write batch
            if (item_id + 1) % batch_size == 0:
                lmdb_txn.commit()
                lmdb_txn = lmdb_env.begin(write=True)
                print('converted {} images'.format(item_id + 1))

            item_id += 1

        # write last batch
        if (item_id + 1) % batch_size != 0:
            lmdb_txn.commit()
            print('converted {} images'.format(item_id + 1))
            print('Generated ' + lmdb_file)
    else:
        print(lmdb_file + ' already exists')
项目:deeplearning-cats-dogs-tutorial    作者:adilmoujahid    | 项目源码 | 文件源码
def make_datum(img, label):
    #image is numpy.ndarray format. BGR instead of RGB
    return caffe_pb2.Datum(
        channels=3,
        width=IMAGE_WIDTH,
        height=IMAGE_HEIGHT,
        label=label,
        data=np.rollaxis(img, 2).tostring())
项目:Market1501-CVLab    作者:Lizw14    | 项目源码 | 文件源码
def make_datum(img, label):  
    #image is numpy.ndarray format. BGR instead of RGB  
    return caffe_pb2.Datum(  
        channels=3,  
        width=IMAGE_WIDTH,  
        height=IMAGE_HEIGHT,  
        label=label,
        data=np.transpose(img, (2, 0, 1)).tostring()) 
        # or .tobytes() if numpy < 1.9

# key = 0
# env = lmdb.open(img_lmdb_path, map_size=int(1e12))
# with env.begin(write=True) as txn:
#     for idx in xrange(numSample):
#         info = data[idx].split(" ")
#         OriImg = cv2.imread(datadir + info[0])
#         img = cv2.resize(OriImg,(IMAGE_WIDTH,IMAGE_HEIGHT))
#         label = int(info[1])
#         img = np.transpose(img, (2, 0, 1))
#         datum = caffe.io.array_to_datum(img, label)
#         key_str = '{:08}'.format(key)
# #        txn.put(key_str.encode('ascii'), datum.SerializeToString())
#         txn.put(key_str, datum.SerializeToString())
#         key += 1
#     for idx in xrange(numSample):
#         info = data[idx].split(" ")
#         OriImg = cv2.imread(datadir + info[0])
#         img = cv2.resize(OriImg,(IMAGE_WIDTH,IMAGE_HEIGHT))
#         label = int(info[1])
#         img = cv2.flip(img,1)
#         img = np.transpose(img, (2, 0, 1))
#         datum = caffe.io.array_to_datum(img, label)
#         key_str = '{:08}'.format(key)
# #        txn.put(key_str.encode('ascii'), datum.SerializeToString())
#         txn.put(key_str, datum.SerializeToString())
#         key += 1
# print key
项目:SpindleNet    作者:yokattame    | 项目源码 | 文件源码
def main(args):
  datum = Datum()
  data = []
  env = lmdb.open(args.input_lmdb)
  with env.begin() as txn:
    cursor = txn.cursor()
    for i, (key, value) in enumerate(cursor):
      if i >= args.truncate:
        break
      datum.ParseFromString(value)
      data.append(datum.float_data)
  data = np.squeeze(np.asarray(data))
  np.save(args.output_npy, data)
项目:GitImpact    作者:ludovicdmt    | 项目源码 | 文件源码
def make_datum(img, label):
    #image is numpy.ndarray format. BGR instead of RGB
    return caffe_pb2.Datum(
        channels=1, # images are in black and white 
        width=IMAGE_WIDTH,
        height=IMAGE_HEIGHT,
        label=label,
        data=img.tostring())