#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Compute the proportion of time an ocean cell is in an meso/eutrophic state
@author: Noam Vogt-Vincent
"""

import numpy as np
import xarray as xr
import dask as ds
import matplotlib.pyplot as plt
import cmasher as cmr
import matplotlib.colors as colors
import matplotlib.ticker as mticker
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from dask.diagnostics import ProgressBar
from matplotlib.gridspec import GridSpec

###############################################################################
# LOAD DATA ###################################################################
###############################################################################

with xr.open_dataset('DATA/CMEMS_CHL_1998-2022.nc', chunks={'lon': 500, 'lat': 500, 'time': 10}) as file:
    chl = file.CHL

###############################################################################
# PROCESS DATA ################################################################
###############################################################################

meso_thresh = {'lower': 2.6, # mg/m3
               'upper': 7.3}

with ProgressBar():
    meso = ((chl >= meso_thresh['lower'])*(chl < meso_thresh['upper'])).astype(int) # Mesotrophic
    eu   = (chl >= meso_thresh['upper']).astype(int)                                # Eutrophic

    meso_prop = meso.mean(dim='time').compute()
    eu_prop = eu.mean(dim='time').compute()

###############################################################################
# PLOT DATA ###################################################################
###############################################################################

f = plt.figure(figsize=(13, 16))
gs = GridSpec(2, 2, figure=f, width_ratios=[1, 0.03], wspace=0.05, hspace=0.06)
ax = []
plot = []
gl = []
ax.append(f.add_subplot(gs[0, 0],  projection = ccrs.PlateCarree())) # Mesotrophic plot
ax.append(f.add_subplot(gs[1, 0],  projection = ccrs.PlateCarree())) # Eutrophic plot
ax.append(f.add_subplot(gs[0, 1])) # Mesotrophic cbar
ax.append(f.add_subplot(gs[1, 1])) # Eutrophic cbar

dlon = chl.lon[1].values - chl.lon[0].values
dlat = chl.lat[1].values - chl.lat[0].values
lon_bnd = np.linspace(chl.lon[0].values-0.5*dlon, chl.lon[-1].values+0.5*dlon, num=len(chl.lon)+1)
lat_bnd = np.linspace(chl.lat[0].values-0.5*dlat, chl.lat[-1].values+0.5*dlat, num=len(chl.lat)+1)

land_10k = cfeature.NaturalEarthFeature('physical', 'land', '10m',
                                        edgecolor='w',
                                        facecolor='k',
                                        zorder=1,
                                        linewidth=0.5)



plot.append(ax[0].pcolormesh(lon_bnd, lat_bnd, meso_prop, cmap=cmr.jungle,
                             norm=colors.LogNorm(vmin=1e-3, vmax=1e0),
                             transform=ccrs.PlateCarree()))
plot.append(ax[1].pcolormesh(lon_bnd, lat_bnd, eu_prop, cmap=cmr.flamingo,
                             norm=colors.LogNorm(vmin=1e-3, vmax=1e0),
                             transform=ccrs.PlateCarree()))

for i in range(2):
    ax[i].set_aspect(1)
    ax[i].set_facecolor('k')

    ax[i].add_feature(land_10k)
    gl.append(ax[i].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                              linewidth=0.5, color='white', linestyle='-', zorder=11))
    gl[i].xlocator = mticker.FixedLocator(np.arange(-220, 220, 20))
    gl[i].ylocator = mticker.FixedLocator(np.arange(-80, 120, 20))
    gl[i].top_labels = False
    gl[i].right_labels = False
    gl[i].ylabel_style = {'size': 20}
    gl[i].xlabel_style = {'size': 20}

    if i == 0:
        gl[i].bottom_labels = False

ax[0].text(27, -44, 'Mesotrophic conditions', fontsize=28, color='w', fontweight='semibold')
ax[1].text(27, -44, 'Eutrophic conditions', fontsize=28, color='w', fontweight='semibold')

cb0 = plt.colorbar(plot[0], cax=ax[2])
cb0.set_label('Proportion of time', size=24)
ax[2].tick_params(axis='y', labelsize=22)
cb1 = plt.colorbar(plot[1], cax=ax[3])
cb1.set_label('Proportion of time', size=24)
ax[3].tick_params(axis='y', labelsize=22)

plt.savefig('IO_State.png', bbox_inches='tight', dpi=300)