#include "tidl_custom_ops.h"

#include "../algo/inc/tidl_commonUtils.h"
#include "../utils/perfsim/common.h" /* include MAX MIN  macro */
#include <math.h>


template <typename T>
void TIDL_refCustomTileProcess(const CustomOpParams& params)
{
    const T* __restrict__ input = static_cast<const T*>(params.inDataPtrs[0]);
    T* __restrict__ output      = static_cast<T*>(params.outDataPtrs[0]);
    const int32_t* inShape      = params.inDataDims[0];
    const int32_t* outShape     = params.outDataDims;
    constexpr int32_t numDims   = TIDL_DIM_MAX;

    /* 
     * Calculate the stride for each dimension.
     * Where stride decides how far to jump in memory to move by 1 in a certain dimension.
    */
    int32_t inStride[TIDL_DIM_MAX]  = {0};
    int32_t outStride[TIDL_DIM_MAX] = {0};
    inStride[numDims - 1]  = 1;
    outStride[numDims - 1] = 1;
    for (int i = numDims - 2; i >= 0; --i)
    {
      inStride[i]  = inStride[i + 1]  * inShape[i + 1];
      outStride[i] = outStride[i + 1] * outShape[i + 1];
    }

    /*
     * Detect linear tail, this finds the last dimensions that are already the same between input and output.
     * Those dimensions can be copied in one straight block without worring about the wraps around calculations
     * if the out dim is larger than the input dim, which is faster.
    */
    int32_t linearDims = numDims;
    while (linearDims > 0 && (inShape[linearDims - 1] == outShape[linearDims - 1]))
    {
      linearDims--;
    }

    /* Calculate the size of the inner block that will be copied at once (product of matching dimensions) */
    int32_t innerBlock = 1;
    for (int d = linearDims; d < numDims; ++d)
    {
      innerBlock *= outShape[d];
    }

    /********** FAST PATH when innerBlock == 1 **********
     * Use mixed-radix counters (carry) to increment output coordinates
     * and update input offset incrementally (no divisions, no fixed-point math).
     * This is minimal, correct and avoids per-element division/modulo.
    *****************************************************/
    if (innerBlock == 1)
    {
        const int32_t total = params.outDataCount;
        // Coordinate in output space (mixed-radix)
        int32_t coord[TIDL_DIM_MAX]    = {0};
        // Corresponding input coordinate (wrapped by inShape)
        int32_t inCoord[TIDL_DIM_MAX]  = {0};
        int32_t inOffset = 0;

        const T* __restrict__ inBase  = input;
        T* __restrict__ outPtr        = output;

        for (int32_t i = 0; i < total; ++i)
        {
            // Copy single element
            *outPtr++ = inBase[inOffset];

            // Increment the mixed-radix coordinate with carry
            for (int d = numDims - 1; d >= 0; --d)
            {
                // Increment output coord at dim d
                coord[d] += 1;
                if (coord[d] < outShape[d])
                {
                    // No carry beyond this digit: update corresponding input coord and offset
                    int32_t oldIn = inCoord[d];
                    int32_t newIn = oldIn + 1;
                    if (newIn >= inShape[d])
                    {
                       newIn = 0; // Wrap
                    }
                    inCoord[d] = newIn;
                    inOffset += (newIn - oldIn) * inStride[d];
                    break; // Done incrementing
                }
                else
                {
                    // Wrapped in output coordinate -> reset this digit and produce carry
                    coord[d] = 0;
                    int32_t oldIn = inCoord[d];
                    inCoord[d] = 0;
                    inOffset += (0 - oldIn) * inStride[d];
                    // Continue carry to next more-significant digit
                }
            }
        }
        
        return;
    }

    /********** ORIGINAL (fall-back) PATH for innerBlock != 1 **********/
    /* 
     * Compute the outer stride, since outer lives in a reduced index space with innerBlock stripped out,
     * so we need to calculate the stride while taking that into account.
    */
    int32_t outerStride[TIDL_DIM_MAX]     = {0};
    int32_t outerStrideInv[TIDL_DIM_MAX]  = {0};
    for (int d = 0; d < linearDims; ++d)
    {
      outerStride[d] = outStride[d] / innerBlock;
      /*
       * Integer division. Fixed-point reciprocal, scale up.
       * Note: we cast outerStride[i] to uint32_t to avoid signed issues.
      */ 
      outerStrideInv[d] = static_cast<int32_t>((static_cast<uint64_t>(1) << 32) / static_cast<uint32_t>(outerStride[d]));
    }

    /* Calculate inShape inverse to be used for further calculations */
    int32_t inShapeInv[TIDL_DIM_MAX] = {1};
    for(int i = 0; i < TIDL_DIM_MAX; i++)
    {
      inShapeInv[i] = static_cast<int32_t>((static_cast<uint64_t>(1) << 32) / static_cast<uint32_t>(inShape[i])); // Fixed-point reciprocal, scale up.
    }

    /* Loop over the outter dimensions */
    int32_t outerCount = params.outDataCount / innerBlock;
    for (int32_t outer = 0; outer < outerCount; ++outer)
    {
        int32_t inOffset = 0;
        int32_t rem = outer;

        /* Wraps around using modulo (%) if output is larger than input */
        for (int d = 0; d < linearDims; ++d)
        {
          /* 
           * Equivalent to:
           *                int32_t coord = rem / outerStride[d];
          */
          int32_t coord = static_cast<int32_t>((static_cast<int64_t>(rem) * static_cast<int64_t>(outerStrideInv[d])) >> 32);  // Fixed-point reciprocal, scale down.

          /* 
           * Equivalent to:
           *                rem %= outerStride[d];
           * Utilizing: (a % b == a - (a / b) * b)
          */
          rem -= coord * outerStride[d];

          /* 
           * Equivalent to:
           *                int32_t inCoord = coord % inShape[d];
           * Utilizing: (a % b == a - (a / b) * b)
          */
          int32_t coordDivInShape = static_cast<int32_t>((static_cast<int64_t>(coord) * inShapeInv[d]) >> 32);  // Fixed-point reciprocal, scale down.
          int32_t inCoord         = coord - coordDivInShape * inShape[d];

          inOffset += inCoord * inStride[d];
        }

        /* 
         * Plain memory copy for the inner block since there is no dimension change between the input and the output.
         * Note: inPtr/outPtr pointers are aliases of input/output pointers,
         * meaning they are dervied pointer and don't create a new base to it's safe with the __restrict__ qualifier.
        */
        const T* __restrict__ inPtr = input + inOffset;
        T*  __restrict__ outPtr     = output + outer * innerBlock;
        memcpy(outPtr, inPtr, static_cast<size_t>(innerBlock) * sizeof(T));
    }
}


template <typename T>
int32_t TIDL_customOpsProcessRef(void* tidlHandle, const CustomOpParams& params)
{
  int32_t status          = CUSTOM_SUCCESS;
  int32_t inElementType   = params.inDataType[0];
  int32_t outElementType  = params.outDataType;

  switch(params.customLayerType)
  {
    case TIDL_CUSTOM_TYPE_TILE:
      TIDL_refCustomTileProcess<T>(params);
      break;
    default:
      status = CUSTOM_FAIL;
      break;
  }

  return status;
}


// int32_t TIDL_customOpsDspProcess(void* tidlHandle, const CustomOpParams& params)
// {
//   int32_t status = CUSTOM_SUCCESS;
//   TIDL_DataflowInitParams initParams;
//   TIDL_CustomOpsIxXOxXInitArgs kernelInitArgs;

//   // Populate the kernel arguments
//   kernelInitArgs.funcStyle  = TIDL_CUSTOM_FUNCTION_OPTIMIZED_C7X;
//   kernelInitArgs.customOpParams     = &params;

//   // Setup the TIDL dataflow interface
//   initParams.dataFlowType     = TIDL_DataFlowTypeOneTensorInProcOneChannel;
//   initParams.getHandleSize    = TIDL_customOps_ixX_oxX_getHandleSize;
//   initParams.initFuncPtr      = TIDL_customOps_ixX_oxX_init;
//   initParams.execFuncPtr      = TIDL_customOps_ixX_oxX_exec;
//   initParams.kernelInitArgs   = &kernelInitArgs;

//   // Initialize dataflow
//   status = TIDL_DataflowInit(tidlHandle, &initParams);

//   // Ensure cache coherency
//   TIDL_enableL1DandL2CacheWb();

//   if (status != 0)
//   {
//     status = CUSTOM_FAIL;
//   }

//   // Run the kernel on input/output buffers
//   if (status == CUSTOM_SUCCESS)
//   {
//     status = TIDL_DataflowProcess(tidlHandle, params.inDataPtrs, params.outDataPtrs);
//   }

//   return status;
// }


template <typename T>
int32_t TIDL_customOpsProcessCore(void* tidlHandle,
                                  const CustomOpParams& params,
                                  int32_t execMode)
{
  int32_t status = CUSTOM_SUCCESS;


  if ((TIDL_EXEC_MODE_STATS_COLLECTION == execMode) ||
      (TIDL_EXEC_MODE_INFER_PROCESS_REF == execMode) )
  {
    status = TIDL_customOpsProcessRef<T>(tidlHandle, params);
  }
  else if (TIDL_EXEC_MODE_INFER_PROCESS == execMode)
  {
  /* 
   * Temporarily redirect DSP optimized code to Ref implementation, due to issues with the tiling buffer descriptor,
   * which currently prevents successful execution through the dataflow.
   */
    // status = TIDL_customOpsDspProcess(tidlHandle, params);
    status = TIDL_customOpsProcessRef<T>(tidlHandle, params);
  }

  return status;
}

/* This function assumes that the additional input has been stored in the custom layer params */
static void* getAdditionalInputBuf(void* params)
{
  // Sanity check: Ensure the params pointer is valid
  if(params == nullptr)
  {
    /* Critical error: 'params' is unexpectedly null. This indicates an issue with assigning CustomOpParams to the weights buffer. */
    tidl_printf(0, "[TIDL_CUSTOM_PROCESS] FATEL Cannot access custom layers params. Buffer is NULL.");

    return nullptr;
  }

  /* Return the second input data after offesting the size of CustomOpParams, as CustomOpParam is stored at the beginning of the buffer */
  uint8_t * paramsBytesPtr  = static_cast<uint8_t*>(params);            // Cast void* to uint8_t* for pointer arithmetic
  uint8_t* dataPtr          = paramsBytesPtr + sizeof(CustomOpParams);  // Pointer adjusted by customOpParamsSize
  
  return reinterpret_cast<void*>(dataPtr);
}

CustomOpParams* getCustomOpParams(void* params)
{
  // Sanity check: Ensure the params pointer is valid
  if(params == nullptr)
  {
    /* Critical error: 'params' is unexpectedly null. This indicates an issue with assigning CustomOpParams to the weights buffer. */
    tidl_printf(0, "[TIDL_CUSTOM_PROCESS] FATEL Cannot access custom layers params. Buffer is NULL.");
    
    return nullptr;
  }

  // Return the CustomOpParams stored at the beginning of the buffer
  return reinterpret_cast<CustomOpParams*>(params);
}

const CustomOpParams* populateCustomOpParams(void* params, sTIDL_Layer_t* tidlLayerPtr, void** inDataPtrs, void** outDataPtrs)
{
  CustomOpParams* customOpParamsPtr = getCustomOpParams(params);
  
  if(nullptr != customOpParamsPtr)
  {
    assert(tidlLayerPtr->numInBufs <= CUSTOM_MAX_INPUTS_SUPPORTED);
    customOpParamsPtr->customLayerType = tidlLayerPtr->layerParams.customParams.customLayerType;

    /* Populate input params */
    customOpParamsPtr->inputCount   = tidlLayerPtr->numInBufs;
    customOpParamsPtr->inDataPtrs   = inDataPtrs;
    /* For each input */
    for(size_t i = 0; i < customOpParamsPtr->inputCount; i++)
    {
      const sTIDL_DataParams_t* inDataParams  = &(tidlLayerPtr->inDataPtr[i]);
      customOpParamsPtr->inDataType[i]    = inDataParams->elementType;
      customOpParamsPtr->inTensorScale[i] = inDataParams->tensorScale;
      customOpParamsPtr->inDataCount[i]   = inDataParams->dimValues[TIDL_DIM_BATCH] *
                                    inDataParams->dimValues[TIDL_DIM_DIM1] *
                                    inDataParams->dimValues[TIDL_DIM_DIM2] *
                                    inDataParams->dimValues[TIDL_DIM_NUMCH] *
                                    inDataParams->dimValues[TIDL_DIM_HEIGHT] *
                                    inDataParams->dimValues[TIDL_DIM_WIDTH];
      /* Copy input dimension values */          
      for(size_t dimIdx = 0; dimIdx < TIDL_DIM_MAX; dimIdx++)
      {
        customOpParamsPtr->inDataDims[i][dimIdx] = inDataParams->dimValues[dimIdx];
      }
    }

    /* Initialize the second input */
    customOpParamsPtr->additionalInput = nullptr;
    if(customOpParamsPtr->inputCount > 1)
    {
      customOpParamsPtr->additionalInput = getAdditionalInputBuf(params); 
    }

    /* Populate output params */
    const sTIDL_DataParams_t* outDataParams = &(tidlLayerPtr->outData);
    customOpParamsPtr->outDataPtrs    = outDataPtrs;
    customOpParamsPtr->outDataType    = outDataParams->elementType;
    customOpParamsPtr->outTensorScale = outDataParams->tensorScale;
    customOpParamsPtr->outDataCount   = outDataParams->dimValues[TIDL_DIM_BATCH] *
                                outDataParams->dimValues[TIDL_DIM_DIM1] *
                                outDataParams->dimValues[TIDL_DIM_DIM2] *
                                outDataParams->dimValues[TIDL_DIM_NUMCH] *
                                outDataParams->dimValues[TIDL_DIM_HEIGHT] *
                                outDataParams->dimValues[TIDL_DIM_WIDTH];
    /* Copy output dimension values */          
    for(size_t dimIdx = 0; dimIdx < TIDL_DIM_MAX; dimIdx++)
    {
      customOpParamsPtr->outDataDims[dimIdx] = outDataParams->dimValues[dimIdx];
    }
  }

  return const_cast<const CustomOpParams*>(customOpParamsPtr);
}

int32_t TIDL_customOpsProcess(void* tidlHandle,
    sTIDL_Layer_t* tidlLayer,
    void* inPtrs[],
    void* outPtrs[],
    void* params,
    void* dmaUtilsContext,
    const sTIDL_sysMemHandle_t sysMems[TIDL_SYSMEM_MAX],
    int32_t execMode)
{
  int32_t status = CUSTOM_SUCCESS;
  const CustomOpParams* customOpParamsPtr = populateCustomOpParams(params, tidlLayer, inPtrs, outPtrs);

  /* Critical error: 'params' is unexpectedly null. This indicates an issue with assigning CustomOpParams to the weights buffer. */
  if(nullptr == customOpParamsPtr)
  {
    return CUSTOM_FAIL;
  }

  const CustomOpParams& customOpParams = *customOpParamsPtr; // Ref

  /* Assuming that all inputs and the output are of the same type , exept scatter elements case */
  int32_t elementType = customOpParams.inDataType[0];

  if(TIDL_UnsignedChar == elementType)
  {
    status = TIDL_customOpsProcessCore<uint8_t>(tidlHandle, customOpParams, execMode);
  }
  else if(TIDL_SignedChar == elementType)
  {
    status = TIDL_customOpsProcessCore<int8_t>(tidlHandle, customOpParams, execMode);
  }
  else if(TIDL_UnsignedShort == elementType)
  {
    status = TIDL_customOpsProcessCore<uint16_t>(tidlHandle, customOpParams, execMode);
  }
  else if(TIDL_SignedShort == elementType)
  {
    status = TIDL_customOpsProcessCore<int16_t>(tidlHandle, customOpParams, execMode);
  }
  else if (TIDL_SinglePrecFloat == elementType)
  {
    status = TIDL_customOpsProcessCore<float32_tidl>(tidlHandle, customOpParams, execMode);
  }
  else
  {
    status = CUSTOM_FAIL;
  }

  return status;
}
