/******************************************************************************
 * Copyright (c) 2013-2014, Texas Instruments Incorporated - http://www.ti.com/
 *   All rights reserved.
 *
 *   Redistribution and use in source and binary forms, with or without
 *   modification, are permitted provided that the following conditions are met:
 *       * Redistributions of source code must retain the above copyright
 *         notice, this list of conditions and the following disclaimer.
 *       * Redistributions in binary form must reproduce the above copyright
 *         notice, this list of conditions and the following disclaimer in the
 *         documentation and/or other materials provided with the distribution.
 *       * Neither the name of Texas Instruments Incorporated nor the
 *         names of its contributors may be used to endorse or promote products
 *         derived from this software without specific prior written permission.
 *
 *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *   AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *   IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 *   ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 *   LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 *   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 *   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 *   INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 *   CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 *   ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
 *   THE POSSIBILITY OF SUCH DAMAGE.
 *****************************************************************************/
#include <iostream>
#include <cstdlib>
#include <iomanip>
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <math.h>
#include <signal.h>
#include <ocl_util.h>
#include "kernel.dsp_h"

extern "C" {
#include "cblas.h"
}

using namespace std;
using namespace cl;

#include "gemm_dsp.h"

/*
extern "C" void test_dsp_finish()
{
    printf("test for lib!!!!!!!!\n");
}
*/

int L2_BUF_SIZE                     = 0;
int MSMC_BUF_SIZE                   = 0;
int NUMAPANELS                      = 0;
int NUMBPANELS                      = 0;
int NUMCOMPUNITS                    = 0;

Context* context = NULL;//cl_context
std::vector<Device>* devices = NULL;
CommandQueue* commandQueue = NULL;//cl_command_queue
Kernel* kernel = NULL;
Program* program = NULL;
Program::Binaries* binary  = NULL;
Buffer* bufMsmc = NULL;
KernelFunctor kernelFunc;

#define RELEASE_OBJ(obj)    { if (obj)  {delete obj; obj = NULL;}}

cl_int dsp_init_mem(Device& device);

static ulong roundDownPower2(ulong value)
{ return (value == 0) ? 0 :  1 << ilogb(value); }

extern "C" void dsp_release_opencl()
{
    commandQueue->finish();
    
    RELEASE_OBJ(context);
    RELEASE_OBJ(devices);
    RELEASE_OBJ(commandQueue);
    RELEASE_OBJ(kernel);
    RELEASE_OBJ(program);
    RELEASE_OBJ(binary);
    RELEASE_OBJ(bufMsmc);

    return;
}

extern "C" int dsp_init_opencl()
{
    int ret = 0;
    cl_int status = 0;
    int deviceListSize;

    context = new Context(CL_DEVICE_TYPE_ACCELERATOR);    
    devices = new std::vector<Device> (context->getInfo<CL_CONTEXT_DEVICES>());
    commandQueue = new CommandQueue(*context, devices[0][0]);

    /*---------------------------------------------------------------------
    * Compile the Kernel Source for the devices
    *--------------------------------------------------------------------*/
    binary = new Program::Binaries(1, std::make_pair(kernel_dsp_bin, sizeof(kernel_dsp_bin)));

    program = new Program(*context, *devices, *binary);
    program->build(*devices);

    if (CL_SUCCESS != (status = dsp_init_mem(devices[0][0])))
    {
        printf("Error%d: dsp_init_mem\n", status);
        ret = EXIT_FAILURE;
        goto EXIT;
    }

//    kernel = Kernel(program, "K_ocl_sgemm_dsp");
//    kernelFunc = kernel.bind(commandQueue, NDRange(NUMCOMPUNITS), NDRange(1));

EXIT:
    if (ret)
    {
        dsp_release_opencl();
    }
    return ret;
}

cl_int dsp_init_mem(Device& device)
{
    int APanelSz        = 8  << 10;
    int BPanelSz        = 16 << 10;
    cl_ulong l2_mem     = 0;
    cl_ulong msmc_mem   = 0;

    device.getInfo(CL_DEVICE_MAX_COMPUTE_UNITS, &NUMCOMPUNITS);
    if (0 == NUMCOMPUNITS)
        return EXIT_FAILURE;

    device.getInfo(CL_DEVICE_LOCAL_MEM_SIZE,    &l2_mem);

#ifdef CL_DEVICE_MSMC_MEM_SIZE_TI
    device.getInfo(CL_DEVICE_MSMC_MEM_SIZE_TI,  &msmc_mem);
#endif

    L2_BUF_SIZE   = roundDownPower2(l2_mem);
    MSMC_BUF_SIZE = roundDownPower2(msmc_mem);

    NUMAPANELS    = L2_BUF_SIZE / 2 / APanelSz;
    NUMBPANELS    = L2_BUF_SIZE / 4 / BPanelSz;

    if ((NUMCOMPUNITS * APanelSz * NUMAPANELS) > MSMC_BUF_SIZE)
         MSMC_BUF_SIZE = 0;
    else MSMC_BUF_SIZE = NUMCOMPUNITS * APanelSz * NUMAPANELS;

    if (MSMC_BUF_SIZE != 0)
        bufMsmc = new Buffer(*context, CL_MEM_READ_WRITE|CL_MEM_USE_MSMC_TI,
                             MSMC_BUF_SIZE);
    else
        bufMsmc = new Buffer(*context, CL_MEM_READ_WRITE, 4); // dummy one

    return CL_SUCCESS;
}

/*n is the number of floats!*/
Buffer *dsp_mem_alloc_copy(float* x, int n, cl_mem_flags type)
{
    Buffer* buf;
    float* x_dsp;
    int size = sizeof(float) * n;

    buf = new Buffer(*context, type,  size);
    if (buf == NULL)
        return NULL;

    x_dsp = (float*) commandQueue->enqueueMapBuffer(*buf, CL_TRUE, CL_MAP_WRITE, 0, size);
    memcpy(x_dsp, x, size);
    
    return buf;
}

extern "C" float *__dsp_malloc(int n, cl_mem_flags type, void** cl_buf)
{
    float* x_dsp;
    Buffer* buf;
    cl_int status;
    int size = sizeof(float) * n;

    buf = new Buffer(*context, type,  size);
    if (buf == NULL)
        return NULL;

    x_dsp = (float*) commandQueue->enqueueMapBuffer(*buf, CL_TRUE, CL_MAP_WRITE, 0, size);

    if (cl_buf)
    {
        *cl_buf = buf;
    }
    
    return x_dsp;
}

void dsp_free(Buffer* x)
{
    RELEASE_OBJ(x);
}

/*x_dsp could be NULL*/
extern "C" void __dsp_free(float *x_dsp, void* cl_buf)
{

    if (!cl_buf)
    {
        cerr << "ERROR:  Buffer is null!" << endl;
        return;
    }
    if (x_dsp)
    {
        commandQueue->enqueueUnmapMemObject(*(Buffer*)cl_buf, x_dsp);
        delete (Buffer*)cl_buf;
    }
}

static void transpose_matrix(float *a, int rows, int cols)
{
    float *transpose = (float*)calloc(rows * cols, sizeof(float));
    int x, y;
    int loop = (cols >> 2) << 2;
    for(x = 0; x < rows; ++x) {
        for(y = 0; y < loop; y += 4) {
            transpose[y * rows + x] = a[x * cols + y];
            transpose[(y + 1) * rows + x] = a[x * cols + y + 1];
            transpose[(y + 2) * rows + x] = a[x * cols + y + 2];
            transpose[(y + 3) * rows + x] = a[x * cols + y + 3];
        }
        for(; y < cols; ++y) {
            transpose[y * rows + x] = a[x * cols + y];
        }
    }
    memcpy(a, transpose, rows * cols * sizeof(float));
    free(transpose);
}

extern "C" void __cblas_sgemm(int TA, int TB, int M, int N, int K, float ALPHA, 
        float* A, int lda, 
        float* B, int ldb,
        const float BETA,
        float* C, int ldc)
/*(int TA, int TB, int M, int N, int K, float ALPHA, 
        float* A, void* A_cl, int lda, 
        float* B, void* B_cl, int ldb,
        const float BETA,
        float* C, void* C_cl, int ldc)*/
{
    enum CBLAS_ORDER order = CblasRowMajor; /*Fixed value*/
/*
    Buffer* a = (Buffer*)A_cl;
    Buffer* b = (Buffer*)B_cl;
    Buffer* c = (Buffer*)C_cl;
*/
    Buffer* a = dsp_mem_alloc_copy(A, (M*K), CL_MEM_READ_ONLY);
    Buffer* b = dsp_mem_alloc_copy(B, (N*K), CL_MEM_READ_ONLY);
    Buffer* c = dsp_mem_alloc_copy(C, (M*N), CL_MEM_READ_WRITE);

    kernel = new Kernel(*program, "K_ocl_sgemm_dsp");
    NDRange* global = new NDRange(NUMCOMPUNITS);//test
    NDRange* local = new NDRange(1);
    kernelFunc = kernel->bind(*commandQueue, *global, *local);

    if (TA)
    {
        transpose_matrix(A, M, K);
    }

    if (TB)
    {
        transpose_matrix(B, K, N);
    }

    try
    {
       if (order == CblasRowMajor)
           kernelFunc(N, M, K, ALPHA, b, N, a, K, BETA, c, N,
                  NUMAPANELS, NUMBPANELS,
                  __local(L2_BUF_SIZE), *bufMsmc).wait();
       else
           kernelFunc(M, N, K, ALPHA, a, M, b, K, BETA, c, M,
                  NUMAPANELS, NUMBPANELS,
                  __local(L2_BUF_SIZE), *bufMsmc).wait();    
    }
    catch (Error err)
   {
       cerr << "ERROR: " << err.what() << "(" << err.err() << ", "
            << ocl_decode_error(err.err()) << ")" << endl;
       exit(-1);
   }

    RELEASE_OBJ(kernel);
    RELEASE_OBJ(global);
    RELEASE_OBJ(local);
    dsp_free(a);
    dsp_free(b);
    dsp_free(c);
    return;
}

