#include "resnet.h"
#include <fcntl.h>      // NOLINT(build/include_order)
#include <getopt.h>     // NOLINT(build/include_order)
#include <sys/time.h>   // NOLINT(build/include_order)
#include <sys/types.h>  // NOLINT(build/include_order)
#include <sys/uio.h>    // NOLINT(build/include_order)
#include <unistd.h>     // NOLINT(build/include_order)

#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc/imgproc_c.h"
#include <iostream>
#include <cstring>
#include <algorithm>
#include <functional>
#include <queue>

#include "itidl_rt.h"

#define LOG(x) std::cerr

void* resnet_in_ptrs[16] = {NULL};
void* resnet_out_ptrs[16]= {NULL};

double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }

void saveFeatureVector(const std::string& filename, float* features, int feature_dim) {
  std::ofstream file(filename);
  if (!file) {
    LOG(ERROR) << "Cannot open file " << filename << " for writing\n";
    return;
  }
  
  for (int i = 0; i < feature_dim; i++) {
    file << features[i];
    if (i < feature_dim - 1) {
      file << ",";
    }
  }
  file << std::endl;
  file.close();
  LOG(INFO) << "Feature vector saved to " << filename << "\n";
}

void printFeatureStats(float* features, int feature_dim) {
  float min_val = features[0];
  float max_val = features[0];
  float sum = 0;
  
  for (int i = 0; i < feature_dim; i++) {
    if (features[i] < min_val) min_val = features[i];
    if (features[i] > max_val) max_val = features[i];
    sum += features[i];
  }
  
  float mean = sum / feature_dim;
  
  LOG(INFO) << "Feature vector statistics:\n";
  LOG(INFO) << "  Dimension: " << feature_dim << "\n";
  LOG(INFO) << "  Min value: " << min_val << "\n";
  LOG(INFO) << "  Max value: " << max_val << "\n";
  LOG(INFO) << "  Mean value: " << mean << "\n";
}

template <class T>
int preprocImage(const std::string &input_image_name, T *out, int wanted_height, int wanted_width, int wanted_channels, float mean, float scale)
{
    int i;
    uint8_t *pSrc;
    cv::Mat image = cv::imread(input_image_name, cv::IMREAD_COLOR);
    cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
    cv::resize(image, image, cv::Size(wanted_width, wanted_height), 0, 0, cv::INTER_AREA);

    if (image.channels() != wanted_channels)
    {
      printf("Warning : Number of channels wanted differs from number of channels in the actual image \n");
      return (-1);
    }
    pSrc = (uint8_t *)image.data;
    for (i = 0; i < wanted_height * wanted_width * wanted_channels; i++)
      out[i] = ((T)pSrc[i] - mean) / scale;
    return 0;
}

void getModelNameromArtifactsDir(char* path, char * net_name, char *io_name)
{
  char sys_cmd[500];
  int status = 0;
  sprintf(sys_cmd, "ls %s/*net.bin | head -1", path);
  FILE * fp = popen(sys_cmd,  "r");
  if (fp == NULL)
  {
    printf("Error while runing command : %s", sys_cmd);
  }
  status = fscanf(fp, "%s", net_name);
  fclose(fp);

  sprintf(sys_cmd, "ls %s/*io_buff1.bin | head -1", path);
  fp = popen(sys_cmd,  "r");
  if (fp == NULL)
  {
    printf("Error while runing command : %s", sys_cmd);
  }
  status = fscanf(fp, "%s", io_name);
  fclose(fp);
  return;
}

int32_t TIDLReadBinFromFile(const char *fileName, void *addr, int32_t size)
{
    FILE *fptr = NULL;
    fptr = fopen((const char *)fileName, "rb");
    if (fptr)
    {
      size_t fsize;
      fsize = fread(addr, size, 1, fptr);
      fclose(fptr);
      return 0;
    }
    else
    {
      printf("Could not open %s file for reading \n", fileName);
    }
    return -1;
}

int RunResNetInference(ResNetSettings* s) {

  char net_name[512];
  char io_name[512];

  getModelNameromArtifactsDir((char *)s->artifact_path.c_str(), net_name, io_name);

  printf("Model Files names : %s,%s\n", net_name, io_name);

  sTIDLRT_Params_t prms;
  void *handle = NULL;
  int32_t status;

  status = TIDLRT_setParamsDefault(&prms);

  FILE * fp_network = fopen(net_name, "rb");
  if (fp_network == NULL)
  {
    printf("Invoke  : ERROR: Unable to open network file %s \n", net_name);
    return -1;
  }
  prms.stats = (sTIDLRT_PerfStats_t*)malloc(sizeof(sTIDLRT_PerfStats_t));

  fseek(fp_network, 0, SEEK_END);
  prms.net_capacity = ftell(fp_network);
  fseek(fp_network, 0, SEEK_SET);
  fclose(fp_network);
  prms.netPtr = malloc(prms.net_capacity);

  status = TIDLReadBinFromFile(net_name, prms.netPtr, prms.net_capacity);

  FILE * fp_config = fopen(io_name, "rb");
  if (fp_config == NULL)
  {
    printf("Invoke  : ERROR: Unable to open IO config file %s \n", io_name);
    return -1;
  }
  fseek(fp_config, 0, SEEK_END);
  prms.io_capacity = ftell(fp_config);
  fseek(fp_config, 0, SEEK_SET);
  fclose(fp_config);
  prms.ioBufDescPtr = malloc(prms.io_capacity);
  status = TIDLReadBinFromFile(io_name, prms.ioBufDescPtr, prms.io_capacity);

  status = TIDLRT_create(&prms, &handle);

  sTIDLRT_Tensor_t *in[16];
  sTIDLRT_Tensor_t *out[16];

  sTIDLRT_Tensor_t in_tensor;
  sTIDLRT_Tensor_t out_tensor;

  int32_t j = 0;
  in[j] = &in_tensor;
  status = TIDLRT_setTensorDefault(in[j]);
  in[j]->layout = TIDLRT_LT_NHWC;
  in[j]->elementType = TIDLRT_Uint8;
  int32_t in_tensor_size = 112 * 112 * 3 * sizeof(uint8_t);

  if (s->device_mem)
  { 
      in[j]->ptr =  TIDLRT_allocSharedMem(64, in_tensor_size);
      in[j]->memType = TIDLRT_MEM_SHARED;
  }
  else
  {
      in[j]->ptr =  malloc(in_tensor_size);
  }

  out[j] = &out_tensor;
  status = TIDLRT_setTensorDefault(out[j]);
  out[j]->layout = TIDLRT_LT_NHWC;
  out[j]->elementType = TIDLRT_Float32;
  int32_t out_tensor_size = s->feature_dim * sizeof(float);

  if (s->device_mem)
  { 
      out[j]->ptr =  TIDLRT_allocSharedMem(64, out_tensor_size);
      out[j]->memType = TIDLRT_MEM_SHARED;
  }
  else
  {
      out[j]->ptr =  malloc(out_tensor_size);
  }

  status = preprocImage<uint8_t>(s->input_image_name, (uint8_t*)in[j]->ptr, 112, 112, 3, s->input_mean, s->input_std);
  LOG(INFO) << "invoked \n";

  struct timeval start_time, stop_time;
  gettimeofday(&start_time, nullptr);

  for (int i = 0; i < s->loop_count; i++)
  {
    TIDLRT_invoke(handle, in, out);
  }

  gettimeofday(&stop_time, nullptr);

  LOG(INFO) << "average time: "
            << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
            << " ms \n";

  float *output = (float *)out[j]->ptr;
  
  printFeatureStats(output, s->feature_dim);
  
  std::string output_filename = s->input_image_name + ".features.txt";
  saveFeatureVector(output_filename, output, s->feature_dim);
  
  if (s->verbose) {
    LOG(INFO) << "First 10 feature values: ";
    for (int i = 0; i < std::min(10, s->feature_dim); i++) {
      LOG(INFO) << "  [" << i << "]: " << output[i] << "\n";
    }
  }

  status = TIDLRT_deactivate(handle);
  status = TIDLRT_delete(handle);

  if (s->device_mem)
  {
    for (uint32_t i = 0; i < 1; i++)
    {
      if (resnet_in_ptrs[i])
      {
        TIDLRT_freeSharedMem(in[i]->ptr);
      }
    }
    for (uint32_t i = 0; i < 1; i++)
    {
      if (resnet_out_ptrs[i])
      {
        TIDLRT_freeSharedMem(out[i]->ptr);
      }
    }
  }
  return 0;
}

void display_usage() {
  LOG(INFO)
      << "resnet_feature_extraction\n"
      << "--device_mem, -d: [0|1], use device memory or not\n"
      << "--artifact_path, -f: Path for Delegate artifacts folder \n"
      << "--count, -c: loop interpreter->Invoke() for certain times\n"
      << "--input_mean, -b: input mean\n"
      << "--input_std, -s: input standard deviation\n"
      << "--image, -i: image_name.jpg (required)\n"
      << "--feature_dim, -n: feature dimension (default: 512)\n"
      << "--profiling, -p: [0|1], profiling or not\n"
      << "--verbose, -v: [0|1] print more information\n"
      << "--warmup_runs, -w: number of warmup runs\n"
      << "\n";
}

int main(int argc, char** argv) {
  ResNetSettings s;

  int c;
  while (1) {
    static struct option long_options[] = {
        {"device_mem", required_argument, nullptr, 'd'},
        {"artifact_path", required_argument, nullptr, 'f'},
        {"count", required_argument, nullptr, 'c'},
        {"verbose", required_argument, nullptr, 'v'},
        {"image", required_argument, nullptr, 'i'},
        {"feature_dim", required_argument, nullptr, 'n'},
        {"profiling", required_argument, nullptr, 'p'},
        {"input_mean", required_argument, nullptr, 'b'},
        {"input_std", required_argument, nullptr, 's'},
        {"num_results", required_argument, nullptr, 'r'},
        {"warmup_runs", required_argument, nullptr, 'w'},
        {nullptr, 0, nullptr, 0}};

    int option_index = 0;

    c = getopt_long(argc, argv,
                    "b:c:d:f:i:n:p:s:v:w:", long_options,
                    &option_index);

    if (c == -1) break;

    switch (c) {
      case 'b':
        s.input_mean = strtod(optarg, nullptr);
        break;
      case 'c':
        s.loop_count =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case 'd':
        s.device_mem =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case 'f':
        s.artifact_path = optarg;
        break;
      case 'i':
        s.input_image_name = optarg;
        break;
      case 'n':
        s.feature_dim =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case 'p':
        s.profiling =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case 's':
        s.input_std = strtod(optarg, nullptr);
        break;
      case 'v':
        s.verbose =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case 'w':
        s.number_of_warmup_runs =
            strtol(optarg, nullptr, 10);  // NOLINT(runtime/deprecated_fn)
        break;
      case 'h':
      case '?':
        display_usage();
        exit(-1);
      default:
        exit(-1);
    }
  }
  
  if (s.input_image_name.empty()) {
    LOG(ERROR) << "Error: Input image path is required. Use -i or --image option.\n";
    display_usage();
    return -1;
  }
  
  return (RunResNetInference(&s));
}