Source code for satproc.utils

import json
import logging
import multiprocessing as mp
import os
import subprocess
import tempfile
from itertools import zip_longest
from multiprocessing.pool import ThreadPool

import fiona
import numpy as np
import pyproj
import rasterio
from packaging import version
from pyproj.crs import CRS
from pyproj.enums import WktVersion
from rasterio.windows import Window
from shapely.geometry import mapping
from shapely.ops import transform
from skimage import exposure
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

__author__ = "Damián Silvani"
__copyright__ = "Dymaxion Labs"
__license__ = "Apache-2.0"

_logger = logging.getLogger(__name__)


[docs]def grouper(iterable, n, fillvalue=None): "Collect data into fixed-length chunks or blocks" # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" args = [iter(iterable)] * n return zip_longest(*args, fillvalue=fillvalue)
[docs]def sliding_windows(size, step_size, width, height, mode="exact"): """Slide a window of +size+ by moving it +step_size+ pixels Parameters ---------- size : int window size, in pixels step_size : int step or *stride* size when sliding window, in pixels width : int image width height : int image height mode : str (default: 'exact') either one of 'exact', 'whole', 'whole_overlap'. - 'exact': clip windows at borders, if needed - 'whole': only whole windows - 'whole_overlap': only wohle windows, allow overlapping windows at borders. Yields ------ Tuple[Window, Tuple[int, int]] a pair of Window and a pair of position (i, j) """ w, h = size sw, sh = step_size whole = mode in ("whole", "whole_overlap") end_i = height - h if whole else height end_j = width - w if whole else width last_pos_i, last_pos_j = 0, 0 for pos_i, i in enumerate(range(0, end_i, sh)): for pos_j, j in enumerate(range(0, end_j, sw)): real_w = w if whole else min(w, abs(width - j)) real_h = h if whole else min(h, abs(height - i)) yield Window(j, i, real_w, real_h), (pos_i, pos_j) last_pos_i, last_pos_j = pos_i, pos_j if mode == "whole_overlap" and (height % sh != 0 or width % sw != 0): for pos_i, i in enumerate(range(0, height - h, sh)): yield Window(width - w, i, w, h), ( pos_i, last_pos_j + 1, ) for pos_j, j in enumerate(range(0, width - w, sw)): yield Window(j, height - h, w, h), ( last_pos_i + 1, pos_j, ) yield Window(width - w, height - h, w, h), (last_pos_i + 1, last_pos_j + 1)
[docs]def rescale_intensity(image, rescale_mode, rescale_range): """ Calculate percentiles from a range cut and rescale intensity of image to byte range Parameters ---------- image : numpy.ndarray image array rescale_mode : str rescaling mode, either 'percentiles' or 'values' rescale_range : Tuple[number, number] input range for rescaling Returns ------- numpy.ndarray rescaled image """ if rescale_mode == "percentiles": in_range = np.percentile(image, rescale_range, axis=(1, 2)).T elif rescale_mode == "values": min_value, max_value = rescale_range if min_value is None: min_value = np.min(image) if max_value is None: max_value = np.max(image) in_range = np.array([(min_value, max_value) for _ in range(image.shape[0])]) elif rescale_mode == "s2_rgb_extra": in_range = np.percentile(image, rescale_range, axis=(1, 2)).T # Override first 3 ranges for (0, 0.3) (Sentinel-2 L2A TCI range) in_range[0] = (0, 0.3) in_range[1] = (0, 0.3) in_range[2] = (0, 0.3) else: raise RuntimeError(f"unknown rescale_mode {rescale_mode}") return np.array( [ exposure.rescale_intensity( image[i, :, :], in_range=tuple(in_range[i]), out_range=(1, 255) ).astype(np.uint8) for i in range(image.shape[0]) ] )
[docs]def write_chips_geojson(output_path, chip_pairs, *, chip_type, crs, basename): """Write a GeoJSON containing chips polygons as features Parameters ---------- output_path : str GeoJSON output path chip_pairs : Tuple[Shape, Tuple[int, int, int]] a pair with the chip polygon geometry, and a tuple of (feature id, x, y) chip_type : str chip file type extension (e.g. tif, jpg) crs : str CRS epsg code of chip polygon geometry basename : str basename of chip files Returns ------- None """ if not chip_pairs: _logger.warn("No chips to save") return _logger.info("Write chips geojson") os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "w") as f: d = {"type": "FeatureCollection", "features": []} if crs != "epsg:4326": code = crs.split(":")[1] d["crs"] = { "type": "name", "properties": {"name": f"urn:ogc:def:crs:EPSG::{code}"}, } for i, (chip, (_fi, xi, yi)) in enumerate(chip_pairs): filename = f"{basename}_{xi}_{yi}.{chip_type}" feature = { "type": "Feature", "geometry": mapping(chip), "properties": {"id": i, "x": xi, "y": yi, "filename": filename}, } d["features"].append(feature) f.write(json.dumps(d))
[docs]def get_raster_band_count(path): """Get raster band count Parameters ---------- path : str path of the raster image Returns ------- int band count """ with rasterio.open(path) as src: return src.count
[docs]def reproject_shape(shp, from_crs, to_crs, project=None): """Reproject a shape from `from_crs` to `to_crs` Parameters ---------- shp : Shape shape to reproject from_crs : str CRS epsg code of shape geometry to_crs : str CRS epsg code of reprojected shape geometry project : Optional[str] a Transformer instance to use for reprojection Returns ------- Shape reprojected shape """ if from_crs == to_crs: return shp if project is None: project = pyproj.Transformer.from_crs( from_crs, to_crs, always_xy=True ).transform return transform(project, shp)
[docs]def proj_crs_from_fiona_dataset(fio_ds): return CRS.from_wkt(fio_ds.crs_wkt)
[docs]def fiona_crs_from_proj_crs(proj_crs): if version.parse(fiona.__gdal_version__) < version.parse("3.0.0"): fio_crs = proj_crs.to_wkt(WktVersion.WKT1_GDAL) else: # GDAL 3+ can use WKT2 fio_crs = proj_crs.to_wkt() return fio_crs
[docs]def build_virtual_raster(image_paths, output_path, separate=None, band=None): os.makedirs(os.path.dirname(output_path), exist_ok=True) # For some reason, relative paths wont work, so we get the absolute path of # each input image path. image_paths = [os.path.abspath(p) for p in image_paths] with tempfile.NamedTemporaryFile() as f: # Write a list of image files to a temporary file for image_path in image_paths: f.write(f"{image_path}\n".encode()) f.flush() output_path = os.path.abspath(output_path) run_command( f"gdalbuildvrt -q -overwrite " f"-input_file_list {f.name} " f"{'-separate ' if separate else ''}" f"{f'-b {band} ' if band else ''}" f"{output_path}", cwd=os.path.dirname(output_path), )
[docs]def run_command(cmd, quiet=True, *, cwd=None): """Run a shell command Parameters ---------- cmd : str command to run quiet : bool (default: True) silent output (stdout and sterr) Returns ------- None """ stderr = subprocess.DEVNULL if quiet else None stdout = subprocess.DEVNULL if quiet else None subprocess.run(cmd, shell=True, stderr=stderr, stdout=stdout, check=True, cwd=cwd)
[docs]def map_with_threads(items, worker, num_jobs=None, total=None, desc=None): """Map a worker function to an iterable of items, using a thread pool Parameters ---------- items : iterable items to map worker : Function worker function to apply to each item num_jobs : int number of threads to use total : int (optional) total number of items (for the progress bar) desc : str (optional) description of the task (for the progress bar) Returns ------- None """ if not total: total = len(items) if not num_jobs: num_jobs = mp.cpu_count() with ThreadPool(num_jobs) as pool: with logging_redirect_tqdm(): with tqdm(total=len(items), ascii=True, desc=desc) as pbar: for _ in enumerate(pool.imap_unordered(worker, items)): pbar.update()