Python tensorflow.contrib.slim 模块,one_hot_encoding() 实例源码

我们从Python开源项目中,提取了以下17个代码示例,用于说明如何使用tensorflow.contrib.slim.one_hot_encoding()

项目:RaspberryPi-Robot    作者:timestocome    | 项目源码 | 文件源码
def __init__(self, lr, s_size, a_size):

        self.state_in = tf.placeholder(shape=[1], dtype=tf.int32)
        state_in_OH = slim.one_hot_encoding(self.state_in, s_size)

        output = slim.fully_connected(state_in_OH, 
                                        a_size, 
                                        biases_initializer=None, 
                                        activation_fn=tf.nn.sigmoid,
                                        weights_initializer=tf.ones_initializer())
        self.output = tf.reshape(output, [-1])

        self.chosen_action = tf.argmax(self.output, 0)
        self.reward_holder = tf.placeholder(shape=[1], dtype=tf.float32)
        self.action_holder = tf.placeholder(shape=[1], dtype=tf.int32)

        self.responsible_weight = tf.slice(self.output, self.action_holder, [1])

        self.loss = -(tf.log(self.responsible_weight) * self.reward_holder)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
        self.update = optimizer.minimize(self.loss)
项目:RaspberryPi-Robot    作者:timestocome    | 项目源码 | 文件源码
def __init__(self, lr, s_size, a_size):

        self.state_in = tf.placeholder(shape=[1], dtype=tf.int32)
        state_in_OH = slim.one_hot_encoding(self.state_in, s_size)

        output = slim.fully_connected(state_in_OH, 
                                        a_size, 
                                        biases_initializer=None, 
                                        activation_fn=tf.nn.sigmoid,
                                        weights_initializer=tf.ones_initializer())

        self.output = tf.reshape(output, [-1])

        self.chosen_action = tf.argmax(self.output, 0)
        self.reward_holder = tf.placeholder(shape=[1], dtype=tf.float32)
        self.action_holder = tf.placeholder(shape=[1], dtype=tf.int32)

        self.responsible_weight = tf.slice(self.output, self.action_holder, [1])

        self.loss = -(tf.log(self.responsible_weight) * self.reward_holder)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
        self.update = optimizer.minimize(self.loss)
项目:social-scene-understanding    作者:cvlab-epfl    | 项目源码 | 文件源码
def det_net_loss(seg_masks_in, reg_masks_in,
                 seg_preds, reg_preds,
                 reg_loss_weight=10.0,
                 epsilon=1e-5):

  with tf.variable_scope('loss'):
    out_size = seg_preds.get_shape()[1:3]
    seg_masks_in_ds = tf.image.resize_images(seg_masks_in[:,:,:,tf.newaxis],
                                             out_size[0], out_size[1],
                                             tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    reg_masks_in_ds = tf.image.resize_images(reg_masks_in,
                                             out_size[0], out_size[1],
                                             tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # segmentation loss
    seg_masks_onehot = slim.one_hot_encoding(seg_masks_in_ds[:,:,:,0], 2)
    seg_loss = - tf.reduce_mean(seg_masks_onehot * tf.log(seg_preds + epsilon))

    # regression loss
    mask = tf.to_float(seg_masks_in_ds)
    reg_loss = tf.reduce_sum(mask * (reg_preds - reg_masks_in_ds)**2)
    reg_loss = reg_loss / (tf.reduce_sum(mask) + 1.0)

  return seg_loss + reg_loss_weight * reg_loss
项目:DNN_Recsys_demo    作者:ShouldChan    | 项目源码 | 文件源码
def __init__(self, lr, s_size,a_size):
        #These lines established the feed-forward part of the network. The agent takes a state and produces an action.
        self.state_in= tf.placeholder(shape=[1],dtype=tf.int32)
        state_in_OH = slim.one_hot_encoding(self.state_in,s_size)
        output = slim.fully_connected(state_in_OH,a_size,\
            biases_initializer=None,activation_fn=tf.nn.sigmoid,weights_initializer=tf.ones_initializer())
        self.output = tf.reshape(output,[-1])
        self.chosen_action = tf.argmax(self.output,0)

        #The next six lines establish the training proceedure. We feed the reward and chosen action into the network
        #to compute the loss, and use it to update the network.
        self.reward_holder = tf.placeholder(shape=[1],dtype=tf.float32)
        self.action_holder = tf.placeholder(shape=[1],dtype=tf.int32)
        self.responsible_weight = tf.slice(self.output,self.action_holder,[1])
        self.loss = -(tf.log(self.responsible_weight)*self.reward_holder)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
        self.update = optimizer.minimize(self.loss)
项目:tutorial_mnist    作者:machine-learning-challenge    | 项目源码 | 文件源码
def prepare_serialized_examples(self, serialized_examples):
    feature_map = {
        'image_raw': tf.FixedLenFeature([784], tf.int64),
        'label': tf.FixedLenFeature([], tf.int64),
    }
    features = tf.parse_example(serialized_examples, features=feature_map)

    images = tf.cast(features["image_raw"], tf.float32) * (1. / 255)
    labels = tf.cast(features['label'], tf.int32)

    def dense_to_one_hot(label_batch, num_classes):
      one_hot = tf.map_fn(lambda x : tf.cast(slim.one_hot_encoding(x, num_classes), tf.int32), label_batch)
      one_hot = tf.reshape(one_hot, [-1, num_classes])
      return one_hot

    labels = dense_to_one_hot(labels, 10)
    return images, labels
项目:mlc2017-online    作者:machine-learning-challenge    | 项目源码 | 文件源码
def prepare_serialized_examples(self, serialized_examples, width=50, height=50):
    # set the mapping from the fields to data types in the proto
    feature_map = {
           'image': tf.FixedLenFeature((), tf.string, default_value=''),
           'label': tf.FixedLenFeature([], tf.int64)
    }
    features = tf.parse_example(serialized_examples, features=feature_map)

    def decode_and_resize(image_str_tensor):
      """Decodes png string, resizes it and returns a uint8 tensor."""

      # Output a grayscale (channels=1) image
      image = tf.image.decode_png(image_str_tensor, channels=1)

      # Note resize expects a batch_size, but tf_map supresses that index,
      # thus we have to expand then squeeze.  Resize returns float32 in the
      # range [0, uint8_max]
      image = tf.expand_dims(image, 0)
      image = tf.image.resize_bilinear(
          image, [height, width], align_corners=False)
      image = tf.squeeze(image, squeeze_dims=[0])
      image = tf.cast(image, dtype=tf.uint8)
      return image

    images_str_tensor = features["image"]
    images = tf.map_fn(
        decode_and_resize, images_str_tensor, back_prop=False, dtype=tf.uint8)
    images = tf.image.convert_image_dtype(images, dtype=tf.float32)
    images = tf.subtract(images, 0.5)
    images = tf.multiply(images, 2.0)

    def dense_to_one_hot(label_batch, num_classes):
      one_hot = tf.map_fn(lambda x : tf.cast(slim.one_hot_encoding(x, num_classes), tf.int32), label_batch)
      one_hot = tf.reshape(one_hot, [-1, num_classes])
      return one_hot

    labels = tf.cast(features['label'], tf.int32)
    labels = dense_to_one_hot(labels, 10)

    return images, labels
项目:SSD_tensorflow_VOC    作者:LevinJ    | 项目源码 | 文件源码
def fine_tune_inception(self):
        train_dir = '/tmp/inception_finetuned/'
        image_size = inception.inception_v4.default_image_size
        checkpoint_path = "../../data/trained_models/inception_v4/inception_v4.ckpt"
        flowers_data_dir = "../../data/flower"


        with tf.Graph().as_default():
            tf.logging.set_verbosity(tf.logging.INFO)

            dataset = flowers.get_split('train', flowers_data_dir)
            images, _, labels = self.load_batch(dataset, height=image_size, width=image_size)

            # Create the model, use the default arg scope to configure the batch norm parameters.
            with slim.arg_scope(inception.inception_v4_arg_scope()):
                logits, _ = inception.inception_v4(images, num_classes=dataset.num_classes, is_training=True)

            # Specify the loss function:
            one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
            total_loss = slim.losses.softmax_cross_entropy(logits, one_hot_labels)
#             total_loss = slim.losses.get_total_loss(add_regularization_losses=False)
#             total_loss = slim.losses.get_total_loss()

            # Create some summaries to visualize the training process:
            tf.summary.scalar('losses/Total_Loss', total_loss)

            # Specify the optimizer and create the train op:
            optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
            train_op = slim.learning.create_train_op(total_loss, optimizer)

            # Run the training:
            number_of_steps = math.ceil(dataset.num_samples/32) * 1
            final_loss = slim.learning.train(
                train_op,
                logdir=train_dir,
                init_fn=self.get_init_fn(checkpoint_path),
                number_of_steps=number_of_steps)


            print('Finished training. Last batch loss %f' % final_loss)
        return
项目:SSD_tensorflow_VOC    作者:LevinJ    | 项目源码 | 文件源码
def train_save_model():
    with tf.Graph().as_default():
        tf.logging.set_verbosity(tf.logging.INFO)

        dataset = flowers.get_split('train', flowers_data_dir)
        images, _, labels = load_batch(dataset)

        # Create the model:
        logits = my_cnn(images, num_classes=dataset.num_classes, is_training=True)

        # Specify the loss function:
        one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
        slim.losses.softmax_cross_entropy(logits, one_hot_labels)
        total_loss = slim.losses.get_total_loss()

        # Create some summaries to visualize the training process:
        tf.summary.scalar('losses/Total Loss', total_loss)

        # Specify the optimizer and create the train op:
        optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
        train_op = slim.learning.create_train_op(total_loss, optimizer)

        # Run the training:
        final_loss = slim.learning.train(
          train_op,
          logdir=train_dir,
          number_of_steps=1, # For speed, we just do 1 epoch
          save_summaries_secs=1)

        print('Finished training. Final batch loss %d' % final_loss)
    return
项目:SSD_tensorflow_VOC    作者:LevinJ    | 项目源码 | 文件源码
def __get_images_labels(self):
        dataset = dataset_factory.get_dataset(
                self.dataset_name, self.dataset_split_name, self.dataset_dir)

        provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=self.num_readers,
                    common_queue_capacity=20 * self.batch_size,
                    common_queue_min=10 * self.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= self.labels_offset

        network_fn = nets_factory.get_network_fn(
                self.model_name,
                num_classes=(dataset.num_classes - self.labels_offset),
                weight_decay=self.weight_decay,
                is_training=True)

        train_image_size = self.train_image_size or network_fn.default_image_size

        preprocessing_name = self.preprocessing_name or self.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                preprocessing_name,
                is_training=True)

        image = image_preprocessing_fn(image, train_image_size, train_image_size)

        images, labels = tf.train.batch(
                [image, label],
                batch_size=self.batch_size,
                num_threads=self.num_preprocessing_threads,
                capacity=5 * self.batch_size)
        labels = slim.one_hot_encoding(
                labels, dataset.num_classes - self.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2)
        images, labels = batch_queue.dequeue()

        return images, labels
项目:sact    作者:mfigurnov    | 项目源码 | 文件源码
def _runBatch(self,
                is_training,
                model_type,
                model=[2, 2, 2, 2]):
    batch_size = 2
    height, width = 128, 128
    num_classes = 10

    with self.test_session() as sess:
      images = tf.random_uniform((batch_size, height, width, 3))
      with slim.arg_scope(
          imagenet_model.resnet_arg_scope(is_training=is_training)):
        logits, end_points = imagenet_model.get_network(
            images, model, num_classes, model_type='sact', base_channels=1)
        if model_type in ('act', 'act_early_stopping', 'sact'):
          metrics = summary_utils.act_metric_map(end_points,
              not is_training)
          metrics.update(summary_utils.flops_metric_map(end_points,
              not is_training))
        else:
          metrics = {}

      if is_training:
        labels = tf.random_uniform(
            (batch_size,), maxval=num_classes, dtype=tf.int32)
        one_hot_labels = slim.one_hot_encoding(labels, num_classes)
        tf.losses.softmax_cross_entropy(
            onehot_labels=one_hot_labels, logits=logits,
            label_smoothing=0.1, weights=1.0)
        if model_type in ('act', 'act_early_stopping', 'sact'):
          training_utils.add_all_ponder_costs(end_points, weights=1.0)
        total_loss = tf.losses.get_total_loss()
        optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
        train_op = slim.learning.create_train_op(total_loss, optimizer)
        sess.run(tf.global_variables_initializer())
        sess.run((train_op, metrics))
      else:
        sess.run([tf.local_variables_initializer(),
            tf.global_variables_initializer()])
        logits_out, metrics_out = sess.run((logits, metrics))
        self.assertEqual(logits_out.shape, (batch_size, num_classes))
项目:sact    作者:mfigurnov    | 项目源码 | 文件源码
def _runBatch(self, is_training, model_type, model=[2]):
    batch_size = 2
    height, width = 32, 32
    num_classes = 10

    with slim.arg_scope(
        cifar_model.resnet_arg_scope(is_training=is_training)):
      with self.test_session() as sess:
        images = tf.random_uniform((batch_size, height, width, 3))
        logits, end_points = cifar_model.resnet(
            images,
            model=model,
            num_classes=num_classes,
            model_type=model_type,
            base_channels=1)
        if model_type in ('act', 'act_early_stopping', 'sact'):
          metrics = summary_utils.act_metric_map(end_points,
              not is_training)
          metrics.update(summary_utils.flops_metric_map(end_points,
              not is_training))
        else:
          metrics = {}

        if is_training:
          labels = tf.random_uniform(
              (batch_size,), maxval=num_classes, dtype=tf.int32)
          one_hot_labels = slim.one_hot_encoding(labels, num_classes)
          tf.losses.softmax_cross_entropy(
              onehot_labels=one_hot_labels, logits=logits)
          if model_type in ('act', 'act_early_stopping', 'sact'):
            training_utils.add_all_ponder_costs(end_points, weights=1.0)
          total_loss = tf.losses.get_total_loss()
          optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
          train_op = slim.learning.create_train_op(total_loss, optimizer)
          sess.run(tf.global_variables_initializer())
          sess.run((train_op, metrics))
        else:
          sess.run([tf.local_variables_initializer(),
              tf.global_variables_initializer()])
          logits_out, metrics_out = sess.run((logits, metrics))
          self.assertEqual(logits_out.shape, (batch_size, num_classes))
项目:google_ml_challenge    作者:SSUHan    | 项目源码 | 文件源码
def prepare_serialized_examples(self, serialized_examples, width=50, height=50):
    # set the mapping from the fields to data types in the proto
    feature_map = {
           'image': tf.FixedLenFeature((), tf.string, default_value=''),
           'label': tf.FixedLenFeature([], tf.int64)
    }
    features = tf.parse_example(serialized_examples, features=feature_map)

    def decode_and_resize(image_str_tensor):
      """Decodes png string, resizes it and returns a uint8 tensor."""

      # Output a grayscale (channels=1) image
      image = tf.image.decode_png(image_str_tensor, channels=1)

      # Note resize expects a batch_size, but tf_map supresses that index,
      # thus we have to expand then squeeze.  Resize returns float32 in the
      # range [0, uint8_max]
      image = tf.expand_dims(image, 0)
      image = tf.image.resize_bilinear(
          image, [height, width], align_corners=False)
      image = tf.squeeze(image, squeeze_dims=[0])
      image = tf.cast(image, dtype=tf.uint8)
      return image

    images_str_tensor = features["image"]
    images = tf.map_fn(
        decode_and_resize, images_str_tensor, back_prop=False, dtype=tf.uint8)
    images = tf.image.convert_image_dtype(images, dtype=tf.float32)
    images = tf.subtract(images, 0.5)
    images = tf.multiply(images, 2.0)

    def dense_to_one_hot(label_batch, num_classes):
      one_hot = tf.map_fn(lambda x : tf.cast(slim.one_hot_encoding(x, num_classes), tf.int32), label_batch)
      one_hot = tf.reshape(one_hot, [-1, num_classes])
      return one_hot

    labels = tf.cast(features['label'], tf.int32)
    labels = dense_to_one_hot(labels, 10)

    return images, labels
项目:mlc2017-online    作者:machine-learning-challenge    | 项目源码 | 文件源码
def prepare_serialized_examples(self, serialized_examples, width=32, height=32, channels=3):
    # set the mapping from the fields to data types in the proto
    feature_map = {
           'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
           'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
           'image/class/label': tf.FixedLenFeature([], tf.int64)
    }
    features = tf.parse_example(serialized_examples, features=feature_map)


    def decode_and_resize(image_str_tensor):
      """Decodes jpeg string, resizes it and returns a uint8 tensor."""

      image = tf.image.decode_jpeg(image_str_tensor, channels=channels)

      # Note resize expects a batch_size, but tf_map supresses that index,
      # thus we have to expand then squeeze.  Resize returns float32 in the
      # range [0, uint8_max]
      image = tf.expand_dims(image, 0)
      image = tf.image.resize_bilinear(
          image, [height, width], align_corners=False)
      image = tf.squeeze(image, squeeze_dims=[0])
      image = tf.cast(image, dtype=tf.uint8)
      return image

    images_str_tensor = features["image/encoded"]
    images = tf.map_fn(
        decode_and_resize, images_str_tensor, back_prop=False, dtype=tf.uint8)
    images = tf.image.convert_image_dtype(images, dtype=tf.float32)
    images = tf.subtract(images, 0.5)
    images = tf.multiply(images, 2.0)

    def dense_to_one_hot(label_batch, num_classes):
      one_hot = tf.map_fn(lambda x : tf.cast(slim.one_hot_encoding(x, num_classes), tf.int32), label_batch)
      one_hot = tf.reshape(one_hot, [-1, num_classes])
      return one_hot

    labels = tf.cast(features['image/class/label'], tf.int32)
    labels = tf.reshape(labels, [-1, 1])


    image_ids = features['image/filename']

    return image_ids, images, labels
项目:deepmodels    作者:learningsociety    | 项目源码 | 文件源码
def train(self,
            train_input_batch,
            train_label_batch,
            train_params,
            preprocessed=True):
    """Training process of the classifier.

    Args:
      train_input_batch: input batch for training.
      train_label_batch: class id for training.
      train_params: commons.TrainTestParams object.
      preprocessed: if train data has been preprocessed.
    """
    assert train_input_batch is not None, "train input batch is none"
    assert train_label_batch is not None, "train label batch is none"
    assert isinstance(
        train_params,
        commons.TrainTestParams), "train params is not a valid type"
    self.check_dm_model_exist()
    # self.dm_model.use_graph()
    model_params = self.dm_model.net_params
    if not preprocessed:
      train_input_batch = self.dm_model.preprocess(train_input_batch)
    pred_logits, endpoints = self.build_model(train_input_batch)
    self.set_key_vars(train_params.restore_scopes_exclude,
                      train_params.train_scopes)
    comp_train_accuracy(pred_logits, train_label_batch)
    tf.assert_equal(
        tf.reduce_max(train_label_batch),
        tf.convert_to_tensor(
            model_params.cls_num, dtype=tf.int64))
    onehot_labels = tf.one_hot(
        train_label_batch, model_params.cls_num, on_value=1.0, off_value=0.0)
    # onehot_labels = slim.one_hot_encoding(train_label_batch,
    #                                       model_params.cls_num)
    onehot_labels = tf.squeeze(onehot_labels)
    self.compute_losses(onehot_labels, pred_logits, endpoints)
    init_fn = None
    if train_params.fine_tune and not train_params.resume_training:
      init_fn = slim.assign_from_checkpoint_fn(train_params.custom["model_fn"],
                                               self.vars_to_restore)
    # this would not work if a tensorboard is running...
    if not train_params.resume_training:
      data_manager.remove_dir(train_params.train_log_dir)
    # display regularization loss.
    if train_params.use_regularization:
      regularization_loss = tf.add_n(tf.losses.get_regularization_losses())
      tf.summary.scalar("losses/regularization_loss", regularization_loss)

    total_loss = tf.losses.get_total_loss(
        add_regularization_losses=train_params.use_regularization)
    base_model.train_model_given_loss(
        total_loss, self.vars_to_train, train_params, init_fn=init_fn)
项目:SSD_tensorflow_VOC    作者:LevinJ    | 项目源码 | 文件源码
def __get_images_labels(self):
        dataset = dataset_factory.get_dataset(
                self.dataset_name, self.dataset_split_name, self.dataset_dir)

        provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=self.num_readers,
                    common_queue_capacity=20 * self.batch_size,
                    common_queue_min=10 * self.batch_size)
        [image, label] = provider.get(['image', 'label'])
        label -= self.labels_offset

        network_fn = nets_factory.get_network_fn(
                self.model_name,
                num_classes=(dataset.num_classes - self.labels_offset),
                weight_decay=self.weight_decay,
                is_training=True)

        train_image_size = self.train_image_size or network_fn.default_image_size

        preprocessing_name = self.preprocessing_name or self.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                preprocessing_name,
                is_training=True)

        image = image_preprocessing_fn(image, train_image_size, train_image_size)

        images, labels = tf.train.batch(
                [image, label],
                batch_size=self.batch_size,
                num_threads=self.num_preprocessing_threads,
                capacity=5 * self.batch_size)
        labels = slim.one_hot_encoding(
                labels, dataset.num_classes - self.labels_offset)
        batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2)
        images, labels = batch_queue.dequeue()

        self.network_fn = network_fn
        self.dataset = dataset

        #set up the network

        return images, labels
项目:google_ml_challenge    作者:SSUHan    | 项目源码 | 文件源码
def prepare_serialized_examples(self, serialized_examples, width=32, height=32, channels=3):
    # set the mapping from the fields to data types in the proto
    feature_map = {
           'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
           'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
           'image/class/label': tf.FixedLenFeature([], tf.int64)
    }
    features = tf.parse_example(serialized_examples, features=feature_map)


    def decode_and_resize(image_str_tensor):
      """Decodes jpeg string, resizes it and returns a uint8 tensor."""

      image = tf.image.decode_jpeg(image_str_tensor, channels=channels)

      # Note resize expects a batch_size, but tf_map supresses that index,
      # thus we have to expand then squeeze.  Resize returns float32 in the
      # range [0, uint8_max]
      image = tf.expand_dims(image, 0)
      image = tf.image.resize_bilinear(
          image, [height, width], align_corners=False)
      image = tf.squeeze(image, squeeze_dims=[0])
      image = tf.cast(image, dtype=tf.uint8)
      return image

    images_str_tensor = features["image/encoded"]
    images = tf.map_fn(
        decode_and_resize, images_str_tensor, back_prop=False, dtype=tf.uint8)
    images = tf.image.convert_image_dtype(images, dtype=tf.float32)
    images = tf.subtract(images, 0.5)
    images = tf.multiply(images, 2.0)

    def dense_to_one_hot(label_batch, num_classes):
      one_hot = tf.map_fn(lambda x : tf.cast(slim.one_hot_encoding(x, num_classes), tf.int32), label_batch)
      one_hot = tf.reshape(one_hot, [-1, num_classes])
      return one_hot

    labels = tf.cast(features['image/class/label'], tf.int32)
    labels = tf.reshape(labels, [-1, 1])


    image_ids = features['image/filename']

    return image_ids, images, labels
项目:GestureRecognition    作者:gkchai    | 项目源码 | 文件源码
def load_batch(dataset, batch_size, is_2D = False, preprocess_fn=None, shuffle=False):
    """Loads a batch for training. dataset is class object that is created from the get_split function"""

    # First create the data_provider object
    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=shuffle,
        common_queue_capacity=2 * batch_size,
        common_queue_min=batch_size,
        num_epochs=None
    )

    # Obtain the raw image using the get method

    if is_2D:
        x, y, z, label = data_provider.get(['series/x', 'series/y', 'series/z', 'label'])
        raw_series = tf.stack([x, y, z])
        raw_series = tf.expand_dims(raw_series, -1)

    else:
        raw_series, label = data_provider.get(['series', 'label'])

    # convert to int32 from int64
    label = tf.to_int32(label)

    label_one_hot = tf.to_int32(slim.one_hot_encoding(label, dataset.num_classes))

    # Perform the correct preprocessing for the series depending if it is training or evaluating
    if preprocess_fn:
        series = preprocess_fn(raw_series)
    else:
        series = raw_series

    # Batch up the data by enqueing the tensors internally in a FIFO queue and dequeueing many
    # elements with tf.train.batch.
    series_batch, labels, labels_one_hot = tf.train.batch(
        [series, label, label_one_hot],
        batch_size=batch_size,
        allow_smaller_final_batch=True,
        num_threads=1
    )
    return series_batch, labels, labels_one_hot