Python osgeo.gdal 模块,Dataset() 实例源码

我们从Python开源项目中,提取了以下35个代码示例,用于说明如何使用osgeo.gdal.Dataset()

项目:python_scripting_for_spatial_data_processing    作者:upsdeepak    | 项目源码 | 文件源码
def create_mask_from_vector(vector_data_path, cols, rows, geo_transform, projection, target_value=1,
                            output_fname='', dataset_format='MEM'):
    """
    Rasterize the given vector (wrapper for gdal.RasterizeLayer). Return a gdal.Dataset.
    :param vector_data_path: Path to a shapefile
    :param cols: Number of columns of the result
    :param rows: Number of rows of the result
    :param geo_transform: Returned value of gdal.Dataset.GetGeoTransform (coefficients for
                          transforming between pixel/line (P,L) raster space, and projection
                          coordinates (Xp,Yp) space.
    :param projection: Projection definition string (Returned by gdal.Dataset.GetProjectionRef)
    :param target_value: Pixel value for the pixels. Must be a valid gdal.GDT_UInt16 value.
    :param output_fname: If the dataset_format is GeoTIFF, this is the output file name
    :param dataset_format: The gdal.Dataset driver name. [default: MEM]
    """
    data_source = gdal.OpenEx(vector_data_path, gdal.OF_VECTOR)
    if data_source is None:
        report_and_exit("File read failed: %s", vector_data_path)
    layer = data_source.GetLayer(0)
    driver = gdal.GetDriverByName(dataset_format)
    target_ds = driver.Create(output_fname, cols, rows, 1, gdal.GDT_UInt16)
    target_ds.SetGeoTransform(geo_transform)
    target_ds.SetProjection(projection)
    gdal.RasterizeLayer(target_ds, [1], layer, burn_values=[target_value])
    return target_ds
项目:pygeotools    作者:dshean    | 项目源码 | 文件源码
def ds_getma(ds, bnum=1):
    """Get masked array from input GDAL Dataset

    Parameters
    ----------
    ds : gdal.Dataset 
        Input GDAL Datset
    bnum : int, optional
        Band number

    Returns
    -------
    np.ma.array    
        Masked array containing raster values
    """
    b = ds.GetRasterBand(bnum)
    return b_getma(b)

#Given input band, return a masked array
项目:pygeotools    作者:dshean    | 项目源码 | 文件源码
def gdal2np_dtype(b):
    """
    Get NumPy datatype that corresponds with GDAL RasterBand datatype
    Input can be filename, GDAL Dataset, GDAL RasterBand, or GDAL integer dtype
    """
    dt_dict = gdal_array.codes
    if isinstance(b, str):
        b = gdal.Open(b)
    if isinstance(b, gdal.Dataset):
        b = b.GetRasterBand(1)
    if isinstance(b, gdal.Band):
        b = b.DataType
    if isinstance(b, int):
        np_dtype = dt_dict[b]
    else:
        np_dtype = None
        print("Input must be GDAL Dataset or RasterBand object")
    return np_dtype

#Replace nodata value in GDAL band
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def read_gdal_projection(dset):
    """Get a projection (OSR object) from a GDAL dataset.

    Parameters
    ----------
    dset : gdal.Dataset

    Returns
    -------
    srs : OSR.SpatialReference
        dataset projection object

    Examples
    --------

    See :ref:`notebooks/classify/wradlib_clutter_cloud_example.ipynb`.

    """
    wkt = dset.GetProjection()
    srs = osr.SpatialReference()
    srs.ImportFromWkt(wkt)
    # src = None
    return srs
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def density_slice(rast, rel=np.less_equal, threshold=1000, nodata=-9999):
    '''
    Returns a density slice from a given raster. Arguments:
        rast        A gdal.Dataset or a NumPy array
        rel         A NumPy logic function; defaults to np.less_equal
        threshold   An integer number
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    if (len(rastr.shape) > 2 and min(rastr.shape) > 1):
        raise ValueError('Expected a single-band raster array')

    return np.logical_and(
        rel(rastr, np.ones(rast.shape) * threshold),
        np.not_equal(rastr, np.ones(rast.shape) * nodata)).astype(np.int0)
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def spectra_at_xy(rast, ll, gt=None, wkt=None, dd=False):
    '''
    Returns the spectral profile of the pixels indicated by the longitude-
    latitude pairs provided. Arguments:
        rast    A gdal.Dataset or NumPy array
        ll      An array of longitude-latitude pairs
        gt      A GDAL GeoTransform tuple; ignored for gdal.Dataset
        wkt     Well-Known Text projection information; ignored for gdal.Dataset
        dd      Interpret the longitude-latitude pairs as decimal degrees
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        gt = rast.GetGeoTransform()
        wkt = rast.GetProjection()
        rast = rast.ReadAsArray()

    # You would think that transposing the matrix can't be as fast as
    #   transposing the coordinate pairs, however, it is.
    return spectra_at_idx(rast.transpose(), xy_to_pixel(ll,
        gt=gt, wkt=wkt, dd=dd))
项目:python_scripting_for_spatial_data_processing    作者:upsdeepak    | 项目源码 | 文件源码
def vectors_to_raster(file_paths, rows, cols, geo_transform, projection):
    """
    Rasterize, in a single image, all the vectors in the given directory.
    The data of each file will be assigned the same pixel value. This value is defined by the order
    of the file in file_paths, starting with 1: so the points/poligons/etc in the same file will be
    marked as 1, those in the second file will be 2, and so on.
    :param file_paths: Path to a directory with shapefiles
    :param rows: Number of rows of the result
    :param cols: Number of columns of the result
    :param geo_transform: Returned value of gdal.Dataset.GetGeoTransform (coefficients for
                          transforming between pixel/line (P,L) raster space, and projection
                          coordinates (Xp,Yp) space.
    :param projection: Projection definition string (Returned by gdal.Dataset.GetProjectionRef)
    """
    labeled_pixels = np.zeros((rows, cols))
    for i, path in enumerate(file_paths):
        label = i+1
        logger.debug("Processing file %s: label (pixel value) %i", path, label)
        ds = create_mask_from_vector(path, cols, rows, geo_transform, projection,
                                     target_value=label)
        band = ds.GetRasterBand(1)
        a = band.ReadAsArray()
        logger.debug("Labeled pixels: %i", len(a.nonzero()[0]))
        labeled_pixels += a
        ds = None
    return labeled_pixels
项目:python_scripting_for_spatial_data_processing    作者:upsdeepak    | 项目源码 | 文件源码
def write_geotiff(fname, data, geo_transform, projection, data_type=gdal.GDT_Byte):
    """
    Create a GeoTIFF file with the given data.
    :param fname: Path to a directory with shapefiles
    :param data: Number of rows of the result
    :param geo_transform: Returned value of gdal.Dataset.GetGeoTransform (coefficients for
                          transforming between pixel/line (P,L) raster space, and projection
                          coordinates (Xp,Yp) space.
    :param projection: Projection definition string (Returned by gdal.Dataset.GetProjectionRef)
    """
    driver = gdal.GetDriverByName('GTiff')
    rows, cols = data.shape
    dataset = driver.Create(fname, cols, rows, 1, data_type)
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    band = dataset.GetRasterBand(1)
    band.WriteArray(data)

    ct = gdal.ColorTable()
    for pixel_value in range(len(classes)+1):
        color_hex = COLORS[pixel_value]
        r = int(color_hex[1:3], 16)
        g = int(color_hex[3:5], 16)
        b = int(color_hex[5:7], 16)
        ct.SetColorEntry(pixel_value, (r, g, b, 255))
    band.SetColorTable(ct)

    metadata = {
        'TIFFTAG_COPYRIGHT': 'CC BY 4.0',
        'TIFFTAG_DOCUMENTNAME': 'classification',
        'TIFFTAG_IMAGEDESCRIPTION': 'Supervised classification.',
        'TIFFTAG_MAXSAMPLEVALUE': str(len(classes)),
        'TIFFTAG_MINSAMPLEVALUE': '0',
        'TIFFTAG_SOFTWARE': 'Python, GDAL, scikit-learn'
    }
    dataset.SetMetadata(metadata)

    dataset = None  # Close the file
    return
项目:pygeotools    作者:dshean    | 项目源码 | 文件源码
def get_sub_dim(src_ds, scale=None, maxdim=1024):
    """Compute dimensions of subsampled dataset 

    Parameters
    ----------
    ds : gdal.Dataset 
        Input GDAL Datset
    scale : int, optional
        Scaling factor
    maxdim : int, optional 
        Maximum dimension along either axis, in pixels

    Returns
    -------
    ns
        Numper of samples in subsampled output
    nl
        Numper of lines in subsampled output
    """
    ns = src_ds.RasterXSize
    nl = src_ds.RasterYSize
    maxdim = float(maxdim)
    if scale is None:
        scale_ns = ns/maxdim
        scale_nl = nl/maxdim
        scale = max(scale_ns, scale_nl)
    #Need to check to make sure scale is positive real 
    if scale > 1:
        ns = int(round(ns/scale))
        nl = int(round(nl/scale))
    return ns, nl
项目:pygeotools    作者:dshean    | 项目源码 | 文件源码
def ds_getma_sub(src_ds, bnum=1, scale=None, maxdim=1024.):    
    """Load a subsampled array, rather than full resolution

    This is useful when working with large rasters

    Uses buf_xsize and buf_ysize options from GDAL ReadAsArray method.

    Parameters
    ----------
    ds : gdal.Dataset 
        Input GDAL Datset
    bnum : int, optional
        Band number
    scale : int, optional
        Scaling factor
    maxdim : int, optional 
        Maximum dimension along either axis, in pixels

    Returns
    -------
    np.ma.array    
        Masked array containing raster values
    """
    #print src_ds.GetFileList()[0]
    b = src_ds.GetRasterBand(bnum)
    b_ndv = get_ndv_b(b)
    ns, nl = get_sub_dim(src_ds, scale, maxdim)
    #The buf_size parameters determine the final array dimensions
    b_array = b.ReadAsArray(buf_xsize=ns, buf_ysize=nl)
    bma = np.ma.masked_values(b_array, b_ndv)
    return bma

#Note: need to consolidate with warplib.writeout (takes ds, not ma)
#Add option to build overviews when writing GTiff
#Input proj must be WKT
项目:pygeotools    作者:dshean    | 项目源码 | 文件源码
def memwarp(src_ds, res=None, extent=None, t_srs=None, r=None, oudir=None, dst_ndv=0, verbose=True):
    """Helper function that calls warp for single input Dataset with output to memory (GDAL Memory Driver)
    """
    driver = iolib.mem_drv
    return warp(src_ds, res, extent, t_srs, r, driver=driver, dst_ndv=dst_ndv, verbose=verbose)

#Use this to warp directly to output file - no need to write to memory then CreateCopy
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def read_gdal_coordinates(dataset, mode='centers', z=True):
    """Get the projected coordinates from a GDAL dataset.

    Parameters
    ----------
    dataset : gdal.Dataset
        raster image with georeferencing
    mode : string
        either 'centers' or 'borders'
    z : boolean
        True to get height coordinates (zero).

    Returns
    -------
    coordinates : :class:`numpy:numpy.ndarray`
        Array of projected coordinates (x,y,z)

    Examples
    --------

    See :ref:`notebooks/classify/wradlib_clutter_cloud_example.ipynb`.

    """
    coordinates_pixel = pixel_coordinates(dataset.RasterXSize,
                                          dataset.RasterYSize, mode)
    geotransform = dataset.GetGeoTransform()
    if z:
        coordinates = pixel_to_map3d(geotransform, coordinates_pixel)
    else:
        coordinates = pixel_to_map(geotransform, coordinates_pixel)
    return (coordinates)
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def extract_raster_dataset(dataset, nodata=None):
    """ Extract data, coordinates and projection information

    Parameters
    ----------
    dataset : gdal.Dataset
        raster dataset
    nodata : scalar
        Value to which the dataset nodata values are mapped.

    Returns
    -------
    data : :class:`numpy:numpy.ndarray`
        Array of shape (rows, cols) or (bands, rows, cols) containing
        the data values.
    coords : :class:`numpy:numpy.ndarray`
        Array of shape (rows, cols, 2) containing xy-coordinates.
    projection : osr object
        Spatial reference system of the used coordinates.
    """

    # data values
    data = read_gdal_values(dataset, nodata=nodata)

    # coords
    coords_pixel = pixel_coordinates(dataset.RasterXSize,
                                     dataset.RasterYSize,
                                     'edges')
    coords = pixel_to_map(dataset.GetGeoTransform(),
                          coords_pixel)

    projection = read_gdal_projection(dataset)

    return data, coords, projection
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def ogr_copy_layer(src_ds, index, dst_ds, reset=True):
    """ Copy OGR.Layer object.

    .. versionadded:: 0.7.0

    Copy OGR.Layer object from src_ds gdal.Dataset to dst_ds gdal.Dataset

    Parameters
    ----------
    src_ds : gdal.Dataset
        object
    index : int
        layer index
    dst_ds : gdal.Dataset
        object
    reset : bool
        if True resets src_layer
    """
    # get and copy src geometry layer

    src_lyr = src_ds.GetLayerByIndex(index)
    if reset:
        src_lyr.ResetReading()
        src_lyr.SetSpatialFilter(None)
        src_lyr.SetAttributeFilter(None)
    dst_ds.CopyLayer(src_lyr, src_lyr.GetName())
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def ogr_copy_layer_by_name(src_ds, name, dst_ds, reset=True):
    """ Copy OGR.Layer object.

    .. versionadded:: 0.8.0

    Copy OGR.Layer object from src_ds gdal.Dataset to dst_ds gdal.Dataset

    Parameters
    ----------
    src_ds : gdal.Dataset
        object
    name : string
        layer name
    dst_ds : gdal.Dataset
        object
    reset : bool
        if True resets src_layer
    """
    # get and copy src geometry layer

    src_lyr = src_ds.GetLayerByName(name)
    if reset:
        src_lyr.ResetReading()
        src_lyr.SetSpatialFilter(None)
        src_lyr.SetAttributeFilter(None)
    dst_ds.CopyLayer(src_lyr, src_lyr.GetName())
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def ogr_add_feature(ds, src, name=None):
    """ Creates OGR.Feature objects in OGR.Layer object.

    .. versionadded:: 0.7.0

    OGR.Features are built from numpy src points or polygons.

    OGR.Features 'FID' and 'index' corresponds to source data element

    Parameters
    ----------
    ds : gdal.Dataset
        object
    src : :func:`numpy:numpy.array`
        source data
    name : string
        name of wanted Layer
    """

    if name is not None:
        lyr = ds.GetLayerByName(name)
    else:
        lyr = ds.GetLayer()

    defn = lyr.GetLayerDefn()
    geom_name = ogr.GeometryTypeToName(lyr.GetGeomType())
    fields = [defn.GetFieldDefn(i).GetName()
              for i in range(defn.GetFieldCount())]
    feat = ogr.Feature(defn)

    for index, src_item in enumerate(src):
        geom = numpy_to_ogr(src_item, geom_name)

        if 'index' in fields:
            feat.SetField('index', index)

        feat.SetGeometry(geom)
        lyr.CreateFeature(feat)
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def open_vector(filename, driver=None):
    """Open vector file, return gdal.Dataset and OGR.Layer

        .. warning:: dataset and layer have to live in the same context,
            if dataset is deleted all layer references will get lost

        .. versionadded:: 0.12.0

    Parameters
    ----------
    filename : string
        vector file name
    driver : string
        gdal driver string

    Returns
    -------
    dataset : gdal.Dataset
        dataset
    layer : ogr.Layer
        layer
    """
    dataset = gdal.OpenEx(filename)

    if driver:
        gdal.GetDriverByName(driver)

    layer = dataset.GetLayer()

    return dataset, layer
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def open_shape(filename, driver=None):
    """Open shapefile, return gdal.Dataset and OGR.Layer

        .. warning:: dataset and layer have to live in the same context,
            if dataset is deleted all layer references will get lost

        .. versionadded:: 0.6.0

    Parameters
    ----------
    filename : string
        shapefile name
    driver : string
        gdal driver string

    Returns
    -------
    dataset : gdal.Dataset
        dataset
    layer : ogr.Layer
        layer
    """

    if driver is None:
        driver = ogr.GetDriverByName('ESRI Shapefile')
    dataset = driver.Open(filename)
    if dataset is None:
        print('Could not open file')
        raise IOError
    layer = dataset.GetLayer()
    return dataset, layer
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def normalize_reflectance_within_image(rast, nodata=-9999, scale=100):
    '''
    Following Wu (2004, Remote Sensing of Environment), normalizes the
    reflectances in each pixel by the average reflectance *across bands.*
    This is an attempt to mitigate within-endmember variability. Arguments:
        rast    A gdal.Dataset or numpy.array instance
        nodata  The NoData value to use (and value to ignore)
        scale   (Optional) Wu's definition scales the normalized reflectance
                by 100 for some reason; another reasonable value would
                be 10,000 (approximating scale of Landsat reflectance units);
                set to None for no scaling.
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    shp = rastr.shape
    rastr_normalized = np.divide(
        rastr.reshape((shp[0], shp[1]*shp[2])),
        rastr.mean(axis=0).reshape((1, shp[1]*shp[2])).repeat(shp[0], axis=0))

    # Recover original shape; scale if necessary
    rastr_normalized = rastr_normalized.reshape(shp)
    if scale is not None:
        rastr_normalized = np.multiply(rastr_normalized, scale)

    # Fill in the NoData areas from the original raster
    np.place(rastr_normalized, rastr == nodata, nodata)
    return rastr_normalized
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def histogram(arr, valid_range=(0, 1), bins=10, normed=False, cumulative=False,
        file_path='hist.png', title=None):
    '''
    Plots a histogram for an input array over a specified range.
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(arr, np.ndarray):
        arr = arr.ReadAsArray()

    plt.hist(arr.ravel(), range=valid_range, bins=bins, normed=normed,
        cumulative=cumulative)
    if title is not None:
        plt.title(title)

    plt.savefig(file_path)
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def test_file_raster_and_array_access(self):
        '''
        Tests that essential file reading and raster/array conversion utilities
        are working properly.
        '''
        from_as_array = as_array(os.path.join(self.test_dir, 'multi3_raster.tiff'))
        from_as_raster = as_raster(os.path.join(self.test_dir, 'multi3_raster.tiff'))
        self.assertTrue(len(from_as_array) == len(from_as_raster) == 3)
        self.assertTrue(isinstance(from_as_array[0], np.ndarray))
        self.assertTrue(isinstance(from_as_raster[0], gdal.Dataset))
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def clean_mask(rast):
    '''
    Clips the values in a mask to the interval [0, 1]; values larger than 1
    become 1 and values smaller than 0 become 0.
    Arguments:
        rast    An input gdal.Dataset or numpy.array instance
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    return np.clip(rastr, a_min=0, a_max=1)
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def dump_raster(rast, rast_path, xoff=0, yoff=0, driver='GTiff', nodata=None):
    '''
    Creates a raster file from a given GDAL dataset (raster). Arguments:
        rast        A gdal.Dataset; does NOT accept NumPy array
        rast_path   The path of the output raster file
        xoff        Offset in the x-direction; should be provided when clipped
        yoff        Offset in the y-direction; should be provided when clipped
        driver      The name of the GDAL driver to use (determines file type)
        nodata      The NoData value; defaults to -9999.
    '''
    driver = gdal.GetDriverByName(driver)
    sink = driver.Create(rast_path, rast.RasterXSize, rast.RasterYSize,
        rast.RasterCount, rast.GetRasterBand(1).DataType)
    assert sink is not None, 'Cannot create dataset; there may be a problem with the output path you specified'
    sink.SetGeoTransform(rast.GetGeoTransform())
    sink.SetProjection(rast.GetProjection())

    for b in range(1, rast.RasterCount + 1):
        dat = rast.GetRasterBand(b).ReadAsArray()
        sink.GetRasterBand(b).WriteArray(dat)
        sink.GetRasterBand(b).SetStatistics(*map(np.float64,
            [dat.min(), dat.max(), dat.mean(), dat.std()]))

        if nodata is None:
            nodata = rast.GetRasterBand(b).GetNoDataValue()

            if nodata is None:
                nodata = -9999

        sink.GetRasterBand(b).SetNoDataValue(np.float64(nodata))

    sink.FlushCache()
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def mask_by_query(rast, query, invert=False, nodata=-9999):
    '''
    Mask pixels (across bands) that match a query in any one band or all bands.
    For example: `query = rast[1,...] < -25` queries those pixels with a value
    less than -25 in band 2; these pixels would be masked (if `invert=False`).
    By default, the pixels that are queried are masked, but if `invert=True`,
    the query serves to select pixels NOT to be masked (`np.invert()` can also
    be called on the query before calling this function to achieve the same
    effect). Arguments:
        rast    A gdal.Dataset or numpy.array instance
        query   A NumPy boolean array representing the result of a query
        invert  True to invert the query
        nodata  The NoData value to apply in the masking
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    shp = rastr.shape
    if query.shape != rastr.shape:
        assert len(query.shape) == 2 or len(query.shape) == len(shp), 'Query must either be 2-dimensional (single-band) or have a dimensionality equal to the raster array'
        assert shp[-2] == query.shape[-2] and shp[-1] == query.shape[-1], 'Raster and query must be conformable arrays in two dimensions (must have the same extent)'

        # Transform the query into a 1-band array and then into a multi-band array
        query = query.reshape((1, shp[-2], shp[-1])).repeat(shp[0], axis=0)

    # Mask out areas that match the query
    if invert:
        rastr[np.invert(query)] = nodata

    else:
        rastr[query] = nodata

    return rastr
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def subarray(rast, filtered_value=-9999, indices=False):
    '''
    Given a (p x m x n) raster (or array), returns a (p x z) subarray where
    z is the number of cases (pixels) that do not contain the filtered value
    (in any band, in the case of a multi-band image). Arguments:
        rast            The input gdal.Dataset or a NumPy array
        filtered_value  The value to remove from the raster array
        indices         If True, return a tuple: (indices, subarray)
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    shp = rastr.shape
    if len(shp) == 1:
        # If already raveled
        return rastr[rastr != filtered_value]

    if len(shp) == 2 or shp[0] == 1:
        # If a "single-band" image
        arr = rastr.reshape(1, shp[-2]*shp[-1])
        return arr[arr != filtered_value]

    # For multi-band images
    arr = rastr.reshape(shp[0], shp[1]*shp[2])
    idx = (arr != filtered_value).any(axis=0)
    if indices:
        # Return the indices as well
        rast_shp = (shp[-2], shp[-1])
        return (np.indices(rast_shp)[:,idx.reshape(rast_shp)], arr[:,idx])

    return arr[:,idx]
项目:pygeotools    作者:dshean    | 项目源码 | 文件源码
def parse_srs(t_srs, src_ds_list=None):
    """Parse arbitrary input t_srs

    Parameters
    ----------
    t_srs : str or gdal.Dataset or filename
        Arbitrary input t_srs 
    src_ds_list : list of gdal.Dataset objects, optional
        Needed if specifying 'first' or 'last'

    Returns
    -------
    t_srs : osr.SpatialReference() object
        Output spatial reference system
    """
    if t_srs is None and src_ds_list is None:
        print("Input t_srs and src_ds_list are both None")
    else:
        if t_srs is None:
            t_srs = 'first'
        if t_srs == 'first' and src_ds_list is not None:
            t_srs = geolib.get_ds_srs(src_ds_list[0])
        elif t_srs == 'last' and src_ds_list is not None:
            t_srs = geolib.get_ds_srs(src_ds_list[-1])
        #elif t_srs == 'source':
        #    t_srs = None 
        elif isinstance(t_srs, osr.SpatialReference): 
            pass
        elif isinstance(t_srs, gdal.Dataset):
            t_srs = geolib.get_ds_srs(t_srs)
        elif isinstance(t_srs, str) and os.path.exists(t_srs): 
            t_srs = geolib.get_ds_srs(gdal.Open(t_srs))
        elif isinstance(t_srs, str):
            temp = osr.SpatialReference()
            if 'EPSG' in t_srs.upper():
                epsgcode = int(t_srs.split(':')[-1])
                temp.ImportFromEPSG(epsgcode)
            elif 'proj' in t_srs:
                temp.ImportFromProj4(t_srs)
            else:
                #Assume the user knows what they are doing
                temp.ImportFromWkt(t_srs)
            t_srs = temp
        else:
            t_srs = None
    return t_srs
项目:pygeotools    作者:dshean    | 项目源码 | 文件源码
def parse_res(res, src_ds_list=None, t_srs=None):
    """Parse arbitrary input res 

    Parameters
    ----------
    res : str or gdal.Dataset or filename or float
        Arbitrary input res 
    src_ds_list : list of gdal.Dataset objects, optional
        Needed if specifying 'first' or 'last'
    t_srs : osr.SpatialReference() object 
        Projection for res calculations, optional

    Returns
    -------
    res : float 
        Output resolution
        None if source resolution should be preserved
    """
    #Default to using first t_srs for res calculations
    #Assumes src_ds_list is not None
    t_srs = parse_srs(t_srs, src_ds_list)

    #Valid strings
    res_str_list = ['first', 'last', 'min', 'max', 'mean', 'med']

    #Compute output resolution in t_srs
    if res in res_str_list and src_ds_list is not None:
        #Returns min, max, mean, med
        res_stats = geolib.get_res_stats(src_ds_list, t_srs=t_srs)
        if res == 'first':
            res = geolib.get_res(src_ds_list[0], t_srs=t_srs, square=True)[0]
        elif res == 'last':
            res = geolib.get_res(src_ds_list[-1], t_srs=t_srs, square=True)[0]
        elif res == 'min':
            res = res_stats[0]
        elif res == 'max':
            res = res_stats[1]
        elif res == 'mean':
            res = res_stats[2]
        elif res == 'med':
            res = res_stats[3]
    elif res == 'source':
        res = None
    elif isinstance(res, gdal.Dataset):
        res = geolib.get_res(res, t_srs=t_srs, square=True)[0]
    elif isinstance(res, str) and os.path.exists(res): 
        res = geolib.get_res(gdal.Open(res), t_srs=t_srs, square=True)[0]
    else:
        res = float(res)
    return res
项目:pygeotools    作者:dshean    | 项目源码 | 文件源码
def parse_extent(extent, src_ds_list, t_srs=None):
    """Parse arbitrary input extent

    Parameters
    ----------
    extent : str or gdal.Dataset or filename or list of float
        Arbitrary input extent
    src_ds_list : list of gdal.Dataset objects, optional
        Needed if specifying 'first', 'last', 'intersection', or 'union'
    t_srs : osr.SpatialReference() object, optional 
        Projection for res calculations

    Returns
    -------
    extent : list of float 
        Output extent [xmin, ymin, xmax, ymax] 
        None if source extent should be preserved
    """
    #Default to using first t_srs for extent calculations
    #Assumes src_ds_list is not None
    t_srs = parse_srs(t_srs, src_ds_list)

    #Valid strings
    extent_str_list = ['first', 'last', 'intersection', 'union']

    if extent in extent_str_list and src_ds_list is not None:
        if len(src_ds_list) == 1 and (extent == 'intersection' or extent == 'union'):
            extent = None
        elif extent == 'first':
            extent = geolib.ds_geom_extent(src_ds_list[0], t_srs=t_srs)
            #extent = geolib.ds_extent(src_ds_list[0], t_srs=t_srs)
        elif extent == 'last':
            extent = geolib.ds_geom_extent(src_ds_list[-1], t_srs=t_srs)
            #extent = geolib.ds_extent(src_ds_list[-1], t_srs=t_srs)
        elif extent == 'intersection':
            #By default, compute_intersection takes ref_srs from ref_ds
            extent = geolib.ds_geom_intersection_extent(src_ds_list, t_srs=t_srs)
            if len(src_ds_list) > 1 and extent is None:
                sys.exit("Input images do not intersect")
        elif extent == 'union':
            #Need to clean up union t_srs handling
            extent = geolib.ds_geom_union_extent(src_ds_list, t_srs=t_srs)
    elif extent == 'source':
        extent = None
    elif isinstance(extent, gdal.Dataset):
        extent = geolib.ds_geom_extent(extent, t_srs=t_srs)
    elif isinstance(extent, str) and os.path.exists(extent): 
        extent = geolib.ds_geom_extent(gdal.Open(extent), t_srs=t_srs)
    elif isinstance(extent, (list, tuple)):
        extent = list(extent)
    else:
        extent = [float(i) for i in extent.split(' ')]
    return extent
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def _create_dst_datasource(self, **kwargs):
        """ Create destination target gdal.Dataset

        Creates one layer for each target polygon, consisting of
        the needed source data attributed with index and weights fields

        Returns
        -------
        ds_mem : gdal.Dataset object
        """

        # TODO: kwargs necessary?

        # create intermediate mem dataset
        ds_mem = io.gdal_create_dataset('Memory', 'dst',
                                        gdal_type=gdal.OF_VECTOR)

        # get src geometry layer
        src_lyr = self.src.ds.GetLayerByName('src')
        src_lyr.ResetReading()
        src_lyr.SetSpatialFilter(None)
        geom_type = src_lyr.GetGeomType()

        # create temp Buffer layer (time consuming)
        ds_tmp = io.gdal_create_dataset('Memory', 'tmp',
                                        gdal_type=gdal.OF_VECTOR)
        georef.ogr_copy_layer(self.trg.ds, 0, ds_tmp)
        tmp_trg_lyr = ds_tmp.GetLayer()

        for i in range(tmp_trg_lyr.GetFeatureCount()):
            feat = tmp_trg_lyr.GetFeature(i)
            feat.SetGeometryDirectly(feat.GetGeometryRef().
                                     Buffer(self._buffer))
            tmp_trg_lyr.SetFeature(feat)

        # get target layer, iterate over polygons and calculate intersections
        tmp_trg_lyr.ResetReading()

        self.tmp_lyr = georef.ogr_create_layer(ds_mem, 'dst', srs=self._srs,
                                               geom_type=geom_type)

        print("Calculate Intersection source/target-layers")
        try:
            tmp_trg_lyr.Intersection(src_lyr, self.tmp_lyr,
                                     options=['SKIP_FAILURES=YES',
                                              'INPUT_PREFIX=trg_',
                                              'METHOD_PREFIX=src_',
                                              'PROMOTE_TO_MULTI=YES',
                                              'PRETEST_CONTAINMENT=YES'],
                                     callback=gdal.TermProgress)
        except RuntimeError:
            # Catch RuntimeError that was reported on gdal 1.11.1
            # on Windows systems
            tmp_trg_lyr.Intersection(src_lyr, self.tmp_lyr,
                                     options=['SKIP_FAILURES=YES',
                                              'INPUT_PREFIX=trg_',
                                              'METHOD_PREFIX=src_',
                                              'PROMOTE_TO_MULTI=YES',
                                              'PRETEST_CONTAINMENT=YES'])

        return ds_mem
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def create_raster_dataset(data, coords, projection=None, nodata=-9999):
    """ Create In-Memory Raster Dataset

    .. versionadded 0.10.0

    Parameters
    ----------
    data : :class:`numpy:numpy.ndarray`
        Array of shape (rows, cols) or (bands, rows, cols) containing
        the data values.
    coords : :class:`numpy:numpy.ndarray`
        Array of shape (rows, cols, 2) containing xy-coordinates.
    projection : osr object
        Spatial reference system of the used coordinates, defaults to None.

    Returns
    -------
    dataset : gdal.Dataset
        In-Memory raster dataset

    Note
    ----
    The origin of the provided data and coordinates is UPPER LEFT.
    """

    # align data
    data = data.copy()
    if data.ndim == 2:
        data = data[np.newaxis, ...]
    bands, rows, cols = data.shape

    # create In-Memory Raster with correct dtype
    mem_drv = gdal.GetDriverByName('MEM')
    gdal_type = gdal_array.NumericTypeCodeToGDALTypeCode(data.dtype)
    dataset = mem_drv.Create('', cols, rows, bands, gdal_type)

    # initialize geotransform
    x_ps, y_ps = coords[1, 1] - coords[0, 0]
    geotran = [coords[0, 0, 0], x_ps, 0, coords[0, 0, 1], 0, y_ps]
    dataset.SetGeoTransform(geotran)

    if projection:
        dataset.SetProjection(projection.ExportToWkt())

    # set np.nan to nodata
    dataset.GetRasterBand(1).SetNoDataValue(nodata)

    for i, band in enumerate(data, start=1):
        dataset.GetRasterBand(i).WriteArray(band)

    return dataset
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def gdal_create_dataset(drv, name, cols=0, rows=0, bands=0,
                        gdal_type=gdal.GDT_Unknown, remove=False):
    """Creates GDAL.DataSet object.

    .. versionadded:: 0.7.0

    .. versionchanged:: 0.11.0
        - changed parameters to keyword args
        - added 'bands' as parameter

    Parameters
    ----------
    drv : string
        GDAL driver string
    name : string
        path to filename
    cols : int
        # of columns
    rows : int
        # of rows
    bands : int
        # of raster bands
    gdal_type : raster data type
        eg. gdal.GDT_Float32
    remove : bool
        if True, existing gdal.Dataset will be
        removed before creation

    Returns
    -------
    out : gdal.Dataset
        object

    """
    driver = gdal.GetDriverByName(drv)
    metadata = driver.GetMetadata()

    if not metadata.get('DCAP_CREATE', False):
        raise IOError("Driver %s doesn't support Create() method.".format(drv))

    if remove:
        if os.path.exists(name):
            driver.Delete(name)
    ds = driver.Create(name, cols, rows, bands, gdal_type)

    return ds
项目:wradlib    作者:wradlib    | 项目源码 | 文件源码
def write_raster_dataset(fpath, dataset, format, options=None, remove=False):
    """ Write raster dataset to file format

        .. versionadded 0.10.0

    Parameters
    ----------
    fpath : string
        A file path - should have file extension corresponding to format.
    dataset : gdal.Dataset
        gdal raster dataset
    format : string
        gdal raster format string
    options : list
        List of option strings for the corresponding format.
    remove : bool
        if True, existing gdal.Dataset will be
        removed before creation

    Note
    ----
    For format and options refer to
    `formats_list <http://www.gdal.org/formats_list.html>`_.

    Examples
    --------
    See :ref:`notebooks/fileio/wradlib_gis_export_example.ipynb`.
    """
    # check for option list
    if options is None:
        options = []

    driver = gdal.GetDriverByName(format)
    metadata = driver.GetMetadata()

    # check driver capability
    if 'DCAP_CREATECOPY' in metadata and metadata['DCAP_CREATECOPY'] != 'YES':
        assert "Driver %s doesn't support CreateCopy() method.".format(format)

    if remove:
        if os.path.exists(fpath):
            driver.Delete(fpath)

    target = driver.CreateCopy(fpath, dataset, 0, options)
    del target
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def hall_rectification(reference, subject, out_path, ref_set, sub_set, dd=False, nodata=-9999,
    dtype=np.int32, keys=('High/Bright', 'Low/Dark')):
    '''
    Performs radiometric rectification after Hall et al. (1991) in Remote
    Sensing of Environment. Assumes first raster is the reference image and
    that none of the targets are NoData pixels in the reference image (they
    are filtered out in the subject images). Arguments:
        reference   The reference image, a gdal.Dataset
        subject     The subject image, a gdal.Dataset
        out_path    Path to a directory where the rectified images should be stored
        ref_set     Sequence of two sequences: "bright" radiometric control set,
                    then "dark" radiometric control set for reference image
        sub_set     As with ref_set, a sequence of sequences (e.g., list of two
                    lists): [[<bright targets>], [<dark targets]]
        dd          Coordinates are in decimal degrees?
        dtype       Date type (NumPy dtype) for the array; default is 32-bit Int
        nodata      The NoData value to use fo all the rasters
        keys        The names of the dictionary keys for the bright, dark sets,
                    respectively
    '''
    # Unpack bright, dark control sets for subject image
    bright_targets, dark_targets = (sub_set[keys[0]], sub_set[keys[1]])

    # Calculate the mean reflectance in each band for bright, dark targets
    bright_ref = spectra_at_xy(reference, ref_set[keys[0]], dd=dd).mean(axis=0)
    dark_ref = spectra_at_xy(reference, ref_set[keys[1]], dd=dd).mean(axis=0)

    # Calculate transformation for the target image
    brights = spectra_at_xy(subject, bright_targets, dd=dd) # Prepare to filter NoData pixels
    darks = spectra_at_xy(subject, dark_targets, dd=dd)
    # Get the "subject" image means for each radiometric control set
    mean_bright = brights[
        np.sum(brights, axis=1) != (nodata * brights.shape[1])
    ].mean(axis=0)
    mean_dark = darks[
        np.sum(darks, axis=1) != (nodata * darks.shape[1])
    ].mean(axis=0)

    # Calculate the coefficients of the linear transformation
    m = (bright_ref - dark_ref) / (mean_bright - mean_dark)
    b = (dark_ref * mean_bright - mean_dark * bright_ref) / (mean_bright - mean_dark)

    arr = subject.ReadAsArray()
    shp = arr.shape # Remember the original shape
    mask = arr.copy() # Save the NoData value locations
    m = m.reshape((shp[0], 1))
    b = b.reshape((shp[0], 1)).T.repeat(shp[1] * shp[2], axis=0).T
    arr2 = ((arr.reshape((shp[0], shp[1] * shp[2])) * m) + b).reshape(shp)
    arr2[mask == nodata] = nodata # Re-apply NoData values

    # Dump the raster to a file
    out_path = os.path.join(out_path, 'rect_%s' % os.path.basename(subject.GetDescription()))
    dump_raster(
        array_to_raster(arr2, subject.GetGeoTransform(), subject.GetProjection(), dtype=dtype), out_path)
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def binary_mask(rast, mask, nodata=-9999, invert=False):
    '''
    Applies an arbitrary, binary mask (data in [0,1]) where pixels with
    a value of 1 are pixels to be masked out. Arguments:
        rast    A gdal.Dataset or a NumPy array
        mask    A gdal.Dataset or a NumPy array
        nodata  The NoData value; defaults to -9999.
        invert  Invert the mask? (tranpose meaning of 0 and 1); defaults to False.
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rastr = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    if not isinstance(mask, np.ndarray):
        maskr = mask.ReadAsArray()

    else:
        maskr = mask.copy()

    if not np.alltrue(np.equal(rastr.shape[-2:], maskr.shape[-2:])):
        raise ValueError('Raster and mask do not have the same shape')

    # Convert Boolean arrays to ones and zeros
    if maskr.dtype == bool:
        maskr = maskr.astype(np.int0)

    # Transform into a "1-band" array and apply the mask
    if maskr.shape != rastr.shape:
        maskr = maskr.reshape((1, maskr.shape[-2], maskr.shape[-1]))\
            .repeat(rastr.shape[0], axis=0) # Copy the mask across the "bands"

    # TODO Compare to place(), e.g.,
    # np.place(rastr, mask.repeat(rastr.shape[0], axis=0), (nodata,))
    # Mask out areas that match the mask (==1)
    if invert:
        rastr[maskr < 1] = nodata

    else:
        rastr[maskr > 0] = nodata

    return rastr
项目:unmixing    作者:arthur-e    | 项目源码 | 文件源码
def mask_ledaps_qa(rast, mask, nodata=-9999):
    '''
    Applies a given LEDAPS QA mask to a raster. It's unclear how these
    bit-packed QA values ought to be converted back into 16-bit binary numbers:

    "{0:b}".format(42).zfill(16) # Convert binary to decimal padded left?
    "{0:b}".format(42).ljust(16, '0') # Or convert ... padded right?

    The temporary solution is to use the most common (modal) value as the
    "clear" pixel classification and discard everything else. We'd like to
    just discard pixels above a certain value knowing that everything above
    this threshold has a certain bit-packed QA meanining. For example, mask
    pixels with QA values greater than or equal to 12287:

    int("1000000000000000", 2) == 32768 # Maybe clouds
    int("0010000000000000", 2) == 12287 # Maybe cirrus

    Similarly, we'd like to discard pixels at or below 4, as these small binary
    numbers correspond to dropped frames, desginated fill values, and/or
    terrain occlusion. Arguments:
        rast    A gdal.Dataset or a NumPy array
        mask    A gdal.Dataset or a NumPy array
    '''
    # Can accept either a gdal.Dataset or numpy.array instance
    if not isinstance(rast, np.ndarray):
        rast = rast.ReadAsArray()

    else:
        rastr = rast.copy()

    if not isinstance(mask, np.ndarray):
        maskr = mask.ReadAsArray()

    else:
        maskr = mask.copy()

    # Since the QA output is so unreliable (e.g., clouds are called water),
    #   we take the most common QA bit-packed value and assume it refers to
    #   the "okay" pixels
    mode = np.argmax(np.bincount(maskr.ravel()))
    assert mode > 4 and mode < 12287, "The modal value corresponds to a known error value"
    maskr[np.isnan(maskr)] = 0
    maskr[maskr != mode] = 0
    maskr[maskr == mode] = 1

    # Transform into a "1-band" array and apply the mask
    maskr = maskr.reshape((1, maskr.shape[0], maskr.shape[1]))\
        .repeat(rastr.shape[0], axis=0) # Copy the mask across the "bands"
    rastr[maskr == 0] = nodata
    return rastr