diff --git a/arrival_time_percentile.py b/arrival_time_percentile.py
index 073fe0f33a1e80e6ebbf74ce1c315abfda4e3360..dadfd989b35717bbd97ce94e224ebe0c18139b81 100644
--- a/arrival_time_percentile.py
+++ b/arrival_time_percentile.py
@@ -9,156 +9,59 @@ Last modified: 19 July 2022
 """
 
 import numpy as np
-from netCDF4 import Dataset
-import os, calendar
+import os
 import matplotlib.pyplot as plt
-from mpl_toolkits.axes_grid1 import make_axes_locatable
-import geopandas as gpd
 from map_data import plot_landmark
-import shapely.geometry
+from postpro import ArrivalTimeMap
+from typing import List
+import itertools
 
-# --- USEFUL VARIABLES --- #
-# grid
 minlon, maxlon, minlat, maxlat = 50, 53, 24, 27.5
 res = 0.01
-longr = np.arange(minlon, maxlon+res, res)
-latgr = np.arange(minlat, maxlat+res, res)
-ox, oy = longr[0], latgr[0]
-nx, ny = longr.size-1, latgr.size-1
-# paths
 basedir = "/export/miro/students/tanselain/OpenOil/"
-
-# --- USEFUL FUNCTIONS --- #
-def is_inside(x, y, lon, lat):
-    """Check if particles with coordinates x,y are inside grid defined by lon,lat"""
-    isok = ~np.isnan(x)
-    isok[isok] = (x[isok] >= lon.min()) & (x[isok] <= lon.max()) & \
-                    (y[isok] >= lat.min()) & (y[isok] <= lat.max())
-    return isok
-
-def get_index(x, y, lon, lat):
-    """Get the indices of the pixels containing the particles of coordinates (x,y)
-
-    Keyword arguments
-    x -- longitudes of the particles
-    y -- latitudes of the particles
-    lon -- longitude of the regular grid points forming the pixels
-    lat -- latitude of the regular grid points foring the pixels
-    """
-    isok = is_inside(x, y, lon, lat)
-    dx = lon[1]-lon[0]
-    dy = lon[1]-lon[0]
-    indices = np.full(x.shape, -1, dtype=np.int32)
-    ix = ((x[isok]-lon[0]) / dx).astype(np.int32)
-    iy = ((y[isok]-lat[0]) / dy).astype(np.int32)
-    indices[isok] = iy*(lon.size-1)+ix
-    return indices
-
-def _arrival_time_month_(record, gridlon, gridlat, source, year, month):
-    """ Store the minimum arrival time to the source from pixels of a regular grid for all trajectories crossing the pixels
-
-    Keyword arguments:
-    record -- a dictionary used to store all arrival times
-    gridlon -- longitude of the regular grid points forming the pixels
-    gridlat -- latitude of the regular grid points foring the pixels
-    source -- source of the particles in the backtracking simulations (used to find netcdf output files)
-    year -- year of the simulations (used to find netcdf output files)
-    month -- month of the simulations (used to find netcdf output files)
-    """
-    for date in range(1,32):
-        if source == "lusail":
-            fn = basedir + f"Output_backward_lusail/{year}/{month}/out_{month}_{date}.nc"
+sources = ['ras_laffan','ras_laffan_port', 'abu_fontas', 'umm_al_houl']  
+years = ['2017', '2018', '2019', '2020']
+
+def plot_arrival_time(sources: List[str], years: List[str], month: str, outfile: str) -> None:
+    
+    sources = [sources] if isinstance(sources,str) else sources
+    years = [years] if isinstance(years,str) else years
+    months = [months] if isinstance(months,str) else months
+
+    atm = ArrivalTimeMap(minlon, maxlon, minlat, maxlat, res)
+    for src,yr in itertools.product(sources,years):
+        if src == "lusail":
+            nc_path = basedir + f'/Output_backward_lusail/{yr}/{month}'
         else:
-            fn = basedir + f"Output_backward/{year}/{source}/{month}/out_{month}_{date}.nc"
-        if not os.path.isfile(fn):
-            continue
-        with Dataset(fn,"r") as ds:
-            lats = np.ma.filled(ds["lat"][:], np.nan)
-            lons = np.ma.filled(ds["lon"][:], np.nan)
-            time = ds["time"][:]
-        tau = np.abs(time-time[0])
-        # find index of all pixels crossed by particles during the simulation
-        indices = get_index(lons, lats, gridlon, gridlat)
-        uids = np.unique(indices)[1:]
-        for px in uids:
-            # for each pixel, find all trajectories that crossed it
-            itraj = np.where(np.sum(indices == px, axis=1) > 0)[0]
-            # then find minimum arrival time from the pixel for all found trajectories
-            itime = np.argmax(indices[itraj] == px, axis=1)
-            min_time = tau[itime].tolist()
-            if px in record.keys():
-                record[px].extend(min_time)
-            else:
-                record[px] = min_time
-
-def to_iterable(a):
-    is_iterable = isinstance(a,list) or isinstance(a,tuple)
-    return a if is_iterable else [a]
-
-def to_shp(lon, lat, a, shpfile):
-    data = []
-    ny, nx = a.shape
-    for i in range(ny):
-        for j in range(nx):
-            X = lon[[j,   j, j+1, j+1]]
-            Y = lat[[i, i+1, i+1,   i]]
-            poly = shapely.geometry.Polygon(np.column_stack([X,Y]))
-            data.append( (a[i,j],poly) )
-    df = gpd.GeoDataFrame(data, columns=["arrival_time", "geometry"], crs="EPSG:4326")
-    df.to_file(shpfile)
-
-def plot_arrival_time(sources, years, months, outfile):
-    """Plot arrival time map to sources based on backtracking simulations with OpenDrift for given months and years
-
-    Keyword arguments:
-    sources -- source points in the backtracking simulations
-    years -- years of the backtracking simulations
-    months -- months of the backtracking simulations
-    outfile -- file name of the generated map
-    """
-    # COMPUTE QUANTILE MAP
-    data = np.full((ny*nx), np.nan)
-    record = {}
-    sources = to_iterable(sources)
-    years = to_iterable(years)
-    months = to_iterable(months)
-    for src in sources:
-        print(f"current source is {src}")
-        for yr in years:
-            for mn in months:
-                print(f"{mn} {yr}")
-                _arrival_time_month_(record, longr, latgr, src, yr, mn)
-    for k,v in record.items():
-        data[k] = np.quantile(v,0.05) / 86400.0
-    data = data.reshape(ny,nx)
-
-    # ACTUAL PLOT
-    fig, ax = plt.subplots(figsize=(12,12))
-
-    # plot data
-    pm = ax.pcolormesh(longr, latgr, data, cmap="YlOrRd_r")
-    divider = make_axes_locatable(ax)
-    cax = divider.append_axes("right", size="4.5%", pad=0.2)
-    fig.colorbar(pm, cax=cax, label="arrival time [days] (5% quantile)")
-
+            nc_path = basedir + f'/Output_backward/{yr}/{src}/{month}'
+        for day in range(1,32):
+            ncfile = nc_path+f'/out_{month}_{day}.nc'
+            if not os.path.isfile(ncfile):  
+                continue
+            atm.add_arrival_times_from_file(ncfile)
+
+    arrival_time_days = atm.get_arrival_time_days()
+    hasdata = ~np.isnan(arrival_time_days)
+    arrival_time_days[hasdata] = np.floor(arrival_time_days[hasdata])
+
+    _, ax = plt.subplots(figsize=(12,12))
+    ax.pcolormesh(atm.lon, atm.lat, arrival_time_days, cmap="YlOrRd_r")
+    cmap = plt.get_cmap('YlOrRd_r')
+    colors = cmap(np.arange(5)/5)
+    for i,clr in enumerate(colors):
+        label = "< 1 day" if i == 0 else f"{i} - {i+1} days"
+        ax.add_patch(plt.Rectangle((0,0), 0, 0, color=clr, label=label))
     plot_landmark(ax, sources)
-
-    name = os.path.splitext(outfile)[0]
-    np.save(name+".npy", {"lon":longr, "lat":latgr, "data":data})
-    to_shp(longr, latgr, data, name+".shp")
-    plt.savefig(outfile, dpi=200, bbox_inches="tight")
+    current_month = months[0]
+    figname = current_month[0].upper()+current_month[1:]
+    ax.text(0, 0, " "+figname+" ", va="bottom", ha="left", color="k", \
+            fontsize=20, transform=ax.transAxes)
+    ax.legend( loc="upper right", title="Arrival time\n(5th percentile)", \
+                fontsize=14, title_fontsize=14)
+    plt.savefig(outfile, dpi=300, bbox_inches="tight")
     plt.close()
 
 if __name__ == "__main__":
-    sources = ['ras_laffan','ras_laffan_port', 'abu_fontas', 'umm_al_houl']  
-    years = ['2017', '2018', '2019', '2020']
-    for year in years:
-        firstiter= 9 if year == "2017" else 1
-        for i in range(firstiter,13):
-            time_tag = year + f'{i:02d}'
-            month = calendar.month_name[i].lower()
-            dirname = "shp/"+time_tag
-            os.makedirs(dirname, exist_ok=True)
-            fname = dirname+f"/arrival_time_percentile_{time_tag}.png"
-
-            plot_arrival_time(sources, year, month, fname)
+    for month in ['april', 'june']:
+        outfile = f"fig_arrival_{month}.pdf"
+        plot_arrival_time(sources, years, month, outfile)
\ No newline at end of file
diff --git a/postpro.py b/postpro.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac809a9aa47d86d0ef81e3cbdd38c20f25f4f85f
--- /dev/null
+++ b/postpro.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python3
+#  -*- coding: utf-8 -*-
+
+"""
+Some classes to process OpenDrift's outputs
+
+Author: Thomas Dobbelaere, Earth and Life Institute, UCLouvain, Belgium
+Last modified: 19 July 2022
+"""
+
+import numpy as np
+from typing import List, Tuple, Dict
+from netCDF4 import Dataset
+
+def _extract_lonlat(ncfile: str) -> Tuple[np.ndarray, np.ndarray]:
+    with Dataset(ncfile) as da:
+        lon = np.ma.filled(da['lon'][:], np.nan)
+        lat = np.ma.filled(da['lat'][:], np.nan)
+    return lon, lat
+
+class DataMap:
+    """
+    Base class to produce maps with given resolution over a given bounding box based on particle trajtories computed by OpenDrift
+    """
+
+    def __init__(self, minlon: float, maxlon: float, minlat: float, maxlat: float, resolution: int) -> None:
+        self.lon = np.arange(minlon, maxlon+resolution, resolution)
+        nx = len(self.lon)
+        self.lat = np.arange(minlat, maxlat+resolution, resolution)
+        ny = len(self.lat)
+        self.data = np.zeros((ny-1)*(nx-1))
+        self.shape = (ny-1,nx-1)
+        self.dx = resolution
+    
+    def _contains(self, lon: np.ndarray, lat: np.ndarray) -> np.ndarray:
+        isok = ~np.isnan(lon)
+        isok[isok] = (lon[isok] >= self.lon.min()) & (lon[isok] <= self.lon.max()) & \
+                        (lat[isok] >= self.lat.min()) & (lat[isok] <= self.lat.max())
+        return isok
+
+    def get_cell_indices(self, lon: np.ndarray, lat: np.ndarray) -> np.ndarray:
+        isok = self._contains(lon, lat)
+        indices = np.full(lon.shape, -1, dtype=np.int32)
+        ix = ((lon[isok]-self.lon[0]) / self.dx).astype(np.int32)
+        iy = ((lat[isok]-self.lat[0]) / self.dx).astype(np.int32)
+        indices[isok] = iy*(self.lon.size-1)+ix
+
+        return indices        
+
+class ArrivalTimeMap(DataMap):
+    """
+    Compute map of minimum arrival time based on backtracking simulation outputs from OpenOil.
+    This is performed by computing the 5% percentile q such that 95% of particles crossing a given pixel needed at least q days to reach their source point
+    """
+
+    arrival_times_per_pixel: Dict[int,List[int]] = {}
+
+    def __init__(self, minlon: float, maxlon: float, minlat: float, maxlat: float, resolution: int) -> None:
+        super().__init__(minlon, maxlon, minlat, maxlat, resolution)
+        self.touched = np.full(self.data.size, False)
+
+    def add_arrival_times_from_file(self, ncfile: str) -> None:
+        with Dataset(ncfile,"r") as da:
+            time = da["time"][:]
+        time[:] = np.abs(time-time[0])
+        indices = self.get_cell_indices(*_extract_lonlat(ncfile))
+        unique_ids = np.unique(indices)[1:]
+        for px in unique_ids:
+            # for each pixel, find all trajectories that crossed it
+            itraj = np.where(np.sum(indices == px, axis=1) > 0)[0]
+            # then find minimum arrival time from the pixel for all found trajectories
+            itime = np.argmax(indices[itraj] == px, axis=1)
+            min_time = time[itime].tolist()
+            if self.touched[px]:
+                self.arrival_times_per_pixel[px].extend(min_time)
+            else:
+                self.arrival_times_per_pixel[px] = min_time
+        self.touched[unique_ids] = True
+    
+    def get_arrival_time_days(self) -> np.ndarray:
+        for px, arrival_times in self.arrival_times_per_pixel.items():
+            self.data[px] = np.quantile(arrival_times,0.05) / 86400.0
+        arrival_days = np.where(self.touched, self.data, np.nan)
+        return arrival_days.reshape(self.shape)