Python lmdb 模块,open() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用lmdb.open()

项目:ml-pyxis    作者:vicolab    | 项目源码 | 文件源码
def __init__(self, dirpath, map_size_limit, ram_gb_limit=2):
        self.map_size_limit = int(map_size_limit)  # Megabytes (MB)
        self.ram_gb_limit = float(ram_gb_limit)  # Gigabytes (GB)
        self.keys = []
        self.nb_samples = 0

        # Minor sanity checks
        if self.map_size_limit <= 0:
            raise ValueError('The LMDB map size must be positive: '
                             '{}'.format(self.map_size_limit))
        if self.ram_gb_limit <= 0:
            raise ValueError('The RAM limit (GB) per write must be '
                             'positive: {}'.format(self.ram_gb_limit))

        # Convert `map_size_limit` from MB to B
        map_size_limit <<= 20

        # Open LMDB environment
        self._lmdb_env = lmdb.open(dirpath,
                                   map_size=map_size_limit,
                                   max_dbs=NB_DBS)

        # Open the default database(s) associated with the environment
        self.data_db = self._lmdb_env.open_db(DATA_DB)
        self.meta_db = self._lmdb_env.open_db(META_DB)
项目:crnn    作者:wulivicte    | 项目源码 | 文件源码
def __init__(self, root=None, transform=None, target_transform=None):
        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False)

        if not self.env:
            print('cannot creat lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            nSamples = int(txn.get('num-samples'))
            self.nSamples = nSamples

        self.transform = transform
        self.target_transform = target_transform
项目:sceneReco    作者:bear63    | 项目源码 | 文件源码
def __init__(self, root=None, transform=None, target_transform=None):
        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False)

        if not self.env:
            print('cannot creat lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            nSamples = int(txn.get('num-samples'))
            self.nSamples = nSamples

        self.transform = transform
        self.target_transform = target_transform
项目:sceneReco    作者:bear63    | 项目源码 | 文件源码
def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = str(txn.get(label_key))
            if self.target_transform is not None:
                label = self.target_transform(label)

        return (img, label)
项目:FCN-VOC2012-Training-Config    作者:voidrank    | 项目源码 | 文件源码
def gen_input(lmdbname, file_list):
    X = np.zeros((len(file_list), 3, HEIGHT, WIDTH), dtype=np.float32)
    map_size = X.nbytes * 5

    env = lmdb.open(lmdbname, map_size=map_size)

    count = 0
    for i in file_list:
        print count
        with env.begin(write=True) as txn:
            filename = os.path.join(DIR, "JPEGImages", i + ".jpg")
            m = np.asarray(Image.open(filename)).transpose((2, 0, 1))
            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = m.shape[0]
            datum.height = m.shape[1]
            datum.width = m.shape[2]
            datum.data = m.tobytes()
            str_id = i
            txn.put(str_id.encode("ascii"), datum.SerializeToString())
            count += 1
项目:FCN-VOC2012-Training-Config    作者:voidrank    | 项目源码 | 文件源码
def gen_output(lmdbname, file_list):
    X = np.zeros((len(file_list), 1, HEIGHT, WIDTH), dtype=np.uint8)
    map_size = X.nbytes * 3

    env = lmdb.open(lmdbname, map_size=map_size)

    count = 0
    for i in file_list:
        print count
        with env.begin(write=True) as txn:
            filename = os.path.join(DIR, "SegmentationClass", i + ".png")
            m = deepcopy(np.asarray(Image.open(filename)))
            for x in range(m.shape[0]):
                for y in range(m.shape[1]):
                    if m[x][y] == 255:
                        m[x][y] = 0
            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = 1
            datum.height = m.shape[0]
            datum.width = m.shape[1]
            datum.data = m.tobytes()
            str_id = i
            txn.put(str_id.encode("ascii"), datum.SerializeToString())
            count += 1
项目:train-CRF-RNN    作者:martinkersner    | 项目源码 | 文件源码
def split_train_test_imgs(class_names, test_ratio):
  train_imgs = []
  test_imgs = []

  for i in class_names:
    file_name = i + '.txt' 
    num_lines = get_num_lines(file_name)
    num_test_imgs = test_ratio * num_lines
    current_line = 1

    with open(file_name, 'rb') as f:
      for line in f:
        if current_line < num_test_imgs:
          test_imgs.append(line.strip())
        else:
          train_imgs.append(line.strip())

        current_line += 1

  print(str(len(train_imgs)) + ' train images')
  print(str(len(test_imgs)) + ' test images')

  return train_imgs, test_imgs
项目:train-CRF-RNN    作者:martinkersner    | 项目源码 | 文件源码
def convert2lmdb(path_src, src_imgs, ext, path_dst, class_ids, preprocess_mode, im_sz, data_mode):
  if os.path.isdir(path_dst):
    print('DB ' + path_dst + ' already exists.\n'
          'Skip creating ' + path_dst + '.', file=sys.stderr)
    return None

  if data_mode == 'label':
    lut = create_lut(class_ids)

  db = lmdb.open(path_dst, map_size=int(1e12))

  with db.begin(write=True) as in_txn:
    for idx, img_name in enumerate(src_imgs):
      #img = imread(os.path.join(path_src + img_name)+ext)
      img = np.array(Image.open(os.path.join(path_src + img_name)+ext))
      img = img.astype(np.uint8)

      if data_mode == 'label':
        img = preprocess_label(img, lut, preprocess_mode, im_sz)
      elif data_mode == 'image':
        img = preprocess_image(img, preprocess_mode, im_sz)

      img_dat = caffe.io.array_to_datum(img)
      in_txn.put('{:0>10d}'.format(idx), img_dat.SerializeToString())
项目:score-zeroshot    作者:pedro-morgado    | 项目源码 | 文件源码
def loop_records(self, num_records=0, init_key=None):
        env = lmdb.open(self.fn, readonly=True)
        datum = Datum()
        with env.begin() as txn:
            cursor = txn.cursor()
            if init_key is not None:
                if not cursor.set_key(init_key):
                    raise ValueError('key ' + init_key + ' not found in lmdb ' + self.fn + '.')

            num_read = 0
            for key, value in cursor:
                datum.ParseFromString(value)
                label = datum.label
                data = datum_to_array(datum).squeeze()
                yield (data, label, key)
                num_read += 1
                if num_records != 0 and num_read == num_records:
                    break
        env.close()
项目:mediachain-indexer    作者:mediachain    | 项目源码 | 文件源码
def verify_img(buf):
    """
    Verify image.
    """

    sbuf = StringIO(buf)

    try:
        ## Basic check:
        img = Image.open(sbuf)
        img.verify()

        ## Detect truncated:
        img = Image.open(sbuf)
        img.load()
    except KeyboardInterrupt:
        raise
    except:
        print ('VERIFY_IMG_FAILED', buf[:100])
        return False
    return img
项目:mediachain-indexer    作者:mediachain    | 项目源码 | 文件源码
def verify_img(buf):
    """
    Verify image.
    """
    from PIL import Image
    from cStringIO import StringIO

    sbuf = StringIO(buf)

    try:
        ## Basic check:
        img = Image.open(sbuf)
        img.verify()

        ## Detect truncated:
        img = Image.open(sbuf)
        img.load()
    except KeyboardInterrupt:
        raise
    except:
        return False
    return True
项目:pytorch_crowd_count    作者:BingzheWu    | 项目源码 | 文件源码
def process_dump_tohdf5data(X,Y, path, phase):

    batch_size = 7000
    X_process = np.zeros((batch_size, 3, patch_h, patch_w), dtype = np.float32)
    Y_process = np.zeros((batch_size, net_density_h, net_density_w), dtype = np.float32)
    with open(os.path.join(path, phase+'.txt'), 'w') as f:
        i1 = 0
        while i1 < len(X):
            if i1+batch_size < len(X):
                i2 = i1 + batch_size
            else:
                i2 = len(X)
            file_name = os.path.join(path, phase+'_'+str(i1)+'.h5')
            with h5py.File(file_name, 'w') as hf:
                for j, img in enumerate(X[i1:i2]):
                    X_process[j] = img.copy().transpose(2,0,1).astype(np.float32)
                    Y_process[j] = density_resize(Y[i1+j], fx = float(net_density_w)/patch_w, fy = float(net_density_h) / patch_h)
                hf['data'] = X_process[:(i2-i1)]
                hf['label'] = Y_process[:(i2-i1)]
            f.write(file_name+'\n')
            i1 += batch_size
项目:pytorch_crowd_count    作者:BingzheWu    | 项目源码 | 文件源码
def read_lmdb(lmdb_path):
    env = lmdb.open(lmdb_path)
    with env.begin() as txn:
        cursor = txn.cursor()
        for (idx, (key, value)) in enumerate(cursor):
            image = np.fromstring(value, dtype = np.float32)
            #image = np.reshape(image, (3,225,225))/255.0
            image = np.reshape(image, (27, 27))
            #image = image.transpose((1,2,0))
            print(image)
            plt.imshow(image, cmap = 'hot')
            plt.show()
            break
        #image = txn.get('0')
        #image = np.fromstring(image)[0]
        #print image.shape
项目:pytorch_crowd_count    作者:BingzheWu    | 项目源码 | 文件源码
def __init__(self, lmdb_image_datapath, lmdb_label_datapath):
                super(UCF_CC_50, self).__init__()
        self.lmdb_image_datapath = lmdb_image_datapath
        self.lmdb_label_datapath = lmdb_label_datapath
        self.images = []
        self.gts = []
        self.total_patches = 0
        self.limits = []
        self.num_files = 0
        self.file_list = []
        self.env_image = lmdb.open(self.lmdb_image_datapath)
        self.env_label = lmdb.open(self.lmdb_label_datapath)
        self.txn_image = self.env_image.begin()
        self.txn_label = self.env_label.begin()
        self.cursor_image = iter(self.txn_image.cursor())
        self.cursor_label = self.txn_label.cursor()
项目:Triplet_Loss_SBIR    作者:TuBui    | 项目源码 | 文件源码
def read(self, in_path):
    """
    read lmdb, return image data and label
    """
    print 'Reading ' + in_path
    env = lmdb.open(in_path, readonly=True)
    N = env.stat()['entries']
    txn = env.begin()
    for i in range(N):
      str_id = '{:08}'.format(i)
      raw_datum = txn.get(str_id)
      datum = caffe.proto.caffe_pb2.Datum()
      datum.ParseFromString(raw_datum)
      feature = caffe.io.datum_to_array(datum)
      if i==0:
        data = np.zeros((N,feature.shape[0],feature.shape[1],
                         feature.shape[2]),dtype=np.uint8)
        label = np.zeros(N,dtype=np.int64)
      data[i] = feature
      label[i] = datum.label
    env.close()
    return data, label
项目:pytorch-yolo2    作者:marvis    | 项目源码 | 文件源码
def __init__(self, lmdb_root, shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0):
        self.env = lmdb.open(lmdb_root,
                 max_readers=1,
                 readonly=True,
                 lock=False,
                 readahead=False,
                 meminit=False)
        self.txn = self.env.begin(write=False) 
        self.nSamples = int(self.txn.get('num-samples'))
        self.indices = range(self.nSamples) 
        if shuffle:
            random.shuffle(self.indices)

        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.shape = shape
        self.seen = seen
        #if self.train:
        #    print('init seen to %d' % (self.seen))
项目:DeepID2    作者:chenzeyuczy    | 项目源码 | 文件源码
def detect_lmdb(path):
    env = lmdb.open(path, readonly=False)
    print "Info of lmdb at", path
    for key, value in env.stat().items():
        print key, ":", value

    datum = caffe.proto.caffe_pb2.Datum()
    with env.begin() as txn:
        cursor = txn.cursor()
        cursor.next()
        key, value = cursor.key(), cursor.value()
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        print "Data shape:", data.shape
    env.close()
项目:sun-bcnn    作者:utiasSTARS    | 项目源码 | 文件源码
def readGroundTruth(datasetTxtFilepath):
    sunDirList = []
    imageFileNames = []
    with open(datasetTxtFilepath) as f:
        for line in f:
            lineItems = line.split()
            fname = lineItems[0]

            sunDir = lineItems[1:4]
            sunDir = [float(i) for i in sunDir]

            if azZenTarget:
                sunAzZen = [0, 0]
                sunAzZen[0] = math.degrees(math.atan2(sunDir[0], sunDir[2]))
                sunAzZen[1] = math.degrees(math.acos(-sunDir[1]))
                sunDirList.append(sunAzZen)
            else:
                sunDirList.append(sunDir)
            imageFileNames.append(fname)

    return sunDirList, imageFileNames
项目:brain-tumor    作者:voidrank    | 项目源码 | 文件源码
def make_lmdb_input(lmdbname, channel_directories, range_set):

    X = np.zeros((len(range_set), len(channel_directories), WIDTH, HEIGHT), dtype=np.double)
    map_size = X.nbytes * 10

    env = lmdb.open(lmdbname, map_size=map_size)

    count = 0
    for i in range_set:
        with env.begin(write=True) as txn:
            filename = str(i) + ".png"
            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = X.shape[1]
            datum.height = X.shape[2]
            datum.width = X.shape[3]
            for j in range(len(channel_directories)):
                dirname = channel_directories[j]
                X[count][j] = np.asarray(PIL.Image.open(os.path.join(dirname, filename)), dtype=np.double)
            datum.data = X[count].tobytes()
            str_id = '{:08}'.format(count)
            txn.put(str_id.encode("ascii"), datum.SerializeToString())
            count += 1
项目:caffe-materials    作者:kyehyeon    | 项目源码 | 文件源码
def write_lmdb(db_path, list_filename, height, width):
  map_size = 9999999999
  db = lmdb.open(db_path, map_size=map_size)
  writer = db.begin(write=True)
  datum = caffe.proto.caffe_pb2.Datum()
  for index, line in enumerate(open(list_filename, 'r')):
    img_filename, label = line.strip().split(' ')
    img = cv2.imread(img_filename, 1)
    img = cv2.resize(img, (height, width))
    _, img_jpg = cv2.imencode('.jpg', img)
    datum.channels = 3
    datum.height = height
    datum.width = width
    datum.label = int(label)
    datum.encoded = True
    datum.data = img_jpg.tostring()
    datum_byte = datum.SerializeToString()
    index_byte = '%010d' % index
    writer.put(index_byte, datum_byte, append=True)
  writer.commit()
  db.close()
项目:SSD-Keras_Tensorflow    作者:jedol    | 项目源码 | 文件源码
def __init__(self, source, batch_size, shuffle=True, use_prefetch=True, capacity=32):
        ## open LMDB
        self.env = lmdb.open(source, readonly=True)
        self.txn = self.env.begin()
        self.cur = self.txn.cursor()

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.use_prefetch = use_prefetch
        self.capacity = capacity
        self.num_data = int(self.txn.stat()['entries'])

        self.reset_inds()

        if self.use_prefetch:
            self.batch_queue = Queue(capacity)
            self.proc = Process(target=self._worker)
            self.proc.start()
            def cleanup():
                self.proc.terminate()
                self.proc.join()
            import atexit
            atexit.register(cleanup)
项目:hyperband_benchmarks    作者:lishal    | 项目源码 | 文件源码
def make_test():
    print 'Loading Matlab data.'
    f = '/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mnist_rotation_back_image_new/mnist_all_background_images_rotation_normalized_test.amat'

    # name of your matlab variables:

    X,Y=get_data(f)
    N = Y.shape[0]
    map_size = X.nbytes*2
    #if you want to shuffle your data
    #random.shuffle(N)
    env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mrbi_test', map_size=map_size)
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i in range(N):
            im_dat = caffe.io.array_to_datum(X[i],Y[i])
            txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
项目:hyperband_benchmarks    作者:lishal    | 项目源码 | 文件源码
def view_lmdb_data():
    lmdb_env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/svhn/svhn_train/')
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()
    x=[]
    y=[]

    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #plt.imshow(np.rollaxis(data,0,3))
        x.append(data)
        y.append(label)
    print len(y)
项目:hyperband_benchmarks    作者:lishal    | 项目源码 | 文件源码
def view_lmdb_data():
    lmdb_env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/svhn/svhn_train/')
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()
    x=[]
    y=[]

    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #plt.imshow(np.rollaxis(data,0,3))
        x.append(data)
        y.append(label)
    print len(y)
项目:pytorch-playground    作者:aaron-xichen    | 项目源码 | 文件源码
def load_lmdb(lmdb_file, n_records=None):
    import lmdb
    import numpy as np
    lmdb_file = expand_user(lmdb_file)
    if os.path.exists(lmdb_file):
        data = []
        env = lmdb.open(lmdb_file, readonly=True, max_readers=512)
        with env.begin() as txn:
            cursor = txn.cursor()
            begin_st = time.time()
            print("Loading lmdb file {} into memory".format(lmdb_file))
            for key, value in cursor:
                _, target, _ = key.decode('ascii').split(':')
                target = int(target)
                img = cv2.imdecode(np.fromstring(value, np.uint8), cv2.IMREAD_COLOR)
                data.append((img, target))
                if n_records is not None and len(data) >= n_records:
                    break
        env.close()
        print("=> Done ({:.4f} s)".format(time.time() - begin_st))
        return data
    else:
        print("Not found lmdb file".format(lmdb_file))
项目:sceneReco    作者:yijiuzai    | 项目源码 | 文件源码
def __init__(self, root=None, transform=None, target_transform=None):
        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False)

        if not self.env:
            print('cannot creat lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            nSamples = int(txn.get('num-samples'))
            self.nSamples = nSamples

        self.transform = transform
        self.target_transform = target_transform
项目:sceneReco    作者:yijiuzai    | 项目源码 | 文件源码
def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = str(txn.get(label_key))
            if self.target_transform is not None:
                label = self.target_transform(label)

        return (img, label)
项目:phocnet    作者:ssudholt    | 项目源码 | 文件源码
def open_single_lmdb_for_write(self, lmdb_path, max_lmdb_size=1024**4, create=True, label_map=None):
        '''
        Opens a single LMDB for inserting ndarrays (i.e. images)

        Args:
            lmdb_path (str): Where to save the LMDB
            max_lmdb_size (int): The maximum size in bytes of the LMDB (default: 1TB)
            create (bool):  If this flag is set, a potentially previously created LMDB at lmdb_path
                            is deleted and overwritten by this new LMDB
            label_map (dictionary): If you supply a dictionary mapping string labels to integer indices, you can later
                                    call put_single with string labels instead of int labels
        '''
        # delete existing LMDB if necessary
        if os.path.exists(lmdb_path) and create:
            self.logger.debug('Erasing previously created LMDB at %s', lmdb_path)
            shutil.rmtree(lmdb_path)
        self.logger.info('Opening single LMDB at %s for writing', lmdb_path)
        self.database_images = lmdb.open(path=lmdb_path, map_size=max_lmdb_size)
        self.txn_images = self.database_images.begin(write=True)
        self.label_map = label_map
项目:ternarynet    作者:czhu95    | 项目源码 | 文件源码
def __init__(self, lmdb_dir, shuffle=True):
        self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False,
                map_size=1099511627776 * 2, max_readers=100)
        self._txn = self._lmdb.begin()
        self._shuffle = shuffle
        self._size = self._txn.stat()['entries']
        if shuffle:
            self.keys = self._txn.get('__keys__')
            if not self.keys:
                self.keys = []
                with timed_operation("Loading LMDB keys ...", log_start=True), \
                        tqdm(total=self._size, ascii=True) as pbar:
                    for k in self._txn.cursor():
                        if k != '__keys__':
                            self.keys.append(k)
                            pbar.update()
项目:visimportance    作者:cvzoya    | 项目源码 | 文件源码
def load_label(maindir, idx, split):
    """
    Load label image as 1 x height x width integer array of label indices.
    The leading singleton dimension is required by the loss.
    """

    if split=='train':
        im = Image.open('{}/GDI/gd_imp_train/{}.png'.format(maindir, idx)) 
    else:
        im = Image.open('{}/GDI/gd_imp_val/{}.png'.format(maindir, idx))

    label = np.array(im, dtype=np.uint8) 
    label = label/255.0

    label = label[np.newaxis, ...]
    return label
项目:braid    作者:Arya-ai    | 项目源码 | 文件源码
def load_lmdb_dataset(db_name):
    X = []
    Y = []
    env = lmdb.open(db_name, readonly=True)
    with env.begin() as txn:
        cursor = txn.cursor()
        for key, value in cursor:
            datum = Datum()
            datum.ParseFromString(value)
            flat_x = np.fromstring(datum.data, dtype=np.uint8)
            X.append(flat_x.reshape(datum.channels, datum.height, datum.width))
            Y.append(datum.label)
    env.close()
    assert len(X) == len(Y)
    X = np.asarray(X)
    Y = np.asarray(Y)
    return X, Y
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y)
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:CaffeSVD    作者:wkcn    | 项目源码 | 文件源码
def read_db(db_name):
    lmdb_env = lmdb.open(db_name)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    X = []
    y = []
    cnts = {}
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        #data = data.swapaxes(0, 2).swapaxes(0, 1)
        X.append(data)
        y.append(label)
        if label not in cnts:
            cnts[label] = 0
        cnts[label] += 1
        #plt.imshow(data)
        #plt.show()
    return X, np.array(y), cnts
项目:video-tools    作者:achalddave    | 项目源码 | 文件源码
def load_image_datum(image_path, resize_height=None, resize_width=None):
    """Load an image in a Caffe datum in BGR order.

    Args:
        image_path (str): Path to an image.
        resize_height (int): Height to resize an image to. If 0 or None, the
            image is not resized.
        resize_width (int): Width to resize an image to. If 0 or None, the
            image is not resized.

    Returns:
        image_datum (caffe Datum): Contains the image in BGR order after
            resizing.
    """
    image = Image.open(image_path)
    if resize_height and resize_width:
        image = image.resize((resize_width, resize_height))
    # Image has shape (height, width, num_channels), where the
    # channels are in RGB order.
    image = np.array(image)
    # Convert image from RGB to BGR.
    image = image[:, :, ::-1]
    # Convert image to (num_channels, height, width) shape.
    image = image.transpose((2, 0, 1))
    return caffe.io.array_to_datum(image).SerializeToString()
项目:video-tools    作者:achalddave    | 项目源码 | 文件源码
def dump_one_lmdb(path, offset):
    with lmdb.open(path, map_size=map_size) as env, \
            env.begin().cursor() as lmdb_cursor:
        num_entries = env.stat()['entries']
        # Unfortunately, it seems the only way to set the cursor to an
        # arbitrary key index (without knowing the key) is to literally call
        # next() repeatedly.
        lmdb_cursor.next()
        for i in tqdm(range(offset)):
            lmdb_cursor.next()
        video_frame = video_frames_pb2.LabeledVideoFrame()
        video_frame.ParseFromString(lmdb_cursor.value())
        image_proto = video_frame.frame.image
        image = np.fromstring(image_proto.data,
                              dtype=np.uint8).reshape(
                                  image_proto.channels, image_proto.height,
                                  image_proto.width).transpose((1, 2, 0))
        image = Image.fromarray(image, 'RGB')
        image.save('tmp.png')
        print(lmdb_cursor.key())
        print(', '.join([label.name for label in video_frame.label]))
项目:video-tools    作者:achalddave    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'labels_hdf5',
        help=('Maps video names to a binary matrix of shape (num_frames, '
              'num_labels).'))
    parser.add_argument('output_lmdb')

    args = parser.parse_args()

    map_size = 2e9

    lmdb_environment = lmdb.open(args.output_lmdb, map_size=int(map_size))
    with lmdb_environment.begin(write=True) as lmdb_transaction, h5py.File(
            args.labels_hdf5) as labels:
        for video_name, file_labels in tqdm(labels.items()):
            file_labels = np.asarray(file_labels)
            for frame_number, frame_labels in enumerate(file_labels):
                key = '{}-{}'.format(video_name, frame_number + 1)
                lmdb_transaction.put(key, frame_labels.tobytes())
项目:video-tools    作者:achalddave    | 项目源码 | 文件源码
def load_image(image_path, resize_height=None, resize_width=None):
    """Load an image in video_frames.Image format.

    Args:
        image_path (str): Path to an image.
        resize_height (int): Height to resize an image to. If 0 or None, the
            image is not resized.
        resize_width (int): Width to resize an image to. If 0 or None, the
            image is not resized.

    Returns:
        image_datum (numpy array): Contains the image in BGR order after
            resizing.
    """
    image = Image.open(image_path)
    if resize_height and resize_width:
        image = image.resize((resize_width, resize_height))
    # Image has shape (height, width, num_channels), where the
    # channels are in RGB order.
    image = np.array(image)
    # Convert image from RGB to BGR.
    image = image[:, :, ::-1]
    # Convert image to (num_channels, height, width) shape.
    image = image.transpose((2, 0, 1))
    return image
项目:confusion    作者:abhimanyudubey    | 项目源码 | 文件源码
def write_prototxt(output_file,prototxt):
    ''' Write prototxt to file.
        Usage: write_prototxt(output_file,prototxt_dictionary) '''
    with open(output_file,'w') as f:
        if 'name' in prototxt.keys():
            f.write('name: \"'+prototxt['name']+'\" \n')
        if 'layer' in prototxt.keys():
            f.write(get_prototxt_string(prototxt['layer'],0,'layer').replace('\n\n','\n'))
项目:crnn    作者:wulivicte    | 项目源码 | 文件源码
def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                img = Image.open(buf).convert('L')
            except IOError:
                print('Corrupted image for %d' % index)
                return self[index + 1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = str(txn.get(label_key))

            if self.target_transform is not None:
                label = self.target_transform(label)

        return (img, label)
项目:ml-pyxis    作者:vicolab    | 项目源码 | 文件源码
def __init__(self, dirpath):
        # Open LMDB environment in read-only mode
        self._lmdb_env = lmdb.open(dirpath, readonly=True, max_dbs=NB_DBS)

        # Open the default database(s) associated with the environment
        self.data_db = self._lmdb_env.open_db(DATA_DB)
        self.meta_db = self._lmdb_env.open_db(META_DB)

        # Read the metadata
        self.nb_samples = int(self.get_meta_str(NB_SAMPLES))
项目:ml-pyxis    作者:vicolab    | 项目源码 | 文件源码
def close(self):
        """Close the environment.

        Invalidates any open iterators, cursors, and transactions.
        """
        self._lmdb_env.close()
项目:ml-pyxis    作者:vicolab    | 项目源码 | 文件源码
def close(self):
        """Close the environment.

        Before closing, the number of samples is written to `meta_db`.

        Invalidates any open iterators, cursors, and transactions.
        """
        self.set_meta_str(NB_SAMPLES, self.nb_samples)
        self._lmdb_env.close()