Source code for satproc.postprocess.smooth

import logging
import math
import os
import tempfile
from glob import glob

import numpy as np
import rasterio
import rasterio.mask
import rasterio.merge
import rasterio.windows
import scipy.signal
from rasterio.transform import Affine
from rtree import index
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

logger = logging.getLogger(__name__)


[docs]def spline_window(window_size, power=2): """ Squared spline window function: https://www.wolframalpha.com/input/?i=y%3Dx**2,+y%3D-(x-2)**2+%2B2,+y%3D(x-4)**2,+from+y+%3D+0+to+2 """ intersection = int(window_size / 4) wind_outer = (abs(2 * (scipy.signal.triang(window_size))) ** power) / 2 wind_outer[intersection:-intersection] = 0 wind_inner = 1 - (abs(2 * (scipy.signal.triang(window_size) - 1)) ** power) / 2 wind_inner[:intersection] = 0 wind_inner[-intersection:] = 0 wind = wind_inner + wind_outer wind = wind / np.average(wind) return wind
[docs]def window_2D(size, power=2, n_channels=1): wind = spline_window(size, power) wind = np.expand_dims(wind, 0) wind = (wind * wind.transpose()) / 4 wind = np.expand_dims(wind, axis=0) return np.repeat(wind, n_channels, axis=0).reshape(n_channels, size, size)
[docs]def generate_spline_window_chips(*, image_paths, output_dir, power=2): """Interpolates all images using a squared spline window""" if not image_paths: return [] logger.info(( f"Interpolate all images using a squared spline window (power={power}) " f"and store chips on {output_dir}" )) first_image = image_paths[0] with rasterio.open(first_image) as src: chip_size = src.width n_channels = src.count assert src.width == src.height win = window_2D(size=chip_size, power=power, n_channels=n_channels) norm_win = (win - win.min()) / (win.max() - win.min()) res = [] with logging_redirect_tqdm(): for img_path in tqdm(image_paths, ascii=True, desc="Smooth chips"): with rasterio.open(img_path) as src: profile = src.profile.copy() profile.update(dtype=np.float64) img = src.read() img = img * norm_win out_path = os.path.join(output_dir, os.path.basename(img_path)) os.makedirs(output_dir, exist_ok=True) res.append(out_path) with rasterio.open(out_path, "w", **profile) as dst: for i in range(img.shape[0]): dst.write(img[i, :, :], i + 1) return res
[docs]def build_bounds_index(image_files): """Returns bounds of merged images and builds an R-Tree index""" idx = index.Index() xs = [] ys = [] with logging_redirect_tqdm(): for i, img_path in tqdm( list(enumerate(image_files)), ascii=True, desc="Build bounds R-Tree index" ): with rasterio.open(img_path) as src: left, bottom, right, top = src.bounds xs.extend([left, right]) ys.extend([bottom, top]) idx.insert(i, (left, bottom, right, top)) dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys) return idx, (dst_w, dst_s, dst_e, dst_n)
[docs]def sliding_windows(size, whole=False, step_size=None, *, width, height): """Slide a window of +size+ by moving it +step_size+ pixels""" if not step_size: step_size = size w, h = (size, size) sw, sh = (step_size, step_size) end_i = height - h if whole else height end_j = width - w if whole else width for pos_i, i in enumerate(range(0, end_i, sh)): for pos_j, j in enumerate(range(0, end_j, sw)): real_w = w if whole else min(w, abs(width - j)) real_h = h if whole else min(h, abs(height - i)) yield rasterio.windows.Window(j, i, real_w, real_h), (pos_i, pos_j)
[docs]def merge_chips(images_files, *, win_bounds): """Merge by taking mean between overlapping images""" datasets = [rasterio.open(p) for p in images_files] img, _ = rasterio.merge.merge(datasets, bounds=win_bounds, method="max") for ds in datasets: ds.close() return img
[docs]def smooth_stitch(*, input_dir, output_dir, power=1.5, temp_dir=None): """ Takes input directory of overlapping chips, and generates a new directory of non-overlapping chips with smooth edges. """ image_paths = glob(os.path.join(input_dir, "*.tif")) if not image_paths: raise RuntimeError("%s does not contain any .tif file" % (input_dir)) # Get the profile and affine of some image as template for output image first_image = image_paths[0] with rasterio.open(first_image) as src: profile = src.profile.copy() src_res = src.res chip_size = src.width assert src.width == src.height if temp_dir: tmpdir = temp_dir else: tmp_dir = tempfile.TemporaryDirectory() tmpdir = tmp_dir.name tmp_image_paths = generate_spline_window_chips( image_paths=image_paths, output_dir=tmpdir, power=power ) # Get bounds from all images and build R-Tree index idx, (dst_w, dst_s, dst_e, dst_n) = build_bounds_index(tmp_image_paths) # Get affine transform for complete bounds logger.info("Output bounds: %r", (dst_w, dst_s, dst_e, dst_n)) output_transform = Affine.translation(dst_w, dst_n) logger.info("Output transform, before scaling: %r", output_transform) output_transform *= Affine.scale(src_res[0], -src_res[1]) logger.info("Output transform, after scaling: %r", output_transform) # Compute output array shape. We guarantee it will cover the output # bounds completely. We need this to build windows list later. output_width = int(math.ceil((dst_e - dst_w) / src_res[0])) output_height = int(math.ceil((dst_n - dst_s) / src_res[1])) # Set width and height for output chips, and other attributes profile.update(width=chip_size, height=chip_size, tiled=True) windows = list( sliding_windows(chip_size, width=output_width, height=output_height) ) logger.info("Num. windows: %d", len(windows)) with logging_redirect_tqdm(): for win, (i, j) in tqdm(windows, ascii=True, desc="Merge chips"): # Get window affine transform and bounds win_transform = rasterio.windows.transform(win, output_transform) win_bounds = rasterio.windows.bounds(win, output_transform) # Get chips that intersect with window intersect_chip_paths = [ tmp_image_paths[i] for i in idx.intersection(win_bounds) ] if intersect_chip_paths: # Merge them with median method img = merge_chips(intersect_chip_paths, win_bounds=win_bounds) # Write output chip profile.update(transform=win_transform) output_path = os.path.join(output_dir, f"{i}_{j}.tif") os.makedirs(output_dir, exist_ok=True) with rasterio.open(output_path, "w", **profile) as dst: for i in range(img.shape[0]): dst.write(img[i, :, :], i + 1) if not temp_dir: tmp_dir.cleanup()