Python tensorflow 模块,dynamic_stitch() 实例源码

我们从Python开源项目中,提取了以下21个代码示例,用于说明如何使用tensorflow.dynamic_stitch()

项目:visual_mpc    作者:febert    | 项目源码 | 文件源码
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
    """Sample batch with specified mix of ground truth and generated data_files points.

    Args:
      ground_truth_x: tensor of ground-truth data_files points.
      generated_x: tensor of generated data_files points.
      batch_size: batch size
      num_ground_truth: number of ground-truth examples to include in batch.
    Returns:
      New batch with num_ground_truth sampled from ground_truth_x and the rest
      from generated_x.
    """
    idx = tf.random_shuffle(tf.range(int(batch_size)))
    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
    generated_examps = tf.gather(generated_x, generated_idx)
    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                             [ground_truth_examps, generated_examps])
项目:cfrnet    作者:clinicalml    | 项目源码 | 文件源码
def _build_output_graph(self, rep, t, dim_in, dim_out, do_out, FLAGS):
        ''' Construct output/regression layers '''

        if FLAGS.split_output:

            i0 = tf.to_int32(tf.where(t < 1)[:,0])
            i1 = tf.to_int32(tf.where(t > 0)[:,0])

            rep0 = tf.gather(rep, i0)
            rep1 = tf.gather(rep, i1)

            y0, weights_out0, weights_pred0 = self._build_output(rep0, dim_in, dim_out, do_out, FLAGS)
            y1, weights_out1, weights_pred1 = self._build_output(rep1, dim_in, dim_out, do_out, FLAGS)

            y = tf.dynamic_stitch([i0, i1], [y0, y1])
            weights_out = weights_out0 + weights_out1
            weights_pred = weights_pred0 + weights_pred1
        else:
            h_input = tf.concat(1,[rep, t])
            y, weights_out, weights_pred = self._build_output(h_input, dim_in+1, dim_out, do_out, FLAGS)

        return y, weights_out, weights_pred
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
    """Sample batch with specified mix of ground truth and generated data_files points.

    Args:
      ground_truth_x: tensor of ground-truth data_files points.
      generated_x: tensor of generated data_files points.
      batch_size: batch size
      num_ground_truth: number of ground-truth examples to include in batch.
    Returns:
      New batch with num_ground_truth sampled from ground_truth_x and the rest
      from generated_x.
    """

    idx = tf.random_shuffle(tf.range(int(batch_size)))
    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
    generated_examps = tf.gather(generated_x, generated_idx)
    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                             [ground_truth_examps, generated_examps])
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
    """Sample batch with specified mix of ground truth and generated data_files points.

    Args:
      ground_truth_x: tensor of ground-truth data_files points.
      generated_x: tensor of generated data_files points.
      batch_size: batch size
      num_ground_truth: number of ground-truth examples to include in batch.
    Returns:
      New batch with num_ground_truth sampled from ground_truth_x and the rest
      from generated_x.
    """
    idx = tf.random_shuffle(tf.range(int(batch_size)))
    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
    generated_examps = tf.gather(generated_x, generated_idx)
    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                             [ground_truth_examps, generated_examps])
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
    """Sample batch with specified mix of ground truth and generated data_files points.

    Args:
      ground_truth_x: tensor of ground-truth data_files points.
      generated_x: tensor of generated data_files points.
      batch_size: batch size
      num_ground_truth: number of ground-truth examples to include in batch.
    Returns:
      New batch with num_ground_truth sampled from ground_truth_x and the rest
      from generated_x.
    """
    generated_x = tf.squeeze(generated_x)

    idx = tf.random_shuffle(tf.range(int(batch_size)))
    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
    generated_examps = tf.gather(generated_x, generated_idx)
    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                             [ground_truth_examps, generated_examps])
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
  """Sample batch with specified mix of ground truth and generated data_files points.

  Args:
    ground_truth_x: tensor of ground-truth data_files points.
    generated_x: tensor of generated data_files points.
    batch_size: batch size
    num_ground_truth: number of ground-truth examples to include in batch.
  Returns:
    New batch with num_ground_truth sampled from ground_truth_x and the rest
    from generated_x.
  """
  idx = tf.random_shuffle(tf.range(int(batch_size)))
  ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
  generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

  ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
  generated_examps = tf.gather(generated_x, generated_idx)
  return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                           [ground_truth_examps, generated_examps])
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
    """Sample batch with specified mix of ground truth and generated data_files points.

    Args:
      ground_truth_x: tensor of ground-truth data_files points.
      generated_x: tensor of generated data_files points.
      batch_size: batch size
      num_ground_truth: number of ground-truth examples to include in batch.
    Returns:
      New batch with num_ground_truth sampled from ground_truth_x and the rest
      from generated_x.
    """
    idx = tf.random_shuffle(tf.range(int(batch_size)))
    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
    generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
    generated_examps = tf.gather(generated_x, generated_idx)
    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                             [ground_truth_examps, generated_examps])
项目:hand3d    作者:lmb-freiburg    | 项目源码 | 文件源码
def _stitch_mat_from_vecs(vector_list):
        """ Stitches a given list of vectors into a 3x3 matrix.

            Input:
                vector_list: list of 9 tensors, which will be stitched into a matrix. list contains matrix elements
                    in a row-first fashion (m11, m12, m13, m21, m22, m23, m31, m32, m33). Length of the vectors has
                    to be the same, because it is interpreted as batch dimension.
        """

        assert len(vector_list) == 9, "There have to be exactly 9 tensors in vector_list."
        batch_size = vector_list[0].get_shape().as_list()[0]
        vector_list = [tf.reshape(x, [1, batch_size]) for x in vector_list]

        trafo_matrix = tf.dynamic_stitch([[0], [1], [2],
                                          [3], [4], [5],
                                          [6], [7], [8]], vector_list)

        trafo_matrix = tf.reshape(trafo_matrix, [3, 3, batch_size])
        trafo_matrix = tf.transpose(trafo_matrix, [2, 0, 1])

        return trafo_matrix
项目:hand3d    作者:lmb-freiburg    | 项目源码 | 文件源码
def _stitch_mat_from_vecs(vector_list):
        """ Stitches a given list of vectors into a 3x3 matrix.

            Input:
                vector_list: list of 9 tensors, which will be stitched into a matrix. list contains matrix elements
                    in a row-first fashion (m11, m12, m13, m21, m22, m23, m31, m32, m33). Length of the vectors has
                    to be the same, because it is interpreted as batch dimension.
        """

        assert len(vector_list) == 9, "There have to be exactly 9 tensors in vector_list."
        batch_size = vector_list[0].get_shape().as_list()[0]
        vector_list = [tf.reshape(x, [1, batch_size]) for x in vector_list]

        trafo_matrix = tf.dynamic_stitch([[0], [1], [2],
                                          [3], [4], [5],
                                          [6], [7], [8]], vector_list)

        trafo_matrix = tf.reshape(trafo_matrix, [3, 3, batch_size])
        trafo_matrix = tf.transpose(trafo_matrix, [2, 0, 1])

        return trafo_matrix
项目:hand3d    作者:lmb-freiburg    | 项目源码 | 文件源码
def _stitch_mat_from_vecs(vector_list):
    """ Stitches a given list of vectors into a 4x4 matrix.

        Input:
            vector_list: list of 16 tensors, which will be stitched into a matrix. list contains matrix elements
                in a row-first fashion (m11, m12, m13, m14, m21, m22, m23, m24, ...). Length of the vectors has
                to be the same, because it is interpreted as batch dimension.
    """

    assert len(vector_list) == 16, "There have to be exactly 16 tensors in vector_list."
    batch_size = vector_list[0].get_shape().as_list()[0]
    vector_list = [tf.reshape(x, [1, batch_size]) for x in vector_list]

    trafo_matrix = tf.dynamic_stitch([[0], [1], [2], [3],
                                      [4], [5], [6], [7],
                                      [8], [9], [10], [11],
                                      [12], [13], [14], [15]], vector_list)

    trafo_matrix = tf.reshape(trafo_matrix, [4, 4, batch_size])
    trafo_matrix = tf.transpose(trafo_matrix, [2, 0, 1])

    return trafo_matrix
项目:hand3d    作者:lmb-freiburg    | 项目源码 | 文件源码
def _stitch_mat_from_vecs(vector_list):
    """ Stitches a given list of vectors into a 3x3 matrix.

        Input:
            vector_list: list of 9 tensors, which will be stitched into a matrix. list contains matrix elements
                in a row-first fashion (m11, m12, m13, m21, m22, m23, m31, m32, m33). Length of the vectors has
                to be the same, because it is interpreted as batch dimension.
    """

    assert len(vector_list) == 9, "There have to be exactly 9 tensors in vector_list."
    batch_size = vector_list[0].get_shape().as_list()[0]
    vector_list = [tf.reshape(x, [1, batch_size]) for x in vector_list]

    trafo_matrix = tf.dynamic_stitch([[0], [1], [2],
                                      [3], [4], [5],
                                      [6], [7], [8]], vector_list)

    trafo_matrix = tf.reshape(trafo_matrix, [3, 3, batch_size])
    trafo_matrix = tf.transpose(trafo_matrix, [2, 0, 1])

    return trafo_matrix
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def split_apply_merge(inp, partitions, fns):
  """Split input according to partitions.  Pass results through fns and merge.

  Args:
    inp: the input vector
    partitions: tensor of same length as input vector, having values 0, 1
    fns: the two functions.

  Returns:
    the vector routed, where routed[i] = fns[partitions[i]](inp[i])
  """
  new_inputs = tf.dynamic_partition(inp, partitions, len(fns))
  new_outputs = [fns[i](x) for i, x in enumerate(new_inputs)]
  new_indices = tf.dynamic_partition(
      tf.range(0, inp.get_shape()[0]), partitions, len(fns))
  return tf.dynamic_stitch(new_indices, new_outputs)
项目:mann-for-speech-separation    作者:KWTsou1220    | 项目源码 | 文件源码
def circular_convolution(v, k, size):
    """Computes circular convolution.
    Args:
        v: a 1-D `Tensor` (vector)
        k: a 1-D `Tensor` (kernel)
        size: a int scalar indicating size of the kernel k
    """
    kernel_size  = int(k.get_shape()[1])
    kernel_shift = int(math.floor(kernel_size/2.0))
    v_list = tf.split(0, size, v)

    def loop(idx):
        if idx < 0: return size + idx
        if idx >= size : return idx - size
        else: return idx

    kernels = []
    for i in xrange(size):
        indices = [loop(i+j) for j in xrange(kernel_shift, -kernel_shift-1, -1)]
        #v_ = tf.gather(v, indices)
        v_sublist = [v_list[indices[j]] for j in range(len(indices))]
        v_        = tf.concat(0, v_sublist)
        kernels.append(tf.reduce_sum(v_ * tf.transpose(k), 0, keep_dims=True))

    return tf.concat(0, kernels)
    #return tf.dynamic_stitch([i for i in xrange(size)], kernels)
项目:GPflow    作者:GPflow    | 项目源码 | 文件源码
def _partition_and_stitch(self, args, func_name):
        """
        args is a list of tensors, to be passed to self.likelihoods.<func_name>

        args[-1] is the 'Y' argument, which contains the indexes to self.likelihoods.

        This function splits up the args using dynamic_partition, calls the
        relevant function on the likelihoods, and re-combines the result.
        """
        # get the index from Y
        Y = args[-1]
        ind = tf.gather(tf.transpose(Y), tf.shape(Y)[1]-1)  # ind = Y[:,-1]
        ind = tf.cast(ind, tf.int32)
        Y = tf.transpose(tf.gather(tf.transpose(Y), tf.range(0, tf.shape(Y)[1]-1)))  # Y = Y[:,:-1]
        args[-1] = Y

        # split up the arguments into chunks corresponding to the relevant likelihoods
        args = zip(*[tf.dynamic_partition(X, ind, self.num_likelihoods) for X in args])

        # apply the likelihood-function to each section of the data

        with params_as_tensors_for(self, convert=False):
            funcs = [getattr(lik, func_name) for lik in self.likelihood_list]
        results = [f(*args_i) for f, args_i in zip(funcs, args)]

        # stitch the results back together
        partitions = tf.dynamic_partition(tf.range(0, tf.size(ind)), ind, self.num_likelihoods)
        results = tf.dynamic_stitch(partitions, results)

        return results
项目:GPflow    作者:GPflow    | 项目源码 | 文件源码
def __call__(self, X):
        ind = tf.gather(tf.transpose(X), tf.shape(X)[1]-1)  # ind = X[:,-1]
        ind = tf.cast(ind, tf.int32)
        X = tf.transpose(tf.gather(tf.transpose(X), tf.range(0, tf.shape(X)[1]-1)))  # X = X[:,:-1]

        # split up X into chunks corresponding to the relevant likelihoods
        x_list = tf.dynamic_partition(X, ind, self.num_meanfunctions)
        # apply the likelihood-function to each section of the data
        results = [m(x) for x, m in zip(x_list, self.meanfunction_list)]
        # stitch the results back together
        partitions = tf.dynamic_partition(tf.range(0, tf.size(ind)), ind, self.num_meanfunctions)
        return tf.dynamic_stitch(partitions, results)
项目:nengo_dl    作者:nengo    | 项目源码 | 文件源码
def test_dynamic_stitch():
    x = tf.zeros((1, 3))
    y = tf.dynamic_stitch([[0], [0]], [x, tf.ones((1, 3))])
    z = tf.gather(y, [0])

    with tf.Session():
        analytic, numeric = tf.test.compute_gradient(x, (1, 3), z, (1, 3))

        assert np.allclose(analytic, numeric)
项目:TensorFlow-in-a-Nutshell    作者:camrongodbout    | 项目源码 | 文件源码
def split_apply_merge(inp, partitions, fns):
  """Split input according to partitions.  Pass results through fns and merge.
  Args:
    inp: the input vector
    partitions: tensor of same length as input vector, having values 0, 1
    fns: the two functions.
  Returns:
    the vector routed, where routed[i] = fns[partitions[i]](inp[i])
  """
  new_inputs = tf.dynamic_partition(inp, partitions, len(fns))
  new_outputs = [fns[i](x) for i, x in enumerate(new_inputs)]
  new_indices = tf.dynamic_partition(
      tf.range(0, inp.get_shape()[0]), partitions, len(fns))
  return tf.dynamic_stitch(new_indices, new_outputs)
项目:tensorflow    作者:luyishisi    | 项目源码 | 文件源码
def _create_regression_targets(self, anchors, groundtruth_boxes, match):
    """Returns a regression target for each anchor.

    Args:
      anchors: a BoxList representing N anchors
      groundtruth_boxes: a BoxList representing M groundtruth_boxes
      match: a matcher.Match object

    Returns:
      reg_targets: a float32 tensor with shape [N, box_code_dimension]
    """
    matched_anchor_indices = match.matched_column_indices()
    unmatched_ignored_anchor_indices = (match.
                                        unmatched_or_ignored_column_indices())
    matched_gt_indices = match.matched_row_indices()
    matched_anchors = box_list_ops.gather(anchors,
                                          matched_anchor_indices)
    matched_gt_boxes = box_list_ops.gather(groundtruth_boxes,
                                           matched_gt_indices)
    matched_reg_targets = self._box_coder.encode(matched_gt_boxes,
                                                 matched_anchors)
    unmatched_ignored_reg_targets = tf.tile(
        self._default_regression_target(),
        tf.stack([tf.size(unmatched_ignored_anchor_indices), 1]))
    reg_targets = tf.dynamic_stitch(
        [matched_anchor_indices, unmatched_ignored_anchor_indices],
        [matched_reg_targets, unmatched_ignored_reg_targets])
    # TODO: summarize the number of matches on average.
    return reg_targets
项目:tensorflow    作者:luyishisi    | 项目源码 | 文件源码
def _create_classification_targets(self, groundtruth_labels, match):
    """Create classification targets for each anchor.

    Assign a classification target of for each anchor to the matching
    groundtruth label that is provided by match.  Anchors that are not matched
    to anything are given the target self._unmatched_cls_target

    Args:
      groundtruth_labels:  a tensor of shape [num_gt_boxes, d_1, ... d_k]
        with labels for each of the ground_truth boxes. The subshape
        [d_1, ... d_k] can be empty (corresponding to scalar labels).
      match: a matcher.Match object that provides a matching between anchors
        and groundtruth boxes.

    Returns:
      cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
        where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels
        which has shape [num_gt_boxes, d_1, d_2, ... d_k].
    """
    matched_anchor_indices = match.matched_column_indices()
    unmatched_ignored_anchor_indices = (match.
                                        unmatched_or_ignored_column_indices())
    matched_gt_indices = match.matched_row_indices()
    matched_cls_targets = tf.gather(groundtruth_labels, matched_gt_indices)

    ones = self._unmatched_cls_target.shape.ndims * [1]
    unmatched_ignored_cls_targets = tf.tile(
        tf.expand_dims(self._unmatched_cls_target, 0),
        tf.stack([tf.size(unmatched_ignored_anchor_indices)] + ones))

    cls_targets = tf.dynamic_stitch(
        [matched_anchor_indices, unmatched_ignored_anchor_indices],
        [matched_cls_targets, unmatched_ignored_cls_targets])
    return cls_targets
项目:tensorflow    作者:luyishisi    | 项目源码 | 文件源码
def indices_to_dense_vector(indices,
                            size,
                            indices_value=1.,
                            default_value=0,
                            dtype=tf.float32):
  """Creates dense vector with indices set to specific value and rest to zeros.

  This function exists because it is unclear if it is safe to use
    tf.sparse_to_dense(indices, [size], 1, validate_indices=False)
  with indices which are not ordered.
  This function accepts a dynamic size (e.g. tf.shape(tensor)[0])

  Args:
    indices: 1d Tensor with integer indices which are to be set to
        indices_values.
    size: scalar with size (integer) of output Tensor.
    indices_value: values of elements specified by indices in the output vector
    default_value: values of other elements in the output vector.
    dtype: data type.

  Returns:
    dense 1D Tensor of shape [size] with indices set to indices_values and the
        rest set to default_value.
  """
  size = tf.to_int32(size)
  zeros = tf.ones([size], dtype=dtype) * default_value
  values = tf.ones_like(indices, dtype=dtype) * indices_value

  return tf.dynamic_stitch([tf.range(size), tf.to_int32(indices)],
                           [zeros, values])
项目:dcn.tf    作者:beopst    | 项目源码 | 文件源码
def replace_features(coarse_features, fine_features, replace_idxs):
    """ Replace fine features with the corresponding coarse features

        Trick.
            use tf.dynamic_stitch ops

    """

    # TODO: simplify indexing 
    def _convert_to_1d_idxs(src_idxs):
        """ Convert 2D idxs to 1D idxs 
            within 1D tensor whose shape is (b*h*w*c)
        """
        batch_idx_len = map_channel.value * map_width.value * map_height.value
        batch_idx_base = [i*batch_idx_len for i in xrange(batch_size.value)]

        batch_1d = map_channel.value * map_width.value * src_idxs[:,0] + \
                   map_channel.value * src_idxs[:,1]
        batch_1d = tf.add(batch_1d,batch_idx_base)

        flat_idxs = [batch_1d+i for i in xrange(map_channel.value)]
        flat_idxs = tf.reshape(tf.transpose(tf.pack(flat_idxs)), [-1])

        return flat_idxs

    batch_size, map_height, map_width, map_channel = coarse_features.get_shape()

    # flatten coarse features
    flat_coarse_features = tf.reshape(coarse_features, [batch_size.value,-1])
    flat_coarse_features = tf.reshape(flat_coarse_features, [-1])


    # flatten fine features
    flat_fine_features = [tf.reshape(i,[-1]) for i in fine_features]
    flat_fine_features = tf.concat(0,flat_fine_features)

    flat_fine_idxs = [_convert_to_1d_idxs(i) for i in replace_idxs]
    flat_fine_idxs = tf.concat(0,flat_fine_idxs)

    # extract coarse features to be replaced
    # this is required for hint-based training
    flat_coarse_replaced = tf.gather(flat_coarse_features, flat_fine_idxs, validate_indices=False)

    merged = tf.dynamic_stitch([tf.range(0,flat_coarse_features.get_shape()[0]),flat_fine_idxs],
            [flat_coarse_features,flat_fine_features])

    merged = tf.reshape(merged,coarse_features.get_shape())

    return merged, flat_coarse_replaced, flat_fine_features