Skip to content
Extraits de code Groupes Projets
Valider f9b50f55 rédigé par Thomas Dobbelaere's avatar Thomas Dobbelaere
Parcourir les fichiers

put correct script for arrival time map figures

parent 481446cc
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
...@@ -9,156 +9,59 @@ Last modified: 19 July 2022 ...@@ -9,156 +9,59 @@ Last modified: 19 July 2022
""" """
import numpy as np import numpy as np
from netCDF4 import Dataset import os
import os, calendar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import geopandas as gpd
from map_data import plot_landmark from map_data import plot_landmark
import shapely.geometry from postpro import ArrivalTimeMap
from typing import List
import itertools
# --- USEFUL VARIABLES --- #
# grid
minlon, maxlon, minlat, maxlat = 50, 53, 24, 27.5 minlon, maxlon, minlat, maxlat = 50, 53, 24, 27.5
res = 0.01 res = 0.01
longr = np.arange(minlon, maxlon+res, res)
latgr = np.arange(minlat, maxlat+res, res)
ox, oy = longr[0], latgr[0]
nx, ny = longr.size-1, latgr.size-1
# paths
basedir = "/export/miro/students/tanselain/OpenOil/" basedir = "/export/miro/students/tanselain/OpenOil/"
sources = ['ras_laffan','ras_laffan_port', 'abu_fontas', 'umm_al_houl']
# --- USEFUL FUNCTIONS --- # years = ['2017', '2018', '2019', '2020']
def is_inside(x, y, lon, lat):
"""Check if particles with coordinates x,y are inside grid defined by lon,lat""" def plot_arrival_time(sources: List[str], years: List[str], month: str, outfile: str) -> None:
isok = ~np.isnan(x)
isok[isok] = (x[isok] >= lon.min()) & (x[isok] <= lon.max()) & \ sources = [sources] if isinstance(sources,str) else sources
(y[isok] >= lat.min()) & (y[isok] <= lat.max()) years = [years] if isinstance(years,str) else years
return isok months = [months] if isinstance(months,str) else months
def get_index(x, y, lon, lat): atm = ArrivalTimeMap(minlon, maxlon, minlat, maxlat, res)
"""Get the indices of the pixels containing the particles of coordinates (x,y) for src,yr in itertools.product(sources,years):
if src == "lusail":
Keyword arguments nc_path = basedir + f'/Output_backward_lusail/{yr}/{month}'
x -- longitudes of the particles
y -- latitudes of the particles
lon -- longitude of the regular grid points forming the pixels
lat -- latitude of the regular grid points foring the pixels
"""
isok = is_inside(x, y, lon, lat)
dx = lon[1]-lon[0]
dy = lon[1]-lon[0]
indices = np.full(x.shape, -1, dtype=np.int32)
ix = ((x[isok]-lon[0]) / dx).astype(np.int32)
iy = ((y[isok]-lat[0]) / dy).astype(np.int32)
indices[isok] = iy*(lon.size-1)+ix
return indices
def _arrival_time_month_(record, gridlon, gridlat, source, year, month):
""" Store the minimum arrival time to the source from pixels of a regular grid for all trajectories crossing the pixels
Keyword arguments:
record -- a dictionary used to store all arrival times
gridlon -- longitude of the regular grid points forming the pixels
gridlat -- latitude of the regular grid points foring the pixels
source -- source of the particles in the backtracking simulations (used to find netcdf output files)
year -- year of the simulations (used to find netcdf output files)
month -- month of the simulations (used to find netcdf output files)
"""
for date in range(1,32):
if source == "lusail":
fn = basedir + f"Output_backward_lusail/{year}/{month}/out_{month}_{date}.nc"
else: else:
fn = basedir + f"Output_backward/{year}/{source}/{month}/out_{month}_{date}.nc" nc_path = basedir + f'/Output_backward/{yr}/{src}/{month}'
if not os.path.isfile(fn): for day in range(1,32):
continue ncfile = nc_path+f'/out_{month}_{day}.nc'
with Dataset(fn,"r") as ds: if not os.path.isfile(ncfile):
lats = np.ma.filled(ds["lat"][:], np.nan) continue
lons = np.ma.filled(ds["lon"][:], np.nan) atm.add_arrival_times_from_file(ncfile)
time = ds["time"][:]
tau = np.abs(time-time[0]) arrival_time_days = atm.get_arrival_time_days()
# find index of all pixels crossed by particles during the simulation hasdata = ~np.isnan(arrival_time_days)
indices = get_index(lons, lats, gridlon, gridlat) arrival_time_days[hasdata] = np.floor(arrival_time_days[hasdata])
uids = np.unique(indices)[1:]
for px in uids: _, ax = plt.subplots(figsize=(12,12))
# for each pixel, find all trajectories that crossed it ax.pcolormesh(atm.lon, atm.lat, arrival_time_days, cmap="YlOrRd_r")
itraj = np.where(np.sum(indices == px, axis=1) > 0)[0] cmap = plt.get_cmap('YlOrRd_r')
# then find minimum arrival time from the pixel for all found trajectories colors = cmap(np.arange(5)/5)
itime = np.argmax(indices[itraj] == px, axis=1) for i,clr in enumerate(colors):
min_time = tau[itime].tolist() label = "< 1 day" if i == 0 else f"{i} - {i+1} days"
if px in record.keys(): ax.add_patch(plt.Rectangle((0,0), 0, 0, color=clr, label=label))
record[px].extend(min_time)
else:
record[px] = min_time
def to_iterable(a):
is_iterable = isinstance(a,list) or isinstance(a,tuple)
return a if is_iterable else [a]
def to_shp(lon, lat, a, shpfile):
data = []
ny, nx = a.shape
for i in range(ny):
for j in range(nx):
X = lon[[j, j, j+1, j+1]]
Y = lat[[i, i+1, i+1, i]]
poly = shapely.geometry.Polygon(np.column_stack([X,Y]))
data.append( (a[i,j],poly) )
df = gpd.GeoDataFrame(data, columns=["arrival_time", "geometry"], crs="EPSG:4326")
df.to_file(shpfile)
def plot_arrival_time(sources, years, months, outfile):
"""Plot arrival time map to sources based on backtracking simulations with OpenDrift for given months and years
Keyword arguments:
sources -- source points in the backtracking simulations
years -- years of the backtracking simulations
months -- months of the backtracking simulations
outfile -- file name of the generated map
"""
# COMPUTE QUANTILE MAP
data = np.full((ny*nx), np.nan)
record = {}
sources = to_iterable(sources)
years = to_iterable(years)
months = to_iterable(months)
for src in sources:
print(f"current source is {src}")
for yr in years:
for mn in months:
print(f"{mn} {yr}")
_arrival_time_month_(record, longr, latgr, src, yr, mn)
for k,v in record.items():
data[k] = np.quantile(v,0.05) / 86400.0
data = data.reshape(ny,nx)
# ACTUAL PLOT
fig, ax = plt.subplots(figsize=(12,12))
# plot data
pm = ax.pcolormesh(longr, latgr, data, cmap="YlOrRd_r")
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="4.5%", pad=0.2)
fig.colorbar(pm, cax=cax, label="arrival time [days] (5% quantile)")
plot_landmark(ax, sources) plot_landmark(ax, sources)
current_month = months[0]
name = os.path.splitext(outfile)[0] figname = current_month[0].upper()+current_month[1:]
np.save(name+".npy", {"lon":longr, "lat":latgr, "data":data}) ax.text(0, 0, " "+figname+" ", va="bottom", ha="left", color="k", \
to_shp(longr, latgr, data, name+".shp") fontsize=20, transform=ax.transAxes)
plt.savefig(outfile, dpi=200, bbox_inches="tight") ax.legend( loc="upper right", title="Arrival time\n(5th percentile)", \
fontsize=14, title_fontsize=14)
plt.savefig(outfile, dpi=300, bbox_inches="tight")
plt.close() plt.close()
if __name__ == "__main__": if __name__ == "__main__":
sources = ['ras_laffan','ras_laffan_port', 'abu_fontas', 'umm_al_houl'] for month in ['april', 'june']:
years = ['2017', '2018', '2019', '2020'] outfile = f"fig_arrival_{month}.pdf"
for year in years: plot_arrival_time(sources, years, month, outfile)
firstiter= 9 if year == "2017" else 1 \ No newline at end of file
for i in range(firstiter,13):
time_tag = year + f'{i:02d}'
month = calendar.month_name[i].lower()
dirname = "shp/"+time_tag
os.makedirs(dirname, exist_ok=True)
fname = dirname+f"/arrival_time_percentile_{time_tag}.png"
plot_arrival_time(sources, year, month, fname)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Some classes to process OpenDrift's outputs
Author: Thomas Dobbelaere, Earth and Life Institute, UCLouvain, Belgium
Last modified: 19 July 2022
"""
import numpy as np
from typing import List, Tuple, Dict
from netCDF4 import Dataset
def _extract_lonlat(ncfile: str) -> Tuple[np.ndarray, np.ndarray]:
with Dataset(ncfile) as da:
lon = np.ma.filled(da['lon'][:], np.nan)
lat = np.ma.filled(da['lat'][:], np.nan)
return lon, lat
class DataMap:
"""
Base class to produce maps with given resolution over a given bounding box based on particle trajtories computed by OpenDrift
"""
def __init__(self, minlon: float, maxlon: float, minlat: float, maxlat: float, resolution: int) -> None:
self.lon = np.arange(minlon, maxlon+resolution, resolution)
nx = len(self.lon)
self.lat = np.arange(minlat, maxlat+resolution, resolution)
ny = len(self.lat)
self.data = np.zeros((ny-1)*(nx-1))
self.shape = (ny-1,nx-1)
self.dx = resolution
def _contains(self, lon: np.ndarray, lat: np.ndarray) -> np.ndarray:
isok = ~np.isnan(lon)
isok[isok] = (lon[isok] >= self.lon.min()) & (lon[isok] <= self.lon.max()) & \
(lat[isok] >= self.lat.min()) & (lat[isok] <= self.lat.max())
return isok
def get_cell_indices(self, lon: np.ndarray, lat: np.ndarray) -> np.ndarray:
isok = self._contains(lon, lat)
indices = np.full(lon.shape, -1, dtype=np.int32)
ix = ((lon[isok]-self.lon[0]) / self.dx).astype(np.int32)
iy = ((lat[isok]-self.lat[0]) / self.dx).astype(np.int32)
indices[isok] = iy*(self.lon.size-1)+ix
return indices
class ArrivalTimeMap(DataMap):
"""
Compute map of minimum arrival time based on backtracking simulation outputs from OpenOil.
This is performed by computing the 5% percentile q such that 95% of particles crossing a given pixel needed at least q days to reach their source point
"""
arrival_times_per_pixel: Dict[int,List[int]] = {}
def __init__(self, minlon: float, maxlon: float, minlat: float, maxlat: float, resolution: int) -> None:
super().__init__(minlon, maxlon, minlat, maxlat, resolution)
self.touched = np.full(self.data.size, False)
def add_arrival_times_from_file(self, ncfile: str) -> None:
with Dataset(ncfile,"r") as da:
time = da["time"][:]
time[:] = np.abs(time-time[0])
indices = self.get_cell_indices(*_extract_lonlat(ncfile))
unique_ids = np.unique(indices)[1:]
for px in unique_ids:
# for each pixel, find all trajectories that crossed it
itraj = np.where(np.sum(indices == px, axis=1) > 0)[0]
# then find minimum arrival time from the pixel for all found trajectories
itime = np.argmax(indices[itraj] == px, axis=1)
min_time = time[itime].tolist()
if self.touched[px]:
self.arrival_times_per_pixel[px].extend(min_time)
else:
self.arrival_times_per_pixel[px] = min_time
self.touched[unique_ids] = True
def get_arrival_time_days(self) -> np.ndarray:
for px, arrival_times in self.arrival_times_per_pixel.items():
self.data[px] = np.quantile(arrival_times,0.05) / 86400.0
arrival_days = np.where(self.touched, self.data, np.nan)
return arrival_days.reshape(self.shape)
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter