Python tensorflow.python.util.compat 模块,as_bytes() 实例源码

我们从Python开源项目中,提取了以下29个代码示例,用于说明如何使用tensorflow.python.util.compat.as_bytes()

项目:lsdc    作者:febert    | 项目源码 | 文件源码
def testBasic(self):
    base_path = tf.test.test_src_dir_path(
        "contrib/session_bundle/example/half_plus_two/00000123")
    tf.reset_default_graph()
    sess, meta_graph_def = session_bundle.load_session_bundle_from_path(
        base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))

    self.assertTrue(sess)
    asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
    with sess.as_default():
      path1, path2 = sess.run(["filename1:0", "filename2:0"])
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)

      collection_def = meta_graph_def.collection_def

      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)

      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      self._checkRegressionSignature(signatures, sess)
      self._checkNamedSigantures(signatures, sess)
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def main():
  global writer
  config = load_config()

  # todo: factor out common logic
  logdir = os.environ["LOGDIR"]
  writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))

  if  config.task_type == 'worker':
    run_worker()
  elif config.task_type == 'ps':
    run_ps()
  else:
    assert False, "Unknown task type "+str(config.task_type)

  writer.Close()
项目:taas-examples    作者:caicloud    | 项目源码 | 文件源码
def _write_assets(assets_directory, assets_filename):
  """??????? hall_plus_two ???????????

  Args:
    - assets_directory: ?????????
    - assets_filename: ???????
  Returns:
  ????????
  """
  if not file_io.file_exists(assets_directory):
    file_io.recursive_create_dir(assets_directory)

  path = os.path.join(
    compat.as_bytes(assets_directory),
    compat.as_bytes(assets_filename))
  file_io.write_string_to_file(path, "asset-file-contents")
  return path
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _maybe_export(self, eval_result):  # pylint: disable=unused-argument
    """Export the Estimator using export_fn, if defined."""
    export_dir_base = os.path.join(
        compat.as_bytes(self._estimator.model_dir),
        compat.as_bytes("export"))

    export_results = []
    for strategy in self._export_strategies:
      # TODO(soergel): possibly, allow users to decide whether to export here
      # based on the eval_result (e.g., to keep the best export).

      export_results.append(
          strategy.export(
              self._estimator,
              os.path.join(
                  compat.as_bytes(export_dir_base),
                  compat.as_bytes(strategy.name))))

    return export_results
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def get_timestamped_export_dir(export_dir_base):
  """Builds a path to a new subdirectory within the base directory.

  Each export is written into a new subdirectory named using the
  current time.  This guarantees monotonically increasing version
  numbers even across multiple runs of the pipeline.
  The timestamp used is the number of seconds since epoch UTC.

  Args:
    export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
  Returns:
    The full path of the new subdirectory (which is not actually created yet).
  """
  export_timestamp = int(time.time())

  export_dir = os.path.join(
      compat.as_bytes(export_dir_base),
      compat.as_bytes(str(export_timestamp)))
  return export_dir


# create a simple parser that pulls the export_version from the directory.
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_get_most_recent_export(self):
    export_dir_base = tempfile.mkdtemp() + "export/"
    gfile.MkDir(export_dir_base)
    _create_test_export_dir(export_dir_base)
    _create_test_export_dir(export_dir_base)
    _create_test_export_dir(export_dir_base)
    export_dir_4 = _create_test_export_dir(export_dir_base)

    (most_recent_export_dir, most_recent_export_version) = (
        saved_model_export_utils.get_most_recent_export(export_dir_base))

    self.assertEqual(compat.as_bytes(export_dir_4),
                     compat.as_bytes(most_recent_export_dir))
    self.assertEqual(compat.as_bytes(export_dir_4),
                     os.path.join(compat.as_bytes(export_dir_base),
                                  compat.as_bytes(
                                      str(most_recent_export_version))))
项目:yolov2    作者:datlife    | 项目源码 | 文件源码
def visualize_graph_in_tfboard(filename, output='./log'):
    with tf.Session() as sess:
        model_filename = filename
        with gfile.FastGFile(model_filename, 'rb') as f:
            data = compat.as_bytes(f.read())
            sm = saved_model_pb2.SavedModel()
            sm.ParseFromString(data)
            if 1 != len(sm.meta_graphs):
                print('More than one graph found. Not sure which to write')
                sys.exit(1)

            g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)

        train_writer = tf.summary.FileWriter(output)
        train_writer.add_graph(sess.graph)
        print("Please execute `tensorboard --logdir {}` to view graph".format(output))
项目:distributional_perspective_on_RL    作者:Kiwoo    | 项目源码 | 文件源码
def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow        
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
项目:baselines    作者:openai    | 项目源码 | 文件源码
def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
项目:tensoronspark    作者:liangfengsid    | 项目源码 | 文件源码
def reset(target, containers=None, config=None):
        if target is not None:
            target = compat.as_bytes(target)
        if containers is not None:
            containers = [compat.as_bytes(c) for c in containers]
        else:
            containers = []
        tf_session.TF_Reset(target, containers, config)
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def _name_list(tensor_list):
  """Utility function for transitioning to the new session API.

  Args:
    tensor_list: a list of `Tensor`s.

  Returns:
    A list of each `Tensor`s name (as byte arrays).
  """
  return [compat.as_bytes(t.name) for t in tensor_list]
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def reset(target, containers=None, config=None):
    """Resets resource containers on `target`, and close all connected sessions.

    A resource container is distributed across all workers in the
    same cluster as `target`.  When a resource container on `target`
    is reset, resources associated with that container will be cleared.
    In particular, all Variables in the container will become undefined:
    they lose their values and shapes.

    NOTE:
    (i) reset() is currently only implemented for distributed sessions.
    (ii) Any sessions on the master named by `target` will be closed.

    If no resource containers are provided, all containers are reset.

    Args:
      target: The execution engine to connect to.
      containers: A list of resource container name strings, or `None` if all of
        all the containers are to be reset.
      config: (Optional.) Protocol buffer with configuration options.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        resetting containers.
    """
    if target is not None:
      target = compat.as_bytes(target)
    if containers is not None:
      containers = [compat.as_bytes(c) for c in containers]
    else:
      containers = []
    tf_session.TF_Reset(target, containers, config)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def testBasic(self):
    base_path = tf.test.test_src_dir_path(
        "contrib/session_bundle/example/half_plus_two/00000123")
    tf.reset_default_graph()
    sess, meta_graph_def = session_bundle.load_session_bundle_from_path(
        base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))

    self.assertTrue(sess)
    asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
    with sess.as_default():
      path1, path2 = sess.run(["filename1:0", "filename2:0"])
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)

      collection_def = meta_graph_def.collection_def

      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)

      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      input_name = default_signature.regression_signature.input.tensor_name
      output_name = default_signature.regression_signature.output.tensor_name
      y = sess.run([output_name], {input_name: np.array([[0], [1], [2], [3]])})
      # The operation is y = 0.5 * x + 2
      self.assertEqual(y[0][0], 2)
      self.assertEqual(y[0][1], 2.5)
      self.assertEqual(y[0][2], 3)
      self.assertEqual(y[0][3], 3.5)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def gfile_copy_callback(files_to_copy, export_dir_path):
  """Callback to copy files using `gfile.Copy` to an export directory.

  This method is used as the default `assets_callback` in `Exporter.init` to
  copy assets from the `assets_collection`. It can also be invoked directly to
  copy additional supplementary files into the export directory (in which case
  it is not a callback).

  Args:
    files_to_copy: A dictionary that maps original file paths to desired
      basename in the export directory.
    export_dir_path: Directory to copy the files to.
  """
  logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
  gfile.MakeDirs(export_dir_path)
  for source_filepath, basename in files_to_copy.items():
    new_path = os.path.join(
        compat.as_bytes(export_dir_path), compat.as_bytes(basename))
    logging.info("Copying asset %s to path %s.", source_filepath, new_path)

    if gfile.Exists(new_path):
      # Guard against being restarted while copying assets, and the file
      # existing and being in an unknown state.
      # TODO(b/28676216): Do some file checks before deleting.
      logging.info("Removing file %s.", new_path)
      gfile.Remove(new_path)
    gfile.Copy(source_filepath, new_path)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def gfile_copy_callback(files_to_copy, export_dir_path):
  """Callback to copy files using `gfile.Copy` to an export directory.

  This method is used as the default `assets_callback` in `Exporter.init` to
  copy assets from the `assets_collection`. It can also be invoked directly to
  copy additional supplementary files into the export directory (in which case
  it is not a callback).

  Args:
    files_to_copy: A dictionary that maps original file paths to desired
      basename in the export directory.
    export_dir_path: Directory to copy the files to.
  """
  logging.info("Write assets into: %s using gfile_copy.", export_dir_path)
  gfile.MakeDirs(export_dir_path)
  for source_filepath, basename in files_to_copy.items():
    new_path = os.path.join(
        compat.as_bytes(export_dir_path), compat.as_bytes(basename))
    logging.info("Copying asset %s to path %s.", source_filepath, new_path)

    if gfile.Exists(new_path):
      # Guard against being restarted while copying assets, and the file
      # existing and being in an unknown state.
      # TODO(b/28676216): Do some file checks before deleting.
      logging.info("Removing file %s.", new_path)
      gfile.Remove(new_path)
    gfile.Copy(source_filepath, new_path)
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def run_benchmark(sess, init_op, add_op):
  """Returns MB/s rate of addition."""


  logdir=FLAGS.logdir_prefix+'/'+FLAGS.name
  os.system('mkdir -p '+logdir)

  # TODO: make events follow same format as eager writer
  writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))
  filename = compat.as_text(writer.FileName())
  training_util.get_or_create_global_step()

  sess.run(init_op)

  for step in range(FLAGS.iters):
    start_time = time.time()
    for i in range(FLAGS.iters_per_step):
      sess.run(add_op.op)

    elapsed_time = time.time() - start_time
    rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
    event = make_event('rate', rate, step)
    writer.WriteEvent(event)
    writer.Flush()
  writer.Close()
  # add event
项目:ray    作者:ray-project    | 项目源码 | 文件源码
def __init__(self, dir, prefix):
        self.dir = dir
        # Start at 1, because EvWriter automatically generates an object with
        # step = 0.
        self.step = 1
        self.evwriter = pywrap_tensorflow.EventsWriter(
            compat.as_bytes(os.path.join(dir, prefix)))
项目:evolution-strategies-starter    作者:openai    | 项目源码 | 文件源码
def __init__(self, dir, prefix):
        self.dir = dir
        self.step = 1 # Start at 1, because EvWriter automatically generates an object with step=0
        self.evwriter = pywrap_tensorflow.EventsWriter(compat.as_bytes(os.path.join(dir, prefix)))
项目:rl-teacher    作者:nottombrown    | 项目源码 | 文件源码
def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow        
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
项目:gym-sandbox    作者:suqi    | 项目源码 | 文件源码
def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow        
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def export_savedmodel(self, export_dir_base, serving_input_fn, **kwargs):
    tf_logging.info('export_savedmodel called with args: %s, %s, %s' %
                    (export_dir_base, serving_input_fn, kwargs))
    self.export_count += 1
    return os.path.join(
        compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def test_stale_asset_collections_are_cleaned(self):
    vocabulary_file = os.path.join(
        compat.as_bytes(test.get_temp_dir()), compat.as_bytes('asset'))
    file_io.write_string_to_file(vocabulary_file, 'foo bar baz')

    export_path = os.path.join(tempfile.mkdtemp(), 'export')

    # create a SavedModel including assets
    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        input_string = tf.placeholder(tf.string)
        # Map string through a table loaded from an asset file
        table = lookup.index_table_from_file(
            vocabulary_file, num_oov_buckets=12, default_value=12)
        output = table.lookup(input_string)
        inputs = {'input': input_string}
        outputs = {'output': output}
        saved_transform_io.write_saved_transform_from_session(
            session, inputs, outputs, export_path)

    # Load it and save it again repeatedly, verifying that the asset collections
    # remain valid.
    for _ in [1, 2, 3]:
      with tf.Graph().as_default() as g:
        with tf.Session().as_default() as session:
          input_string = tf.constant('dog')
          inputs = {'input': input_string}
          outputs = saved_transform_io.apply_saved_transform(export_path,
                                                             inputs)

          self.assertEqual(
              1, len(g.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
          self.assertEqual(
              0, len(g.get_collection(tf.saved_model.constants.ASSETS_KEY)))

          # Check that every ASSET_FILEPATHS refers to a Tensor in the graph.
          # If not, get_tensor_by_name() raises KeyError.
          for asset_path in g.get_collection(ops.GraphKeys.ASSET_FILEPATHS):
            tensor_name = asset_path.name
            g.get_tensor_by_name(tensor_name)

          export_path = os.path.join(tempfile.mkdtemp(), 'export')
          saved_transform_io.write_saved_transform_from_session(
              session, inputs, outputs, export_path)
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def _do_run(self, handle, target_list, fetch_list, feed_dict,
              options, run_metadata):
    """Runs a step based on the given fetches and feeds.

    Args:
      handle: a handle for partial_run. None if this is just a call to run().
      target_list: A list of operations to be run, but not fetched.
      fetch_list: A list of tensors to be fetched.
      feed_dict: A dictionary that maps tensors to numpy ndarrays.
      options: A (pointer to a) [`RunOptions`] protocol buffer, or None
      run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.

    Raises:
      tf.errors.OpError: Or one of its subclasses on error.
    """
    if self._created_with_new_api:
      # pylint: disable=protected-access
      feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
      fetches = [t._as_tf_output() for t in fetch_list]
      targets = [op._c_op for op in target_list]
      # pylint: enable=protected-access
    else:
      feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items())
      fetches = _name_list(fetch_list)
      targets = _name_list(target_list)

    def _run_fn(session, feed_dict, fetch_list, target_list, options,
                run_metadata):
      # Ensure any changes to the graph are reflected in the runtime.
      self._extend_graph()
      with errors.raise_exception_on_not_ok_status() as status:
        if self._created_with_new_api:
          return tf_session.TF_SessionRun_wrapper(
              session, options, feed_dict, fetch_list, target_list,
              run_metadata, status)
        else:
          return tf_session.TF_Run(session, options,
                                   feed_dict, fetch_list, target_list,
                                   status, run_metadata)

    def _prun_fn(session, handle, feed_dict, fetch_list):
      assert not self._created_with_new_api, ('Partial runs don\'t work with '
                                              'C API')
      if target_list:
        raise RuntimeError('partial_run() requires empty target_list.')
      with errors.raise_exception_on_not_ok_status() as status:
        return tf_session.TF_PRun(session, handle, feed_dict, fetch_list,
                                  status)

    if handle is None:
      return self._do_call(_run_fn, self._session, feeds, fetches, targets,
                           options, run_metadata)
    else:
      return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:Machine-Learning    作者:sfeng15    | 项目源码 | 文件源码
def _add_collection_def(meta_graph_def, key):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s" % type(key))
    return
  collection_list = ops.get_collection(key)
  if not collection_list:
    return
  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x)
        assert isinstance(proto, proto_type)
        getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        getattr(col_def, kind).value.extend([x.name for x in collection_list])
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Error encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s" % (key, str(e)))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return