#!/usr/bin/env python3
#  -*- coding: utf-8 -*-

"""
Some classes to process OpenDrift's outputs

Author: Thomas Dobbelaere, Earth and Life Institute, UCLouvain, Belgium
Last modified: 19 July 2022
"""

import numpy as np
from typing import List, Tuple, Dict
from netCDF4 import Dataset

def _extract_lonlat(ncfile: str) -> Tuple[np.ndarray, np.ndarray]:
    with Dataset(ncfile) as da:
        lon = np.ma.filled(da['lon'][:], np.nan)
        lat = np.ma.filled(da['lat'][:], np.nan)
    return lon, lat

class DataMap:
    """
    Base class to produce maps with given resolution over a given bounding box based on particle trajtories computed by OpenDrift
    """

    def __init__(self, minlon: float, maxlon: float, minlat: float, maxlat: float, resolution: int) -> None:
        self.lon = np.arange(minlon, maxlon+resolution, resolution)
        nx = len(self.lon)
        self.lat = np.arange(minlat, maxlat+resolution, resolution)
        ny = len(self.lat)
        self.data = np.zeros((ny-1)*(nx-1))
        self.shape = (ny-1,nx-1)
        self.dx = resolution
    
    def _contains(self, lon: np.ndarray, lat: np.ndarray) -> np.ndarray:
        isok = ~np.isnan(lon)
        isok[isok] = (lon[isok] >= self.lon.min()) & (lon[isok] <= self.lon.max()) & \
                        (lat[isok] >= self.lat.min()) & (lat[isok] <= self.lat.max())
        return isok

    def get_cell_indices(self, lon: np.ndarray, lat: np.ndarray) -> np.ndarray:
        isok = self._contains(lon, lat)
        indices = np.full(lon.shape, -1, dtype=np.int32)
        ix = ((lon[isok]-self.lon[0]) / self.dx).astype(np.int32)
        iy = ((lat[isok]-self.lat[0]) / self.dx).astype(np.int32)
        indices[isok] = iy*(self.lon.size-1)+ix

        return indices        

class ArrivalTimeMap(DataMap):
    """
    Compute map of minimum arrival time based on backtracking simulation outputs from OpenOil.
    This is performed by computing the 5% percentile q such that 95% of particles crossing a given pixel needed at least q days to reach their source point
    """

    arrival_times_per_pixel: Dict[int,List[int]] = {}

    def __init__(self, minlon: float, maxlon: float, minlat: float, maxlat: float, resolution: int) -> None:
        super().__init__(minlon, maxlon, minlat, maxlat, resolution)
        self.touched = np.full(self.data.size, False)

    def add_arrival_times_from_file(self, ncfile: str) -> None:
        with Dataset(ncfile,"r") as da:
            time = da["time"][:]
        time[:] = np.abs(time-time[0])
        indices = self.get_cell_indices(*_extract_lonlat(ncfile))
        unique_ids = np.unique(indices)[1:]
        for px in unique_ids:
            # for each pixel, find all trajectories that crossed it
            itraj = np.where(np.sum(indices == px, axis=1) > 0)[0]
            # then find minimum arrival time from the pixel for all found trajectories
            itime = np.argmax(indices[itraj] == px, axis=1)
            min_time = time[itime].tolist()
            if self.touched[px]:
                self.arrival_times_per_pixel[px].extend(min_time)
            else:
                self.arrival_times_per_pixel[px] = min_time
        self.touched[unique_ids] = True
    
    def get_arrival_time_days(self) -> np.ndarray:
        for px, arrival_times in self.arrival_times_per_pixel.items():
            self.data[px] = np.quantile(arrival_times,0.05) / 86400.0
        arrival_days = np.where(self.touched, self.data, np.nan)
        return arrival_days.reshape(self.shape)