Python google.protobuf.text_format 模块,Merge() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用google.protobuf.text_format.Merge()

项目:facade-segmentation    作者:jfemiani    | 项目源码 | 文件源码
def make_testable(train_model_path):
    # load the train net prototxt as a protobuf message
    with open(train_model_path) as f:
        train_str = f.read()
    train_net = caffe_pb2.NetParameter()
    text_format.Merge(train_str, train_net)

    # add the mean, var top blobs to all BN layers
    for layer in train_net.layer:
        if layer.type == "BN" and len(layer.top) == 1:
            layer.top.append(layer.top[0] + "-mean")
            layer.top.append(layer.top[0] + "-var")

    # remove the test data layer if present
    if train_net.layer[1].name == "data" and train_net.layer[1].include:
        train_net.layer.remove(train_net.layer[1])
        if train_net.layer[0].include:
            # remove the 'include {phase: TRAIN}' layer param
            train_net.layer[0].include.remove(train_net.layer[0].include[0])
    return train_net
项目:ngraph    作者:NervanaSystems    | 项目源码 | 文件源码
def import_protobuf(self, pb_file, verbose=False):
        """
        Imports graph_def from protobuf file to ngraph.

        Arguments:
            pb_file: Protobuf file path.
            verbose: Prints graph_def at each node if True.
        """
        # read graph_def
        graph_def = tf.GraphDef()
        if mimetypes.guess_type(pb_file)[0] == 'text/plain':
            with open(pb_file, 'r') as f:
                text_format.Merge(f.read(), graph_def)
        else:
            with open(pb_file, 'rb') as f:
                graph_def.ParseFromString(f.read())

        self.import_graph_def(graph_def, verbose=verbose)
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def _assert_encode_decode(self, coder, expected_proto_text, expected_decoded):
    example = tf.train.Example()
    text_format.Merge(expected_proto_text, example)
    data = example.SerializeToString()

    # Assert the data is decoded into the expected format.
    decoded = coder.decode(data)
    np.testing.assert_equal(expected_decoded, decoded)

    # Assert the decoded data can be encoded back into the original proto.
    encoded = coder.encode(decoded)
    parsed_example = tf.train.Example()
    parsed_example.ParseFromString(encoded)
    self.assertEqual(example, parsed_example)

    # Assert the data can be decoded from the encoded string.
    decoded_again = coder.decode(encoded)
    np.testing.assert_equal(expected_decoded, decoded_again)
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def _assert_decode_encode(self, coder, expected_proto_text, expected_decoded):
    example = tf.train.Example()
    text_format.Merge(expected_proto_text, example)

    # Assert the expected decoded data can be encoded into the expected proto.
    encoded = coder.encode(expected_decoded)
    parsed_example = tf.train.Example()
    parsed_example.ParseFromString(encoded)
    self.assertEqual(example, parsed_example)

    # Assert the encoded data can be decoded into the original input.
    decoded = coder.decode(encoded)
    np.testing.assert_equal(expected_decoded, decoded)

    # Assert the decoded data can be encoded back into the expected proto.
    encoded_again = coder.encode(decoded)
    parsed_example_again = tf.train.Example()
    parsed_example_again.ParseFromString(encoded_again)
    np.testing.assert_equal(example, parsed_example_again)
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def test_example_proto_coder_error(self):
    input_schema = dataset_schema.from_feature_spec({
        '2d_vector_feature': tf.FixedLenFeature(shape=[2, 2], dtype=tf.int64),
    })
    coder = example_proto_coder.ExampleProtoCoder(input_schema)

    example_decoded_value = {
        '2d_vector_feature': [1, 2, 3]
    }
    example_proto_text = """
    features {
      feature { key: "1d_vector_feature"
                value { int64_list { value: [ 1, 2, 3 ] } } }
    }
    """
    example = tf.train.Example()
    text_format.Merge(example_proto_text, example)

    # Ensure that we raise an exception for trying to encode invalid data.
    with self.assertRaisesRegexp(ValueError, 'got wrong number of values'):
      _ = coder.encode(example_decoded_value)

    # Ensure that we raise an exception for trying to parse invalid data.
    with self.assertRaisesRegexp(ValueError, 'got wrong number of values'):
      _ = coder.decode(example.SerializeToString())
项目:Vector-Tiles-Reader-QGIS-Plugin    作者:geometalab    | 项目源码 | 文件源码
def testMergeExpandedAnyRepeated(self):
    message = any_test_pb2.TestAny()
    text = ('repeated_any_value {\n'
            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
            '    data: "string0"\n'
            '  }\n'
            '}\n'
            'repeated_any_value {\n'
            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
            '    data: "string1"\n'
            '  }\n'
            '}\n')
    text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
    packed_message = unittest_pb2.OneString()
    message.repeated_any_value[0].Unpack(packed_message)
    self.assertEqual('string0', packed_message.data)
    message.repeated_any_value[1].Unpack(packed_message)
    self.assertEqual('string1', packed_message.data)
项目:Vector-Tiles-Reader-QGIS-Plugin    作者:geometalab    | 项目源码 | 文件源码
def testParsingNestedClass(self):
    """Test that the generated class can parse a nested message."""
    file_descriptor = descriptor_pb2.FileDescriptorProto()
    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
    msg_descriptor = descriptor.MakeDescriptor(
        file_descriptor.message_type[0])
    msg_class = reflection.MakeClass(msg_descriptor)
    msg = msg_class()
    msg_str = (
        'bar {'
        '  baz {'
        '    deep: 4'
        '  }'
        '}')
    text_format.Merge(msg_str, msg)
    self.assertEqual(msg.bar.baz.deep, 4)
项目:CityHorizon    作者:CityStreetWander    | 项目源码 | 文件源码
def make_testable(train_model_path):
    # load the train net prototxt as a protobuf message
    with open(train_model_path) as f:
        train_str = f.read()
    train_net = caffe_pb2.NetParameter()
    text_format.Merge(train_str, train_net)

    # add the mean, var top blobs to all BN layers
    for layer in train_net.layer:
        if layer.type == "BN" and len(layer.top) == 1:
            layer.top.append(layer.top[0] + "-mean")
            layer.top.append(layer.top[0] + "-var")

    # remove the test data layer if present
    if train_net.layer[1].name == "data" and train_net.layer[1].include:
        train_net.layer.remove(train_net.layer[1])
        if train_net.layer[0].include:
            # remove the 'include {phase: TRAIN}' layer param
            train_net.layer[0].include.remove(train_net.layer[0].include[0])
    return train_net
项目:cv4ag    作者:worldbank    | 项目源码 | 文件源码
def make_testable(train_model_path):
    # load the train net prototxt as a protobuf message
    with open(train_model_path) as f:
        train_str = f.read()
    train_net = caffe_pb2.NetParameter()
    text_format.Merge(train_str, train_net)

    # add the mean, var top blobs to all BN layers
    for layer in train_net.layer:
        if layer.type == "BN" and len(layer.top) == 1:
            layer.top.append(layer.top[0] + "-mean")
            layer.top.append(layer.top[0] + "-var")

    # remove the test data layer if present
    if train_net.layer[1].name == "data" and train_net.layer[1].include:
        train_net.layer.remove(train_net.layer[1])
        if train_net.layer[0].include:
            # remove the 'include {phase: TRAIN}' layer param
            train_net.layer[0].include.remove(train_net.layer[0].include[0])
    return train_net
项目:protoc-gen-lua-bin    作者:u0u0    | 项目源码 | 文件源码
def testParsingNestedClass(self):
    """Test that the generated class can parse a nested message."""
    file_descriptor = descriptor_pb2.FileDescriptorProto()
    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
    msg_descriptor = descriptor.MakeDescriptor(
        file_descriptor.message_type[0])
    msg_class = reflection.MakeClass(msg_descriptor)
    msg = msg_class()
    msg_str = (
        'bar {'
        '  baz {'
        '    deep: 4'
        '  }'
        '}')
    text_format.Merge(msg_str, msg)
    self.assertEqual(msg.bar.baz.deep, 4)
项目:coremltools    作者:apple    | 项目源码 | 文件源码
def testRoundTripExoticAsOneLine(self):
    message = unittest_pb2.TestAllTypes()
    message.repeated_int64.append(-9223372036854775808)
    message.repeated_uint64.append(18446744073709551615)
    message.repeated_double.append(123.456)
    message.repeated_double.append(1.23e22)
    message.repeated_double.append(1.23e-18)
    message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
    message.repeated_string.append(u'\u00fc\ua71f')

    # Test as_utf8 = False.
    wire_text = text_format.MessageToString(
        message, as_one_line=True, as_utf8=False)
    parsed_message = unittest_pb2.TestAllTypes()
    text_format.Merge(wire_text, parsed_message)
    self.assertEquals(message, parsed_message)

    # Test as_utf8 = True.
    wire_text = text_format.MessageToString(
        message, as_one_line=True, as_utf8=True)
    parsed_message = unittest_pb2.TestAllTypes()
    text_format.Merge(wire_text, parsed_message)
    self.assertEquals(message, parsed_message)
项目:coremltools    作者:apple    | 项目源码 | 文件源码
def testMergeMessageSet(self):
    message = unittest_pb2.TestAllTypes()
    text = ('repeated_uint64: 1\n'
            'repeated_uint64: 2\n')
    text_format.Merge(text, message)
    self.assertEqual(1, message.repeated_uint64[0])
    self.assertEqual(2, message.repeated_uint64[1])

    message = unittest_mset_pb2.TestMessageSetContainer()
    text = ('message_set {\n'
            '  [protobuf_unittest.TestMessageSetExtension1] {\n'
            '    i: 23\n'
            '  }\n'
            '  [protobuf_unittest.TestMessageSetExtension2] {\n'
            '    str: \"foo\"\n'
            '  }\n'
            '}\n')
    text_format.Merge(text, message)
    ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
    ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
    self.assertEquals(23, message.message_set.Extensions[ext1].i)
    self.assertEquals('foo', message.message_set.Extensions[ext2].str)
项目:coremltools    作者:apple    | 项目源码 | 文件源码
def testMergeExotic(self):
    message = unittest_pb2.TestAllTypes()
    text = ('repeated_int64: -9223372036854775808\n'
            'repeated_uint64: 18446744073709551615\n'
            'repeated_double: 123.456\n'
            'repeated_double: 1.23e+22\n'
            'repeated_double: 1.23e-18\n'
            'repeated_string: \n'
            '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
            'repeated_string: "foo" \'corge\' "grault"\n'
            'repeated_string: "\\303\\274\\352\\234\\237"\n'
            'repeated_string: "\\xc3\\xbc"\n'
            'repeated_string: "\xc3\xbc"\n')
    text_format.Merge(text, message)

    self.assertEqual(-9223372036854775808, message.repeated_int64[0])
    self.assertEqual(18446744073709551615, message.repeated_uint64[0])
    self.assertEqual(123.456, message.repeated_double[0])
    self.assertEqual(1.23e22, message.repeated_double[1])
    self.assertEqual(1.23e-18, message.repeated_double[2])
    self.assertEqual(
        '\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0])
    self.assertEqual('foocorgegrault', message.repeated_string[1])
    self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
    self.assertEqual(u'\u00fc', message.repeated_string[3])
项目:coremltools    作者:apple    | 项目源码 | 文件源码
def testMergeBadEnumValue(self):
    message = unittest_pb2.TestAllTypes()
    text = 'optional_nested_enum: BARR'
    self.assertRaisesWithMessage(
        text_format.ParseError,
        ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
         'has no value named BARR.'),
        text_format.Merge, text, message)

    message = unittest_pb2.TestAllTypes()
    text = 'optional_nested_enum: 100'
    self.assertRaisesWithMessage(
        text_format.ParseError,
        ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
         'has no value with number 100.'),
        text_format.Merge, text, message)
项目:coremltools    作者:apple    | 项目源码 | 文件源码
def testMergeStringFieldUnescape(self):
    message = unittest_pb2.TestAllTypes()
    text = r'''repeated_string: "\xf\x62"
               repeated_string: "\\xf\\x62"
               repeated_string: "\\\xf\\\x62"
               repeated_string: "\\\\xf\\\\x62"
               repeated_string: "\\\\\xf\\\\\x62"
               repeated_string: "\x5cx20"'''
    text_format.Merge(text, message)

    SLASH = '\\'
    self.assertEqual('\x0fb', message.repeated_string[0])
    self.assertEqual(SLASH + 'xf' + SLASH + 'x62', message.repeated_string[1])
    self.assertEqual(SLASH + '\x0f' + SLASH + 'b', message.repeated_string[2])
    self.assertEqual(SLASH + SLASH + 'xf' + SLASH + SLASH + 'x62',
                     message.repeated_string[3])
    self.assertEqual(SLASH + SLASH + '\x0f' + SLASH + SLASH + 'b',
                     message.repeated_string[4])
    self.assertEqual(SLASH + 'x20', message.repeated_string[5])
项目:coremltools    作者:apple    | 项目源码 | 文件源码
def testMergeExpandedAny(self):
    message = any_test_pb2.TestAny()
    text = ('any_value {\n'
            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
            '    data: "string"\n'
            '  }\n'
            '}\n')
    text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
    packed_message = unittest_pb2.OneString()
    message.any_value.Unpack(packed_message)
    self.assertEqual('string', packed_message.data)
    message.Clear()
    text_format.Parse(text, message, descriptor_pool=descriptor_pool.Default())
    packed_message = unittest_pb2.OneString()
    message.any_value.Unpack(packed_message)
    self.assertEqual('string', packed_message.data)
项目:coremltools    作者:apple    | 项目源码 | 文件源码
def testMergeExpandedAnyRepeated(self):
    message = any_test_pb2.TestAny()
    text = ('repeated_any_value {\n'
            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
            '    data: "string0"\n'
            '  }\n'
            '}\n'
            'repeated_any_value {\n'
            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
            '    data: "string1"\n'
            '  }\n'
            '}\n')
    text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
    packed_message = unittest_pb2.OneString()
    message.repeated_any_value[0].Unpack(packed_message)
    self.assertEqual('string0', packed_message.data)
    message.repeated_any_value[1].Unpack(packed_message)
    self.assertEqual('string1', packed_message.data)
项目:ENet    作者:TimoSaemann    | 项目源码 | 文件源码
def make_testable(train_model_path):
    # load the train net prototxt as a protobuf message
    with open(train_model_path) as f:
        train_str = f.read()
    train_net = caffe_pb2.NetParameter()
    text_format.Merge(train_str, train_net)

    # add the mean, var top blobs to all BN layers
    for layer in train_net.layer:
        if layer.type == "BN" and len(layer.top) == 1:
            layer.top.append(layer.top[0] + "-mean")
            layer.top.append(layer.top[0] + "-var")

    # remove the test data layer if present
    if train_net.layer[1].name == "data" and train_net.layer[1].include:
        train_net.layer.remove(train_net.layer[1])
        if train_net.layer[0].include:
            # remove the 'include {phase: TRAIN}' layer param
            train_net.layer[0].include.remove(train_net.layer[0].include[0])
    return train_net
项目:Triplet_Loss_SBIR    作者:TuBui    | 项目源码 | 文件源码
def add_params(self,params):
    """
    Set or update solver parameters
    """
    paramstr = ''
    for key, val in params.items():
      self.sp.ClearField(key) #reset field
      if isinstance(val,str):     #if val is a string
        paramstr += (key + ': ' + '"' + val + '"' + '\n')
      elif type(val) is list:     #repeatable field
        for it in val:
          paramstr += (key + ': ' + str(it) + '\n')
      elif type(val) == type(True): #boolean type
        if val:
          paramstr += (key + ': true\n')
        else:
          paramstr += (key + ': false\n')
      else:                       #numerical value
        paramstr += (key + ': ' + str(val) + '\n')
    #apply change
    text_format.Merge(paramstr, self.sp)
项目:DeepArt    作者:jiriroz    | 项目源码 | 文件源码
def __init__(self):
        """Loading DNN model."""
        model_path = '/home/jiri/caffe/models/bvlc_googlenet/'
        net_fn   = model_path + 'deploy.prototxt'
        param_fn = model_path + 'bvlc_googlenet.caffemodel'
        #model_path = '/home/jiri/caffe/models/oxford102/'
        #net_fn   = model_path + 'deploy.prototxt'
        #param_fn = model_path + 'oxford102.caffemodel'

        # Patching model to be able to compute gradients.
        # Note that you can also manually add "force_backward: true" line
        #to "deploy.prototxt".
        model = caffe.io.caffe_pb2.NetParameter()
        text_format.Merge(open(net_fn).read(), model)
        model.force_backward = True
        open('tmp.prototxt', 'w').write(str(model))

        # ImageNet mean, training set dependent
        mean =  np.float32([104.0, 116.0, 122.0])
        # the reference model has channels in BGR order instead of RGB
        chann_sw = (2,1,0)
        self.net = caffe.Classifier('tmp.prototxt', param_fn, mean=mean, channel_swap=chann_sw)
项目:DepthSegnet    作者:hari-sikchi    | 项目源码 | 文件源码
def make_testable(train_model_path):
    # load the train net prototxt as a protobuf message
    print "hello"
    with open(train_model_path) as f:
        train_str = f.read()
    train_net = caffe_pb2.NetParameter()
    text_format.Merge(train_str, train_net)

    # add the mean, var top blobs to all BN layers
    for layer in train_net.layer:
        print(len(layer.top))
        if layer.type == "BN" and len(layer.top) == 1:
            layer.top.append(layer.top[0] + "-mean")
            layer.top.append(layer.top[0] + "-var")

    # remove the test data layer if present
    if train_net.layer[1].name == "data" and train_net.layer[1].include:
        train_net.layer.remove(train_net.layer[1])
        if train_net.layer[0].include:
            # remove the 'include {phase: TRAIN}' layer param
            train_net.layer[0].include.remove(train_net.layer[0].include[0])
    return train_net
项目:DepthSegnet    作者:hari-sikchi    | 项目源码 | 文件源码
def make_testable(train_model_path):
    # load the train net prototxt as a protobuf message
    with open(train_model_path) as f:
        train_str = f.read()
    train_net = caffe_pb2.NetParameter()
    text_format.Merge(train_str, train_net)

    # add the mean, var top blobs to all BN layers
    for layer in train_net.layer:
        #print layer.type 
        #print type(layer.top)
        if layer.type == "BN" and len(layer.top) == 1:
            layer.top.append(layer.top[0] + "-mean")
            layer.top.append(layer.top[0] + "-var")

    # remove the test data layer if present
    if train_net.layer[1].name == "data" and train_net.layer[1].include:
        train_net.layer.remove(train_net.layer[1])
        if train_net.layer[0].include:
            # remove the 'include {phase: TRAIN}' layer param
            train_net.layer[0].include.remove(train_net.layer[0].include[0])
    return train_net
项目:tensorboard    作者:tensorflow    | 项目源码 | 文件源码
def _latest_checkpoints_changed(configs, run_path_pairs):
  """Returns true if the latest checkpoint has changed in any of the runs."""
  for run_name, assets_dir in run_path_pairs:
    if run_name not in configs:
      config = ProjectorConfig()
      config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME)
      if tf.gfile.Exists(config_fpath):
        with tf.gfile.GFile(config_fpath, 'r') as f:
          file_content = f.read()
        text_format.Merge(file_content, config)
    else:
      config = configs[run_name]

    # See if you can find a checkpoint file in the logdir.
    logdir = _assets_dir_to_logdir(assets_dir)
    ckpt_path = _find_latest_checkpoint(logdir)
    if not ckpt_path:
      continue
    if config.model_checkpoint_path != ckpt_path:
      return True
  return False
项目:tensorboard    作者:tensorflow    | 项目源码 | 文件源码
def dump_data(logdir):
  """Dumps plugin data to the log directory."""
  plugin_logdir = plugin_asset_util.PluginDirectory(
      logdir, profile_plugin.ProfilePlugin.plugin_name)
  _maybe_create_directory(plugin_logdir)

  for run in profile_demo_data.RUNS:
    run_dir = os.path.join(plugin_logdir, run)
    _maybe_create_directory(run_dir)
    if run in profile_demo_data.TRACES:
      with open(os.path.join(run_dir, 'trace'), 'w') as f:
        proto = trace_events_pb2.Trace()
        text_format.Merge(profile_demo_data.TRACES[run], proto)
        f.write(proto.SerializeToString())
    shutil.copyfile('tensorboard/plugins/profile/profile_demo.op_profile.json',
                    os.path.join(run_dir, 'op_profile.json'))

  # Unsupported tool data should not be displayed.
  run_dir = os.path.join(plugin_logdir, 'empty')
  _maybe_create_directory(run_dir)
  with open(os.path.join(run_dir, 'unsupported'), 'w') as f:
    f.write('unsupported data')
项目:cv-api    作者:yasunorikudo    | 项目源码 | 文件源码
def __init__(self):
        # load MS COCO labels
        labelmap_file = os.path.join(CAFFE_ROOT, LABEL_MAP)
        file = open(labelmap_file, 'r')
        self._labelmap = caffe_pb2.LabelMap()
        text_format.Merge(str(file.read()), self._labelmap)

        model_def = os.path.join(CAFFE_ROOT, PROTO_TXT)
        model_weights = os.path.join(CAFFE_ROOT, CAFFE_MODEL)

        self._net = caffe.Net(model_def, model_weights, caffe.TEST)
        self._transformer = caffe.io.Transformer(
            {'data': self._net.blobs['data'].data.shape})
        self._transformer.set_transpose('data', (2, 0, 1))
        self._transformer.set_mean('data', np.array([104, 117, 123]))
        self._transformer.set_raw_scale('data', 255)
        self._transformer.set_channel_swap('data', (2, 1, 0))

        # set net to batch size of 1
        image_resize = IMAGE_SIZE
        self._net.blobs['data'].reshape(1, 3, image_resize, image_resize)
项目:DL8803    作者:NanditaDamaraju    | 项目源码 | 文件源码
def make_testable(train_model_path):
    # load the train net prototxt as a protobuf message
    with open(train_model_path) as f:
        train_str = f.read()
    train_net = caffe_pb2.NetParameter()
    text_format.Merge(train_str, train_net)

    # add the mean, var top blobs to all BN layers
    for layer in train_net.layer:
        if layer.type == "BN" and len(layer.top) == 1:
            layer.top.append(layer.top[0] + "-mean")
            layer.top.append(layer.top[0] + "-var")

    # remove the test data layer if present
    if train_net.layer[1].name == "data" and train_net.layer[1].include:
        train_net.layer.remove(train_net.layer[1])
        if train_net.layer[0].include:
            # remove the 'include {phase: TRAIN}' layer param
            train_net.layer[0].include.remove(train_net.layer[0].include[0])
    return train_net
项目:DL8803    作者:NanditaDamaraju    | 项目源码 | 文件源码
def make_testable(train_model_path):
    # load the train net prototxt as a protobuf message
    with open(train_model_path) as f:
        train_str = f.read()
    train_net = caffe_pb2.NetParameter()
    text_format.Merge(train_str, train_net)

    # add the mean, var top blobs to all BN layers
    for layer in train_net.layer:
        if layer.type == "BN" and len(layer.top) == 1:
            layer.top.append(layer.top[0] + "-mean")
            layer.top.append(layer.top[0] + "-var")

    # remove the test data layer if present
    if train_net.layer[1].name == "data" and train_net.layer[1].include:
        train_net.layer.remove(train_net.layer[1])
        if train_net.layer[0].include:
            # remove the 'include {phase: TRAIN}' layer param
            train_net.layer[0].include.remove(train_net.layer[0].include[0])
    return train_net
项目:go2mapillary    作者:enricofer    | 项目源码 | 文件源码
def testMergeExpandedAnyRepeated(self):
    message = any_test_pb2.TestAny()
    text = ('repeated_any_value {\n'
            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
            '    data: "string0"\n'
            '  }\n'
            '}\n'
            'repeated_any_value {\n'
            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
            '    data: "string1"\n'
            '  }\n'
            '}\n')
    text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
    packed_message = unittest_pb2.OneString()
    message.repeated_any_value[0].Unpack(packed_message)
    self.assertEqual('string0', packed_message.data)
    message.repeated_any_value[1].Unpack(packed_message)
    self.assertEqual('string1', packed_message.data)
项目:go2mapillary    作者:enricofer    | 项目源码 | 文件源码
def testParsingNestedClass(self):
    """Test that the generated class can parse a nested message."""
    file_descriptor = descriptor_pb2.FileDescriptorProto()
    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
    msg_descriptor = descriptor.MakeDescriptor(
        file_descriptor.message_type[0])
    msg_class = reflection.MakeClass(msg_descriptor)
    msg = msg_class()
    msg_str = (
        'bar {'
        '  baz {'
        '    deep: 4'
        '  }'
        '}')
    text_format.Merge(msg_str, msg)
    self.assertEqual(msg.bar.baz.deep, 4)
项目:rpcDemo    作者:Tangxinwei    | 项目源码 | 文件源码
def testParsingNestedClass(self):
    """Test that the generated class can parse a nested message."""
    file_descriptor = descriptor_pb2.FileDescriptorProto()
    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
    msg_descriptor = descriptor.MakeDescriptor(
        file_descriptor.message_type[0])
    msg_class = reflection.MakeClass(msg_descriptor)
    msg = msg_class()
    msg_str = (
        'bar {'
        '  baz {'
        '    deep: 4'
        '  }'
        '}')
    text_format.Merge(msg_str, msg)
    self.assertEqual(msg.bar.baz.deep, 4)
项目:QScode    作者:PierreHao    | 项目源码 | 文件源码
def __init__(self):
        caffe.set_mode_gpu()
        #caffe.set_device(0)
        model_path = '../models/bvlc_googlenet/' # substitute your path here
        net_fn   = model_path + 'deploy.prototxt'
        param_fn = model_path + 'bvlc_googlenet.caffemodel'
        model = caffe.io.caffe_pb2.NetParameter()
        text_format.Merge(open(net_fn).read(), model)
        model.force_backward = True #backward to input layer
        open('tmp.prototxt', 'w').write(str(model))
        self.net = caffe.Classifier('tmp.prototxt', param_fn,
                       mean = np.float32([104.0, 116.0, 122.0]), 
                       channel_swap = (2,1,0))
        # for the mode guide, if flag = 1               
        self.flag = 0
        self.epoch = 20
        self.end = 'inception_4c/output'
        #self.end = 'conv4'
项目:QScode    作者:PierreHao    | 项目源码 | 文件源码
def __init__(self, solver_prototxt, pretrained_model=None):
        """Initialize the SolverWrapper."""

        self.solver = caffe.SGDSolver(solver_prototxt)
        if pretrained_model is not None:
            print ('Loading pretrained model '
                   'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)      
        self.solver_param = caffe.io.caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            text_format.Merge(f.read(), self.solver_param)

        if self.solver_param.solver_mode == 1:
            caffe.set_mode_gpu()
            caffe.set_device(params.gpu_id)
            print 'Use GPU', params.gpu_id, 'to train'
        else:
            print 'Use CPU to train'
        #initial python data layer    
        self.solver.net.layers[0].set_db()
项目:QScode    作者:PierreHao    | 项目源码 | 文件源码
def __init__(self):
        caffe.set_mode_gpu()
        #caffe.set_device(0)
        model_path = '../models/bvlc_googlenet/' # substitute your path here
        net_fn   = model_path + 'deploy.prototxt'
        param_fn = model_path + 'bvlc_googlenet.caffemodel'
        model = caffe.io.caffe_pb2.NetParameter()
        text_format.Merge(open(net_fn).read(), model)
        model.force_backward = True #backward to input layer
        open('tmp.prototxt', 'w').write(str(model))
        self.net = caffe.Classifier('tmp.prototxt', param_fn,
                       mean = np.float32([104.0, 116.0, 122.0]), 
                       channel_swap = (2,1,0))
        # for the mode guide, if flag = 1               
        self.flag = 0
        self.epoch = 20
        self.end = 'inception_4c/output'
        #self.end = 'conv4'
项目:protocall    作者:google    | 项目源码 | 文件源码
def parse_proto(text, message_name):
  if message_name in protos:
    p = protos[message_name]()
    text_format.Merge(text, p)
    return p
  import pdb; pdb.set_trace()
  raise RuntimeError("message name is: '" + message_name + "'")
项目:protocall    作者:google    | 项目源码 | 文件源码
def create_proto_expression():
    s= """arithmetic_operator { left { atom { field { component { name: "a" } component { name: "value" } } } } right { arithmetic_operator { left { atom { field { component { name: "xyz" } component { name: "value" } } } } right { atom { field { component { name: "b" } component { name: "value" } } } } operator: PLUS } } operator: MULTIPLY }"""

    e = protocall_pb2.Expression()
    text_format.Merge(s, e)
    return e
项目:emu    作者:mlosch    | 项目源码 | 文件源码
def _load_layer_types(prototxt):
        # Read prototxt with caffe protobuf definitions
        layers = caffe_pb2.NetParameter()
        with open(prototxt, 'r') as f:
            text_format.Merge(str(f.read()), layers)

        # Assign layer parameters to type dictionary
        types = OrderedDict()
        for i in range(len(layers.layer)):
            types[layers.layer[i].name] = layers.layer[i].type

        return types
项目:PSPNet-Keras-tensorflow    作者:Vladkryvoruchko    | 项目源码 | 文件源码
def load(self):
        '''Load the layer definitions from the prototxt.'''
        self.params = get_caffe_resolver().NetParameter()
        with open(self.def_path, 'rb') as def_file:
            text_format.Merge(def_file.read(), self.params)
项目:voc-classification    作者:philkr    | 项目源码 | 文件源码
def parseProtoString(s):
    from google.protobuf import text_format
    proto_net = pb.NetParameter()
    text_format.Merge(s, proto_net)
    return proto_net
项目:nimo    作者:wolfram2012    | 项目源码 | 文件源码
def read_proto_file(file_path, parser_object):
    file = open(file_path, "r")
    if not file:
        raise Exception("ERROR (" + file_path + ")!")
    text_format.Merge(str(file.read()), parser_object)
    file.close()
    return parser_object
项目:magenta    作者:tensorflow    | 项目源码 | 文件源码
def parse_test_proto(proto_type, proto_string):
  instance = proto_type()
  text_format.Merge(proto_string, instance)
  return instance
项目:MMdnn    作者:Microsoft    | 项目源码 | 文件源码
def merge(self, parent, child):
        '''Merge the child node into the parent.'''
        raise NotImplementedError('Must be implemented by subclass')
项目:MMdnn    作者:Microsoft    | 项目源码 | 文件源码
def __init__(self, def_path, data_path, target_toolkit, input_shape=None, phase='test'):
        self.layer_name_map = {}
        self.data_injector = None
        self.is_train_proto = False
        self.input_shape = input_shape
        if def_path is None:
            if self.input_shape is None:
                raise ConversionError('if the graph prototxt is not provided, the input shape should be provided')
            self.input_shape = [1] + self.input_shape
            def_path, self.data_injector = self.gen_prototxt_from_caffemodel(data_path, self.input_shape)
            self.is_train_proto = True
        else:
            model = get_caffe_resolver().NetParameter()
            with open(def_path, 'r') as f:
                text_format.Merge(f.read(), model)
            layers = model.layers or model.layer
            if len([layer for layer in layers if NodeKind.map_raw_kind(layer.type) in LAYER_IN_TRAIN_PROTO]) > 0:
                if self.input_shape is None:
                    raise ConversionError('the train_val.prototxt should be provided with the input shape')
                self.input_shape = [1] + self.input_shape
                self.is_train_proto = True
        graph = GraphBuilder(def_path, self.input_shape, self.is_train_proto, phase).build()
        if self.is_train_proto:
            def_path = graph.prototxt
        if data_path is not None:
            graph = graph.transformed([
                self.data_injector if self.data_injector else DataInjector(def_path, data_path), # Load and associate learned parameters
                BatchNormScaleBiasFuser(),
                BatchNormPreprocessor() # Pre-process batch normalization data
            ])
            target_toolkit = target_toolkit.lower()
            if target_toolkit not in ('caffe', 'caffe2'):
                graph = graph.transformed([DataReshaper({ # Reshape the parameters to TensorFlow's ordering
                    NodeKind.Convolution: (2, 3, 1, 0), # (c_o, c_i, h, w) -> (h, w, c_i, c_o)
                    NodeKind.InnerProduct: (1, 0) # (c_o, c_i) -> (c_i, c_o)
                }),
                    ParameterNamer() # Convert parameters to dictionaries
                ])
        self.graph = graph
        #  self.graph = NodeRenamer()(graph)
        print_stderr(self.graph)
项目:MMdnn    作者:Microsoft    | 项目源码 | 文件源码
def load(self):
        self.model = get_caffe_resolver().NetParameter()
        with open(self.model_path, 'r') as f:
            text_format.Merge(f.read(), self.model)
        if self.is_train_proto:
            self.process_train_proto()
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def test_example_proto_coder_default_value(self):
    input_schema = dataset_schema.from_feature_spec({
        'scalar_feature_3':
            tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=1.0),
        'scalar_feature_4':
            tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=0.0),
        '1d_vector_feature':
            tf.FixedLenFeature(
                shape=[1], dtype=tf.float32, default_value=[2.0]),
        '2d_vector_feature':
            tf.FixedLenFeature(
                shape=[2, 2],
                dtype=tf.float32,
                default_value=[[1.0, 2.0], [3.0, 4.0]]),
    })
    coder = example_proto_coder.ExampleProtoCoder(input_schema)

    # Python types.
    example_proto_text = """
    features {
    }
    """
    example = tf.train.Example()
    text_format.Merge(example_proto_text, example)
    data = example.SerializeToString()

    # Assert the data is decoded into the expected format.
    expected_decoded = {
        'scalar_feature_3': 1.0,
        'scalar_feature_4': 0.0,
        '1d_vector_feature': [2.0],
        '2d_vector_feature': [[1.0, 2.0], [3.0, 4.0]],
    }
    decoded = coder.decode(data)
    np.testing.assert_equal(expected_decoded, decoded)
项目:dsrf    作者:ddexnet    | 项目源码 | 文件源码
def block_from_ascii(cls, text):
    """Returns Block protobuf parsed from ASCII text."""
    block = block_pb2.Block()
    text_format.Merge(text, block)
    return block
项目:dsrf    作者:ddexnet    | 项目源码 | 文件源码
def block_from_ascii(cls, text):
    """Returns Block protobuf parsed from ASCII text."""
    block = block_pb2.Block()
    text_format.Merge(text, block)
    return block
项目:dsrf    作者:ddexnet    | 项目源码 | 文件源码
def block_from_ascii(cls, text):
    """Returns Block protobuf parsed from ASCII text."""
    block = block_pb2.Block()
    text_format.Merge(text, block)
    return block
项目:dsrf    作者:ddexnet    | 项目源码 | 文件源码
def block_from_ascii(cls, text):
    """Returns Block protobuf parsed from ASCII text."""
    block = block_pb2.Block()
    text_format.Merge(text, block)
    return block
项目:dsrf    作者:ddexnet    | 项目源码 | 文件源码
def block_from_ascii(cls, text):
    """Returns Block protobuf parsed from ASCII text."""
    block = block_pb2.Block()
    text_format.Merge(text, block)
    return block
项目:piecewisecrf    作者:Vaan5    | 项目源码 | 文件源码
def load(self):
        '''Load the layer definitions from the prototxt.'''
        self.params = get_caffe_resolver().NetParameter()
        with open(self.def_path, 'rb') as def_file:
            text_format.Merge(def_file.read(), self.params)