/* -*- c-file-style: "GNU" -*- */
/*
 * Copyright (C) CNRS, INRIA, Université Bordeaux 1, Télécom SudParis
 * See COPYING in top-level directory.
 */

#ifndef _REENTRANT
#define _REENTRANT
#endif

#include "eztrace-core/eztrace_config.h"
#include "eztrace-lib/eztrace.h"
#include "eztrace-lib/eztrace_otf2.h"
#include "eztrace-lib/eztrace_module.h"
#include "eztrace-core/eztrace_attributes.h"


#if USE_MPI
#include "mpi.h"
#include <otf2/OTF2_MPI_Collectives.h>
#include "eztrace-lib/eztrace_mpi.h"

#define CURRENT_MODULE mpi_core
DECLARE_CURRENT_MODULE;

static int _mpi_core_initialized = 0;

extern OTF2_CommRef comm_world_ref;
static int mpi_verbose = 0;

/* pointers to actual MPI functions (C version)  */
int (*libMPI_Init)(int*, char***);
int (*libMPI_Init_thread)(int*, char***, int, int*);
int (*libMPI_Comm_size)(MPI_Comm, int*);
int (*libMPI_Comm_rank)(MPI_Comm, int*);
int (*libMPI_Finalize)(void);
int (*libMPI_Initialized)(int*);
int (*libMPI_Barrier)(MPI_Comm);

/* fortran bindings */
void (*libmpi_init_)(int* e);
void (*libmpi_init_thread_)(int*, int*, int*);
void (*libmpi_finalize_)(int*);
void (*libmpi_comm_size_)(MPI_Comm*, int*, int*);
void (*libmpi_comm_rank_)(MPI_Comm*, int*, int*);

static int _mpi_init_called = 0;

int MPI_Finalize() {
  eztrace_stop();
  return libMPI_Finalize();
}


int MPI_Comm_size(MPI_Comm c, int* s) {
  return libMPI_Comm_size(c, s);
}

int MPI_Comm_rank(MPI_Comm c, int* r) {
  return libMPI_Comm_rank(c, r);
}


void ezt_mpi_initialize_trace();

void ezt_set_mpi_mode();

/* internal function
 * This function is used by the various MPI_Init* functions (C
 * and Fortran versions)
 * This function add informations to the trace (rank, etc.)
 * and set the trace filename.
 */
void _mpi_init_generic() {
  if(_mpi_init_called)
    return;

  set_recursion_shield_on();
  int ret __attribute__((__unused__));

  libMPI_Comm_size(MPI_COMM_WORLD, &mpi_infos.size);
  libMPI_Comm_rank(MPI_COMM_WORLD, &mpi_infos.rank);

  ret = asprintf(&mpi_infos.proc_id, "%d", mpi_infos.rank);

  // First, let's synchronize the MPI ranks clocks
  libMPI_Barrier(MPI_COMM_WORLD);
  first_timestamp = 0;
  // This initialize first_timestamps. From now on, the timestamps will be relative to the current timestamp
  ezt_get_timestamp();

  mpi_infos.mpi_any_source = MPI_ANY_SOURCE;
  mpi_infos.mpi_any_tag = MPI_ANY_TAG;
  mpi_infos.mpi_proc_null = MPI_PROC_NULL;
  mpi_infos.mpi_request_null = (app_ptr)MPI_REQUEST_NULL;
  mpi_infos.mpi_comm_world = (app_ptr)MPI_COMM_WORLD;
  mpi_infos.mpi_comm_self = (app_ptr)MPI_COMM_SELF;

  /* initialize communicators */
  int hashtable_size = 128; /* todo: make that limit configurable ? */
  ezt_hashtable_init(&mpi_infos.mpi_communicators,
		     hashtable_size);
  
  todo_set_status("mpi_init", init_complete);
  ezt_otf2_set_mpi_rank(mpi_infos.rank, mpi_infos.size);
  todo_wait("ezt_otf2", init_complete);

  /* now mpi_comm_world has been registered to OTF2 */
  ezt_hashtable_insert(&mpi_infos.mpi_communicators,
		       _ezt_hash_mpi_comm(MPI_COMM_WORLD),
		       &comm_world_ref);
  EZT_New_MPI_Comm(MPI_COMM_SELF);
  _mpi_init_called = 1;
  ezt_set_mpi_mode();
  set_recursion_shield_off();
}


/* This function record initialization events. It is called during mpi_init if
 * autostart is enabled or when eztrace_start is reached
 */
void ezt_mpi_initialize_trace() {
  static int mpi_trace_initialized = 0;
  if (_ezt_trace.status == ezt_trace_status_running || _ezt_trace.status == ezt_trace_status_paused) {
    if (!mpi_trace_initialized) {
      assert(_mpi_init_called);
      mpi_trace_initialized = 1;
    }
  }
}

int MPI_Init_thread(int* argc, char*** argv, int required, int* provided) {
  eztrace_log(dbg_lvl_debug, "MPI_Init_thread intercepted\n");
  INSTRUMENT_FUNCTIONS(mpi_core);
  int ret = libMPI_Init_thread(argc, argv, required, provided);
  _mpi_init_generic();
  eztrace_log(dbg_lvl_debug, "End of MPI_Init_thread interception\n");
  return ret;
}

int MPI_Init(int* argc, char*** argv) {
  eztrace_log(dbg_lvl_debug, "MPI_Init intercepted\n");
  INSTRUMENT_FUNCTIONS(mpi_core);
  int ret = libMPI_Init(argc, argv);
  _mpi_init_generic();
  eztrace_log(dbg_lvl_debug, "End of MPI_Init interception\n");
  return ret;
}

int (*libMPI_Send)(CONST void* buf, int count, MPI_Datatype datatype, int dest,
		     int tag, MPI_Comm comm);
int (*libMPI_Recv)(void* buf, int count, MPI_Datatype datatype, int source,
                   int tag, MPI_Comm comm, MPI_Status* status);
int (*libMPI_Barrier)(MPI_Comm);
int (*libMPI_Reduce)(CONST void*, void*, int, MPI_Datatype, MPI_Op, int,
                     MPI_Comm);

static int _EZT_MPI_Recv(void* buffer, size_t size, int src, int tag) {
  if(mpi_verbose)
    eztrace_log(dbg_lvl_verbose, "[%d] %s(buffer=%p, size=%lu, src=%d, tag=%x)\n",
		mpi_infos.rank, __func__, buffer, size, src, tag);
  int ret = libMPI_Recv(buffer, size, MPI_BYTE, src, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
  if(ret == MPI_SUCCESS)
    return 0;
  eztrace_warn("%s failed\n",  __func__);
  return 1;
}

static int _EZT_MPI_Send(void* buffer, size_t size, int dest, int tag) {
  if(mpi_verbose)
    eztrace_log(dbg_lvl_verbose, "[%d] %s(buffer=%p, size=%lu, dest=%d, tag=%x)\n",
		mpi_infos.rank, __func__, buffer, size, dest, tag);
  int ret = libMPI_Send(buffer, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
  if(ret == MPI_SUCCESS)
    return 0;
  eztrace_warn("%s failed\n",  __func__);
  return 1;
}

static int _EZT_MPI_Reduce(const void *sendbuf, void *recvbuf, int count,
			   enum EZT_MPI_Datatype datatype, enum EZT_MPI_Op op, int root) {
  if(mpi_verbose) {
    eztrace_log(dbg_lvl_normal, "[%d] %s(sendbuf=%p, recvbuf=%p, count=%d, type=%x, op=%x, root=%d)\n",
		mpi_infos.rank, __func__, sendbuf, recvbuf, count, datatype, op, root);
  }

  int ret = libMPI_Reduce(sendbuf, recvbuf, count, EZT_DATATYPE_TO_MPI(datatype),
			  EZT_OP_TO_MPI(op), root, MPI_COMM_WORLD);

  if(ret == MPI_SUCCESS)
    return 0;
  eztrace_warn("%s failed\n", __func__);
  return 1;
}

static int _EZT_MPI_SetMPICollectiveCallbacks(OTF2_Archive *archive) {
  OTF2_ErrorCode err = OTF2_MPI_Archive_SetCollectiveCallbacks(archive,
							       MPI_COMM_WORLD,
							       MPI_COMM_SELF);
  if(err == OTF2_SUCCESS)
    return 0;
  eztrace_warn("%s failed\n", __func__);
  return 1;
}

static int _EZT_MPI_Barrier() {
  int ret = libMPI_Barrier(MPI_COMM_WORLD);
  if(ret == MPI_SUCCESS)
    return 0;
  eztrace_warn("%s failed\n", __func__);
  return 1;
}

static double _EZT_MPI_Wtime() {
  return MPI_Wtime();
}

void EZT_New_MPI_Comm_dummy(MPI_Comm c    MAYBE_UNUSED) { printf("%s\n", __FUNCTION__);}
void EZT_Delete_MPI_Comm_dummy(MPI_Comm c MAYBE_UNUSED) { }


PPTRACE_START_INTERCEPT_FUNCTIONS(mpi_core)
     INTERCEPT3("MPI_Init_thread", libMPI_Init_thread)
     INTERCEPT3("MPI_Init", libMPI_Init)
     INTERCEPT3("MPI_Finalize", libMPI_Finalize)
     INTERCEPT3("MPI_Comm_size", libMPI_Comm_size)
     INTERCEPT3("MPI_Comm_rank", libMPI_Comm_rank)
     INTERCEPT3("MPI_Send", libMPI_Send)
     INTERCEPT3("MPI_Recv", libMPI_Recv)
     INTERCEPT3("MPI_Barrier", libMPI_Barrier)
     INTERCEPT3("MPI_Reduce", libMPI_Reduce)
     INTERCEPT3("mpi_init_", libmpi_init_)
     INTERCEPT3("mpi_init_thread_", libmpi_init_thread_)
     INTERCEPT3("mpi_finalize_", libmpi_finalize_)
     INTERCEPT3("mpi_comm_size_", libmpi_comm_size_)
     INTERCEPT3("mpi_comm_rank_", libmpi_comm_rank_)
PPTRACE_END_INTERCEPT_FUNCTIONS(mpi_core);

void ezt_mpi_init() {
  todo_set_status("ezt_mpi", init_started);
  EZT_MPI_Recv = _EZT_MPI_Recv;
  EZT_MPI_Send = _EZT_MPI_Send;
  EZT_MPI_Reduce = _EZT_MPI_Reduce;
  EZT_MPI_SetMPICollectiveCallbacks = _EZT_MPI_SetMPICollectiveCallbacks;
  EZT_MPI_Barrier = _EZT_MPI_Barrier;
  EZT_MPI_Wtime = _EZT_MPI_Wtime;

  /* The MPI module may has set the EZT_New_MPI_Comm and EZT_Delete_MPI_Comm callback.
   * Make sure we don't override them.
   */
  if(!EZT_New_MPI_Comm)
      EZT_New_MPI_Comm = EZT_New_MPI_Comm_dummy;
  if(!EZT_Delete_MPI_Comm)
      EZT_Delete_MPI_Comm = EZT_Delete_MPI_Comm_dummy;
  todo_set_status("ezt_mpi", init_complete);
}

void _eztrace_init_mpi_core() {
  if(! _mpi_core_initialized) {

    INSTRUMENT_FUNCTIONS(mpi_core);
    ezt_hashtable_init(&mpi_infos.ezt_mpi_requests, 128);
    ezt_hashtable_init(&mpi_infos.ezt_mpi_persistent_requests, 128);

    if (eztrace_autostart_enabled())
      eztrace_start();

  }
  _mpi_core_initialized = 1;
}

static void finalize_mpi_core() {
  if(_mpi_core_initialized) {
    _mpi_core_initialized = 0;
    eztrace_stop();

    ezt_hashtable_finalize(&mpi_infos.mpi_communicators);
    ezt_hashtable_finalize(&mpi_infos.ezt_mpi_requests);
    free(mpi_infos.proc_id);
  }
}

static void _mpi_core_init(void) __attribute__((constructor));
static void _mpi_core_init(void) {
   EZT_REGISTER_MODULE(mpi_core, "EZTrace MPI core",
		      _eztrace_init_mpi_core, finalize_mpi_core);
  enqueue_todo("ezt_mpi", ezt_mpi_init, NULL, status_invalid);
}

static void _mpi_core_conclude(void) __attribute__((destructor));
static void _mpi_core_conclude(void) {
  if(_mpi_core_initialized) {
    _mpi_core_initialized = 0;
    eztrace_stop();
  }
}

#endif

