/*
  This file is part of CDO. CDO is a collection of Operators to manipulate and analyse Climate model Data.

  Author: Uwe Schulzweida

*/

/*
   This module contains the following operators:

      Detrend    detrend         Detrend
*/

#include <cdi.h>

#include "varray.h"
#include "process_int.h"
#include "cdo_vlist.h"
#include "cdo_options.h"
#include "datetime.h"
#include "cimdOmp.h"
#include "pmlist.h"
#include "param_conversion.h"
#include "field_functions.h"
#include "arithmetic.h"

static void
detrend(long nts, const Varray<double> &deltaTS0, double missval1, const Varray<double> &array1, Varray<double> &array2)
{
  auto missval2 = missval1;
  double sumj = 0.0, sumjj = 0.0;
  double sumx = 0.0, sumjx = 0.0;
  long n = 0;

  auto detrend_sum = [&](auto j, auto is_EQ) {
    if (!is_EQ(array1[j], missval1))
      {
        auto zj = deltaTS0[j];
        sumj += zj;
        sumjj += zj * zj;
        sumjx += zj * array1[j];
        sumx += array1[j];
        n++;
      }
  };

  if (std::isnan(missval1))
    for (long j = 0; j < nts; ++j) detrend_sum(j, dbl_is_equal);
  else
    for (long j = 0; j < nts; ++j) detrend_sum(j, is_equal);

  double work1 = 0, work2 = 0;

  {
    auto is_EQ = is_equal;
    work1 = DIVM(SUBM(sumjx, DIVM(MULM(sumj, sumx), n)), SUBM(sumjj, DIVM(MULM(sumj, sumj), n)));
    work2 = SUBM(DIVM(sumx, n), MULM(DIVM(sumj, n), work1));
  }

  auto detrend_kernel = [&](auto j, auto is_EQ) { return SUBM(array1[j], ADDM(work2, MULM(work1, deltaTS0[j]))); };

  if (std::isnan(missval1))
    for (long j = 0; j < nts; ++j) array2[j] = detrend_kernel(j, dbl_is_equal);
  else
    for (long j = 0; j < nts; ++j) array2[j] = detrend_kernel(j, is_equal);
}

static void
computeDeltaTS0(bool tstepIsEqual, int nts, int calendar, DateTimeList &dtlist, Varray<double> &deltaTS0)
{
  CheckTimeIncr checkTimeIncr;
  JulianDate julianDate0;
  double deltat1 = 0.0;

  for (int tsID = 0; tsID < nts; ++tsID)
    {
      auto vDateTime = dtlist.get_vDateTime(tsID);
      if (tstepIsEqual) check_time_increment(tsID, calendar, vDateTime, checkTimeIncr);
      deltaTS0[tsID] = tstepIsEqual ? (double) tsID : delta_time_step_0(tsID, calendar, vDateTime, julianDate0, deltat1);
    }
}

static void
detrendGetParameter(bool &tstepIsEqual)
{
  auto pargc = cdo_operator_argc();
  if (pargc)
    {
      const auto &pargv = cdo_get_oper_argv();

      KVList kvlist;
      kvlist.name = cdo_module_name();
      if (kvlist.parse_arguments(pargv) != 0) cdo_abort("Parse error!");
      if (Options::cdoVerbose) kvlist.print();

      for (const auto &kv : kvlist)
        {
          const auto &key = kv.key;
          if (kv.nvalues > 1) cdo_abort("Too many values for parameter key >%s<!", key);
          if (kv.nvalues < 1) cdo_abort("Missing value for parameter key >%s<!", key);
          const auto &value = kv.values[0];

          // clang-format off
          if      (key == "equal") tstepIsEqual = parameter_to_bool(value);
          else cdo_abort("Invalid parameter key >%s<!", key);
          // clang-format on
        }
    }
}

class Detrend : public Process
{
public:
  using Process::Process;
  inline static CdoModule module = {
    .name = "Detrend",
    .operators = { { "detrend", DetrendHelp } },
    .aliases = {},
    .mode = EXPOSED,     // Module mode: 0:intern 1:extern
    .number = CDI_REAL,  // Allowed number type
    .constraints = { 1, 1, NoRestriction },
  };
  inline static RegisterEntry<Detrend> registration = RegisterEntry<Detrend>(module);

  int varID, levelID;
  DateTimeList dtlist;

  CdoStreamID streamID1;
  CdoStreamID streamID2;

  VarList varList;
  FieldVector3D vars;

  int taxisID1;
  int taxisID2;
  int vlistID1;
  int nvars;

  bool tstepIsEqual = true;

public:
  void
  init()
  {
    detrendGetParameter(tstepIsEqual);

    streamID1 = cdo_open_read(0);

    vlistID1 = cdo_stream_inq_vlist(streamID1);
    auto vlistID2 = vlistDuplicate(vlistID1);

    vlist_unpack(vlistID2);

    taxisID1 = vlistInqTaxis(vlistID1);
    taxisID2 = taxisDuplicate(taxisID1);
    vlistDefTaxis(vlistID2, taxisID2);

    streamID2 = cdo_open_write(1);
    cdo_def_vlist(streamID2, vlistID2);

    varList_init(varList, vlistID1);

    nvars = vlistNvars(vlistID1);
  }

  void
  run()
  {
    int tsID = 0;
    while (true)
      {
        auto nrecs = cdo_stream_inq_timestep(streamID1, tsID);
        if (nrecs == 0) break;

        constexpr size_t NALLOC_INC = 1024;
        if ((size_t) tsID >= vars.size()) vars.resize(vars.size() + NALLOC_INC);

        dtlist.taxis_inq_timestep(taxisID1, tsID);

        fields_from_vlist(vlistID1, vars[tsID]);

        for (int recID = 0; recID < nrecs; ++recID)
          {
            cdo_inq_record(streamID1, &varID, &levelID);
            auto &field = vars[tsID][varID][levelID];
            field.init(varList[varID]);
            cdo_read_record(streamID1, field);
          }

        tsID++;
      }

    auto nts = tsID;
    Varray<double> deltaTS0(nts);
    Varray2D<double> array1_2D(Threading::ompNumThreads, Varray<double>(nts));
    Varray2D<double> array2_2D(Threading::ompNumThreads, Varray<double>(nts));

    auto calendar = taxisInqCalendar(taxisID1);
    computeDeltaTS0(tstepIsEqual, nts, calendar, dtlist, deltaTS0);

    for (varID = 0; varID < nvars; ++varID)
      {
        const auto &var = varList[varID];
        auto nsteps = var.isConstant ? 1 : nts;
        auto missval = var.missval;
        auto fieldMemType = var.memType;
        auto gridsize = var.gridsize;
        for (levelID = 0; levelID < var.nlevels; ++levelID)
          {
#ifdef _OPENMP
#pragma omp parallel for default(none) schedule(static) \
    shared(fieldMemType, gridsize, nsteps, deltaTS0, missval, array1_2D, array2_2D, vars, varID, levelID)
#endif
            for (size_t i = 0; i < gridsize; ++i)
              {
                auto ompthID = cdo_omp_get_thread_num();
                auto &array1 = array1_2D[ompthID];
                auto &array2 = array2_2D[ompthID];

                if (fieldMemType == MemType::Float)
                  for (int k = 0; k < nsteps; ++k) array1[k] = vars[k][varID][levelID].vec_f[i];
                else
                  for (int k = 0; k < nsteps; ++k) array1[k] = vars[k][varID][levelID].vec_d[i];

                detrend(nsteps, deltaTS0, missval, array1, array2);

                if (fieldMemType == MemType::Float)
                  for (int k = 0; k < nsteps; ++k) vars[k][varID][levelID].vec_f[i] = array2[k];
                else
                  for (int k = 0; k < nsteps; ++k) vars[k][varID][levelID].vec_d[i] = array2[k];
              }
          }
      }

    for (tsID = 0; tsID < nts; ++tsID)
      {
        dtlist.taxis_def_timestep(taxisID2, tsID);
        cdo_def_timestep(streamID2, tsID);

        for (varID = 0; varID < nvars; ++varID)
          {
            const auto &var = varList[varID];
            if (tsID && var.isConstant) continue;
            for (levelID = 0; levelID < var.nlevels; ++levelID)
              {
                cdo_def_record(streamID2, varID, levelID);
                auto &field = vars[tsID][varID][levelID];
                cdo_write_record(streamID2, field);
              }
          }
      }
  }

  void
  close()
  {

    cdo_stream_close(streamID2);
    cdo_stream_close(streamID1);
  }
};
