Source code for gkit.core.raster

import numpy as np
from import MaskedArray
from osgeo import gdal, osr, ogr
from scipy.ndimage.filters import generic_filter as gf

import gkit as gk

# Data type mapping between numpy and gdal.
TYPE = {
    np.dtype(np.int8): gdal.GDT_Byte,
    np.dtype(np.uint8): gdal.GDT_Byte,
    np.dtype(np.int16): gdal.GDT_Int16,
    np.dtype(np.uint16): gdal.GDT_UInt16,
    np.dtype(np.int32): gdal.GDT_Int32,
    np.dtype(np.uint32): gdal.GDT_UInt32,
    np.dtype(np.float32): gdal.GDT_Float32,
    np.dtype(np.float64): gdal.GDT_Float64,
    np.dtype(np.complex64): gdal.GDT_CFloat32,
    np.dtype(np.complex128): gdal.GDT_CFloat64,

def _srs_to_wkt(a_srs):
    name, code = a_srs.split(":")
    name = name.upper()
    code = int(code)

    srs = osr.SpatialReference()
    if name.startswith("EPSG"):
    elif name.startswith("ESRI"):

    projection = srs.ExportToWkt()

    del srs
    return projection

[docs]class Raster(MaskedArray): """ """ def __new__(cls, array, transform, projection=None, a_srs="EPSG:4326", nodatavalue=None, mask=None, filepath=None): """""" projection = projection or _srs_to_wkt(a_srs) _raster_meta = { 'projection': projection, 'transform': tuple(transform), } array = if mask is not None: array = MaskedArray(array, mask=mask) if nodatavalue is not None: array =, nodatavalue) obj = super(Raster, cls).__new__( cls, array, fill_value=nodatavalue ) if filepath is not None: filepath = str(filepath) setattr(obj, 'filepath', filepath) setattr(obj, '_raster_meta', _raster_meta) return obj def _update_from(self, obj): self.__dict__.update({ "_raster_meta": getattr(obj, '_raster_meta', None), "filepath": getattr(obj, 'filepath', None) }) super(Raster, self)._update_from(obj) return @property def projection(self): return self._raster_meta['projection'] @projection.setter def projection(self, value): self._raster_meta['projection'] = value @property def transform(self): return self._raster_meta['transform'] @transform.setter def transform(self, value): self._raster_meta['transform'] = tuple(value) @property def extent(self): """The extent of raster in current coordinates. [left, right, bottom, top] """ left = self.transform[0] right = left + self.transform[1] * self.shape[1] top = self.transform[3] bottom = top + self.transform[5] * self.shape[0] return left, right, bottom, top def _gdal_dtype(self): if self.dtype in TYPE: dtype = TYPE[self.dtype] return dtype else: raise 'Cannot convert {} into gdal.'.format(self.dtype)
[docs] def coord(self, x, y): """Get point value by coordinate. Args: x (float): The X coordinate of the point. y (float): The Y coordinate of the point. """ origin_x = self._raster_meta['transform'][3] origin_y = self._raster_meta['transform'][0] pixel_x = self._raster_meta['transform'][5] pixel_y = self._raster_meta['transform'][1] x = int((x - origin_x) / pixel_x) y = int((y - origin_y) / pixel_y) return self[x, y]
[docs] def set_fill_value(self, value=None): """Set fill value. Args: value (int or float): Fill value. The max or min value of current dtype will be used when this argument is greater or lower than current dtype's range. """ info = np.iinfo if self.dtype.kind == "i" else np.finfo max_value = info(self.dtype).max min_value = info(self.dtype).min value = self.fill_value if value is None else value value = np.median([max_value, min_value, value]) super(Raster, self).set_fill_value(value)
[docs] def save(self, out_path=None, driver_name="GTiff", compress=True, options=None): """save :class:`Raster` to GeoTIFF file or :class:`gdal.Dataset`. Args: raster (Raster or a list of Rasters): Save rasters to file. When it's a list or tuple of :class:`Raster`, save them all as multi bands in one file. out_path (str): The output path. If it is ``None``, return a :class:`gdal.Dataset`.(use MEM driver) driver_name (str): Use which driver to save.(default="GTiff") compress (int): | Could be following options: | ``compress=True`` Use LZW to compress | ``compress=False`` (default) Do not compress | ``compress='DEFAULT'`` | ``compress='PACKBITS'`` | ... other algorithms gdal supported Returns: `None` or `gdal.Dataset` """ return, out_path, driver_name, compress, options)
[docs] def clip_by_layer(self, layer): """Clip raster by layer.""" # TODO Convert layer's projection into raster's projection when # they have different projection. mem_raster_driver = gdal.GetDriverByName("MEM") tmp_raster = mem_raster_driver.Create( "", self.shape[1], self.shape[0], 1, gdal.GDT_Byte) tmp_raster.SetProjection(self._raster_meta['projection']) tmp_raster.SetGeoTransform(self._raster_meta['transform']) tmp_raster.GetRasterBand(1).SetNoDataValue(0) tmp_raster.GetRasterBand(1).Fill(0) gdal.RasterizeLayer(tmp_raster, [1], layer, burn_values=[1]) mask = tmp_raster.ReadAsArray(), 0).mask array =, mask) return Raster(array, self.transform, self.projection)
[docs] def clip_by_shp(self, shp_path): """Clip raster by shapefile.""" shp = ogr.Open(shp_path) return self.clip_by_layer(shp.GetLayer())
[docs] def clip_by_feature(self, feature): """Clip raster by on or more features.""" feature = [feature] if isinstance(feature, ogr.Feature) else feature mem_shp_driver = ogr.GetDriverByName("Memory") tmp_shp = mem_shp_driver.CreateDataSource("") srs = osr.SpatialReference() srs.ImportFromWkt(self.projection) tmp_layer = tmp_shp.CreateLayer("tmp", srs) for f in feature: tmp_layer.CreateFeature(f.Clone()) return self.clip_by_layer(tmp_layer)
[docs] def clip_by_extent(self, extent): """Clip raster by extent. Args: extent (tuple or list): [left, right, bottom, top] Returns: `None` or `gdal.Dataset` """ extent = list(extent) if self.transform[1] < 0: extent[0], extent[1] = extent[1], extent[0] if self.transform[5] > 0: extent[2], extent[3] = extent[3], extent[2] index = np.round([ (extent[0] - self.transform[0]) / self.transform[1], (extent[1] - self.transform[0]) / self.transform[1], (extent[2] - self.transform[3]) / self.transform[5], (extent[3] - self.transform[3]) / self.transform[5], ]).astype(int) raster = self[index[3]:index[2], index[0]:index[1]] transform = list(self.transform) transform[0] = index[0] * self.transform[1] + self.transform[0] transform[3] = index[3] * self.transform[5] + self.transform[3] raster.transform = transform return raster
def split_by_shp(self, shp, by=None, overall=False): return gk.split_by_shp(self, shp, by, overall) def zonal_apply(self, shp_path, func, by=None, overall=False, args=(), kwargs={}): return gk.zonal_apply(self, shp_path, func, by, overall, args, kwargs)
[docs] def reproject(self, x_count=None, y_count=None, transform=None, projection=None, a_srs=None, method=gdal.GRA_Bilinear): """Reproject/Resample Args: x_count (int): Row count. (``RasterXSize``) y_count (int): Column count. (``RasterYSize``) transform (list): Use current transform in default. projection: Use current projection in default. method (int): | Could be following options: | ``gdal.GRA_Bilinear`` (default) | ``gdal.GRA_Average`` | ``gdal.GRA_Cubic`` | ``gdal.GRA_CubicSpline`` | ``gdal.GRA_Lanczos`` | ``gdal.GRA_NearestNeighbour`` Returns: :class:`Raster` """ x_count = x_count or self.shape[1] y_count = y_count or self.shape[0] transform = transform or self.transform if projection or a_srs: projection = projection or _srs_to_wkt(a_srs) else: projection = self.projection mem_raster_driver = gdal.GetDriverByName("MEM") tmp_raster = mem_raster_driver.Create( "", x_count, y_count, 1, self._gdal_dtype() ) tmp_raster.SetProjection(projection) tmp_raster.SetGeoTransform(transform) tmp_band = tmp_raster.GetRasterBand(1) tmp_band.SetNoDataValue(np.float64(self.fill_value)) tmp_band.Fill(np.float64(self.fill_value)) gdal.ReprojectImage(, tmp_raster, self.projection, projection, method) array = tmp_band.ReadAsArray() del tmp_band, tmp_raster return Raster( array, transform, projection, nodatavalue=self.fill_value)
[docs] def resample(self, x_count=None, y_count=None, transform=None, method=None): """Alias of :meth:`self.reproject`. The only difference is that :meth:`resample` cannot change the projection of raster. """ return self.reproject(x_count, y_count, transform, method=method)
[docs] def plot(self, *args, ax=None, cmap_name='seismic', if_show=False, **kwargs): """Use ``matplotlib`` to plot preview picture Args: ax: The ``Axes`` instance. If it's ``None``, use the current ``Axes`` instance. cmap_name (str): color map name, reference: if_show (bool): If call :meth:`` after ploting. """ import matplotlib.pylab as plt if ax is None: ax = plt.gca() plt.imshow( self, *args, cmap=plt.get_cmap(cmap_name), extent=self.extent, **kwargs ) if if_show: plt.colorbar(orientation='horizontal')
[docs] def show(self, *args, **kwargs): """A shortcut of :meth:`self.plot`. Just set ``if_show=True``. """ kwargs['if_show'] = True self.plot(*args, **kwargs)
[docs] def rolling(self, function, size=None, footprint=None, mode='reflect', cval=0.0): """Calculate a 2D filter using the given function. At each element the provided function is called. The input values within the filter footprint at that element are passed to the function as a 1D array of double values. Args: function (callable, str): Function to apply at each element. size (scalar, tuple): See foorprint, below. Ignored if footprint is given. footprint (array, str): | Either `size` or `footprint` must be defined. `size` gives the shape that is taken from the input array, at every element position, to define the input to the filter function. | `footprint` is a boolean array that specifies (implicitly) a shape, but also which of the elements within this shape will get passed to the filter function. | When `footprint` is given, `size` is ignored. mode (str): | The `mode` parameter determines how the input array is extended when the filter overlaps a border. By passing a sequence of modes with length equal to the number of dimensions of the input array, different modes can be specified along each axis. Default value is 'reflect'. The valid values and their behavior is as follows: | 'reflect' (`d c b a | a b c d | d c b a`) The input is extended by reflecting about the edge of the last pixel. | 'constant' (`k k k k | a b c d | k k k k`) The input is extended by filling all values beyond the edge with the same constant value, defined by the `cval` parameter. | 'nearest' (`a a a a | a b c d | d d d d`) The input is extended by replicating the last pixel. | 'mirror' (`d c b | a b c d | c b a`) The input is extended by reflecting about the center of the last pixel. | 'wrap' (`a b c d | a b c d | a b c d`) The input is extended by wrapping around to the opposite edge. cval (scalar): Value to fill past edges of input if `mode` is 'constant'. Default is 0.0. Returns: :class:`Raster` """ res = gf( self.filled(np.nan), function, size, footprint, mode=mode, cval=cval) return Raster(res, self.transform, self.projection, mask=self.mask)
def __str__(self): return self.__repr__() def __repr__(self): projection_name = osr.SpatialReference( wkt=self.projection).GetAttrValue('geogcs') return "Raster<{}, {}, {}>".format( projection_name, self.shape, self.extent)