/* nag_kalman_sqrt_filt_info_invar (g13edc) Example Program.
 *
 * Copyright 2014 Numerical Algorithms Group.
 *
 * Mark 24, 2013.
 */

#include <nag.h>
#include <stdio.h>
#include <nag_stdlib.h>
#include <nagf03.h>
#include <nagf16.h>
#include <nagg13.h>

typedef enum { read, print }   ioflag;

static int ex1(void);
static int ex2(void);

int main(void)
{
  Integer  exit_status_ex1 = 0;
  Integer  exit_status_ex2 = 0;

  /* Skip the heading in the data file   */
  scanf("%*[^\n] ");

  printf("nag_kalman_sqrt_filt_info_invar (g13edc) Example Program "
          "Results\n\n");

  exit_status_ex1 = ex1();
  exit_status_ex2 = ex2();

  return (exit_status_ex1 == 0 && exit_status_ex2 == 0) ? 0 : 1;
}

#define AINV(I, J)  ainv[(I) *tdainv + J]
#define QINV(I, J)  qinv[(I) *tdqinv + J]
#define RINV(I, J)  rinv[(I) *tdrinv + J]
#define T(I, J)     t[(I) *tdt + J]
#define AINVB(I, J) ainvb[(I) *tdainvb + J]
#define C(I, J)     c[(I) *tdc + J]

static int ex1(void)
{
  Integer  exit_status = 0, i, istep, j, m, n, p, tdainv, tdainvb, tdc, tdqinv;
  Integer  tdrinv, tdt;
  double   *ainv = 0, *ainvb = 0, *c = 0, *qinv = 0, *rinv = 0, *rinvy = 0;
  double   *t = 0, tol, *x = 0, *z = 0;

  /* Nag Types */
  NagError fail;

  INIT_FAIL(fail);

  printf("Example 1\n");

  /* Skip the heading in the data file   */
  scanf("%*[^\n]");
  scanf("%ld%ld%ld%lf", &n, &m, &p, &tol);
  if (n >= 1 || m >= 1 || p >= 1)
    {
      if (!(ainv = NAG_ALLOC(n*n, double)) ||
          !(qinv = NAG_ALLOC(m*m, double)) ||
          !(rinv = NAG_ALLOC(p*p, double)) ||
          !(t = NAG_ALLOC(n*n, double)) ||
          !(ainvb = NAG_ALLOC(n*m, double)) ||
          !(c = NAG_ALLOC(p*n, double)) ||
          !(x = NAG_ALLOC(n, double)) ||
          !(z = NAG_ALLOC(m, double)) ||
          !(rinvy = NAG_ALLOC(p, double)))
        {
          printf("Allocation failure\n");
          exit_status = -1;
          goto END;
        }
      tdainv = n;
      tdqinv = m;
      tdrinv = p;
      tdt = n;
      tdainvb = m;
      tdc = n;
    }
  else
    {
      printf("Invalid n or m or p.\n");
      exit_status = 1;
      return exit_status;
    }

  /* Read data */
  for (i = 0; i < n; ++i)
    for (j = 0; j < n; ++j)
      scanf("%lf", &AINV(i, j));
  for (i = 0; i < p; ++i)
    for (j = 0; j < n; ++j)
      scanf("%lf", &C(i, j));
  if (rinv)
    for (i = 0; i < p; ++i)
      for (j = 0; j < p; ++j)
        scanf("%lf", &RINV(i, j));
  for (i = 0; i < n; ++i)
    for (j = 0; j < m; ++j)
      scanf("%lf", &AINVB(i, j));
  for (i = 0; i < m; ++i)
    for (j = 0; j < m; ++j)
      scanf("%lf", &QINV(i, j));
  for (i = 0; i < n; ++i)
    for (j = 0; j < n; ++j)
      scanf("%lf", &T(i, j));
  for (j = 0; j < m; ++j)
    scanf("%lf", &z[j]);
  for (j = 0; j < n; ++j)
    scanf("%lf", &x[j]);
  for (j = 0; j < p; ++j)
    scanf("%lf", &rinvy[j]);

  /* Perform three iterations of the Kalman filter recursion  */

  for (istep = 1; istep <= 3; ++istep)
    /* nag_kalman_sqrt_filt_info_invar (g13edc).
     * One iteration step of the time-invariant Kalman filter
     * recursion using the square root information
     * implementation with (A^(-1)(A^(-1)B)) in upper
     * controller Hessenberg form
     */
    nag_kalman_sqrt_filt_info_invar(n, m, p, t, tdt, ainv,
                                    tdainv, ainvb, tdainvb, rinv,
                                    tdrinv, c, tdc, qinv,
                                    tdqinv, x, rinvy, z, tol, &fail);
  if (fail.code != NE_NOERROR)
    {
      printf("Error from nag_kalman_sqrt_filt_info_invar (g13edc).\n%s\n",
              fail.message);
      exit_status = 1;
      goto END;
    }
  printf("\nThe inverse of the square root of the state covariance "
          "matrix is \n\n");

  for (i = 0; i < n; ++i)
    {
      for (j = 0; j < n; ++j)
        printf("%8.4f ", T(i, j));
      printf("\n");
    }
  printf("\nThe components of the estimated filtered state are\n\n");
  printf("   k       x(k)  \n");
  for (i = 0; i < n; ++i)
    {
      printf("   %ld  ", i);
      printf("  %8.4f  \n", x[i]);
    }

 END:
  NAG_FREE(ainv);
  NAG_FREE(qinv);
  NAG_FREE(rinv);
  NAG_FREE(t);
  NAG_FREE(ainvb);
  NAG_FREE(c);
  NAG_FREE(x);
  NAG_FREE(z);
  NAG_FREE(rinvy);

  return exit_status;
}

static void mat_io(Integer n, Integer m, double mat[], Integer tdmat,
                   ioflag flag, const char *message);

#define AINV(I, J)   ainv[(I) *tdainv + J]
#define AINVB(I, J)  ainvb[(I) *tdainvb + J]
#define C(I, J)      c[(I) *tdc + J]
#define AINVU(I, J)  ainvu[(I) *tdainvu + J]
#define AINVBU(I, J) ainvbu[(I) *tdainvbu + J]
#define QINV(I, J)   qinv[(I) *tdqinv + J]
#define RINV(I, J)   rinv[(I) *tdrinv + J]
#define T(I, J)      t[(I) *tdt + J]
#define RWORK(I, J)  rwork[(I) *tdrwork + J]
#define TU(I, J)     tu[(I) *tdtu + J]
#define IG(I, J)     ig[(I) *tdig + J]
#define IH(I, J)     ih[(I) *tdih + J]
#define CU(I, J)     cu[(I) *tdcu + J]
#define U(I, J)      u[(I) *tdu + J]

static int ex2(void)
{
  Integer            dete, exit_status = 0, i, ione = 1, istep, j, m, n, p,
                     tdainv, tdainvb;
  Integer            tdainvbu, tdainvu, tdc, tdcu, tdig, tdih, tdqinv, tdrinv,
                     tdrwork;
  Integer            tdt, tdtu, tdu;
  Nag_ControllerForm reduceto = Nag_UH_Controller;
  Nag_ab_input       inp_ab = Nag_ab_prod;
  double             *ainv = 0, *ainvb = 0, *ainvbu = 0, *ainvu = 0, *c = 0;
  double             *cu = 0, detf, *diag = 0, *ig = 0, *ih = 0, one = 1.0;
  double             *qinv = 0, *rinv = 0, *rinvy = 0, *rwork = 0, *t = 0;
  double             tol, *tu = 0, *u = 0, *ux = 0, *x = 0, *z = 0, zero = 0.0;

  /* Nag Types */
  NagError           fail;

  INIT_FAIL(fail);

  printf("\n\nExample 2\n");

  /* skip the heading in the data file */
  scanf(" %*[^\n]");
  scanf("%ld%ld%ld%lf", &n, &m, &p, &tol);
  if (n >= 1 || m >= 1 || p >= 1)
    {
      if (!(ainv = NAG_ALLOC(n*n, double)) ||
          !(ainvb = NAG_ALLOC(n*m, double)) ||
          !(c = NAG_ALLOC(p*n, double)) ||
          !(ainvu = NAG_ALLOC(n*n, double)) ||
          !(ainvbu = NAG_ALLOC(n*m, double)) ||
          !(qinv = NAG_ALLOC(m*m, double)) ||
          !(rinv = NAG_ALLOC(p*p, double)) ||
          !(t = NAG_ALLOC(n*n, double)) ||
          !(x = NAG_ALLOC(n, double)) ||
          !(z = NAG_ALLOC(n, double)) ||
          !(rwork = NAG_ALLOC(n*n, double)) ||
          !(tu = NAG_ALLOC(n*n, double)) ||
          !(rinvy = NAG_ALLOC(p, double)) ||
          !(ig = NAG_ALLOC(n*n, double)) ||
          !(ih = NAG_ALLOC(n*n, double)) ||
          !(cu = NAG_ALLOC(p*n, double)) ||
          !(u = NAG_ALLOC(n*n, double)) ||
          !(ux = NAG_ALLOC(n, double)) ||
          !(diag = NAG_ALLOC(n, double)))
        {
          printf("Allocation failure\n");
          exit_status = -1;
          goto END;
        }
      tdainv = n;
      tdainvb = m;
      tdc = n;
      tdainvu = n;
      tdainvbu = m;
      tdqinv = m;
      tdrinv = p;
      tdt = n;
      tdrwork = n;
      tdtu = n;
      tdig = n;
      tdih = n;
      tdcu = n;
      tdu = n;
    }
  else
    {
      printf("Invalid n or m or p.\n");
      exit_status = 1;
      return exit_status;
    }

  /* Read data */
  mat_io(n, n, ainv, tdainv, read, "");
  mat_io(p, n, c, tdc, read, "");
  if (rinv)
    mat_io(p, p, rinv, tdrinv, read, "");
  mat_io(n, m, ainvb, tdainvb, read, "");
  mat_io(m, m, qinv, tdqinv, read, "");
  mat_io(n, n, t, tdt, read, "");
  for (j = 0; j < m; ++j)
    scanf("%lf", &z[j]);
  for (j = 0; j < n; ++j)
    scanf("%lf", &x[j]);
  for (j = 0; j < p; ++j)
    scanf("%lf", &rinvy[j]);

  for (i = 0; i < n; ++i)     /* Initialise the identity matrix u */
    {
      for (j = 0; j < n; ++j)
        U(i, j) = zero;
      U(i, i) = one;
    }

  /* Copy the arrays ainv[] and ainvb[] into ainvu[] and ainvbu[] */
  for (i = 0; i < n; ++i)
    for (j = 0; j < n; ++j)
      AINVU(j, i) = AINV(j, i);
  for (j = 0; j < m; ++j)
    for (i = 0; i < n; ++i)
      AINVBU(i, j) = AINVB(i, j);

  /* Transform (ainvu[],ainvbu[]) to reduceto controller Hessenberg form */
  /* nag_trans_hessenberg_controller (g13exc).
   * Unitary state-space transformation to reduce (BA) to
   * lower or upper controller Hessenberg form
   */
  nag_trans_hessenberg_controller(n, m, reduceto, ainvu, tdainvu, ainvbu,
                                  tdainvbu, u, tdu, &fail);
  if (fail.code != NE_NOERROR)
    {
      printf("Error from nag_trans_hessenberg_controller (g13exc).\n%s\n",
              fail.message);
      exit_status = 1;
      goto END;
    }

  /* Calculate the matrix cu = c*u'    */
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_Trans, p, n, n, one, c, tdc,
            u, tdu, zero, cu, tdcu, &fail);

  /* Calculate the vector ux = u*x     */
  nag_dgemv(Nag_RowMajor, Nag_NoTrans, n, n, one, u, tdu, x, ione,
            zero, ux, ione, &fail);

  /* Form the information matrices ih = u*ig*u' and ig = t'*t     */
  nag_dgemm(Nag_RowMajor, Nag_Trans, Nag_NoTrans, n, n, n, one, t, tdt,
            t, tdt, zero, ig, tdig, &fail);
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_Trans, n, n, n, one, ig, tdig,
            u, tdu, zero, rwork, tdrwork, &fail);
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_NoTrans, n, n, n, one, u, tdu,
            rwork, tdrwork, zero, ih, tdih, &fail);

  /* Now find the reduceto triangular (right) cholesky factor of ih */
  /* nag_real_cholesky (f03aec).
   * LL^T factorization and determinant of real symmetric
   * positive-definite matrix
   */
  f03aec(n, ih, tdih, diag, &detf, &dete, &fail);
  if (fail.code != NE_NOERROR)
    {
      printf("Error from nag_real_cholesky (f03aec).\n%s\n",
              fail.message);
      exit_status = 1;
      goto END;
    }
  for (i = 0; i < n; ++i)
    {
      TU(i, i) = one/diag[i];
      for (j = 0; j < i; ++j)
        {
          TU(j, i) = IH(i, j);
          TU(i, j) = zero;
        }
    }

  /* Do three iterations of the Kalman filter recursion  */
  for (istep = 1; istep <= 3; ++istep)
    {
      /* nag_kalman_sqrt_filt_info_var (g13ecc).
       * One iteration step of the time-varying Kalman filter
       * recursion using the square root information
       * implementation
       */
      nag_kalman_sqrt_filt_info_var(n, m, p, inp_ab, t, tdt, ainv,
                                    tdainv, ainvb, tdainvb, rinv, tdrinv,
                                    c, tdc, qinv, tdqinv, x,
                                    rinvy, z, tol, &fail);
      if (fail.code != NE_NOERROR)
        {
          printf("Error from nag_kalman_sqrt_filt_info_var (g13ecc).\n%s\n",
                  fail.message);
          exit_status = 1;
          goto END;
        }
      /* nag_kalman_sqrt_filt_info_invar (g13edc), see above. */
      nag_kalman_sqrt_filt_info_invar(n, m, p, tu, tdtu, ainvu, tdainvu,
                                      ainvbu, tdainvbu, rinv, tdrinv, cu,
                                      tdcu, qinv, tdqinv, ux, rinvy,
                                      z, tol, &fail);
      if (fail.code != NE_NOERROR)
        {
          printf("Error from nag_kalman_sqrt_filt_info_invar (g13edc).\n%s\n",
                  fail.message);
          exit_status = 1;
          goto END;
        }
    }

  /* Print Results */
  printf("\nResults from nag_kalman_sqrt_filt_info_var (g13ecc) \n\n");
  /* let ig = t' * t */
  nag_dgemm(Nag_RowMajor, Nag_Trans, Nag_NoTrans, n, n, n, one, t, tdt,
            t, tdt, zero, ig, tdig, &fail);
  mat_io(n, n, ig, tdig, print, "The information matrix ig is\n");
  printf("\nThe components of the estimated filtered state are\n\n");
  printf("  k       x(k)  \n");
  for (i = 0; i < n; ++i)
    printf("  %ld   %8.4f \n", i, x[i]);
  printf("\nResults from nag_kalman_sqrt_filt_info_invar (g13edc) \n\n");
  /* let ih = tu' * tu */
  nag_dgemm(Nag_RowMajor, Nag_Trans, Nag_NoTrans, n, n, n, one, tu, tdtu,
            tu, tdtu, zero, ih, tdih, &fail);
  mat_io(n, n, ih, tdih, print, "The information matrix ih is\n");

  /* Calculate ih = u'*ih*u   */
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_NoTrans, n, n, n, one, ih, tdih,
            u, tdu, zero, rwork, tdrwork, &fail);
  nag_dgemm(Nag_RowMajor, Nag_Trans, Nag_NoTrans, n, n, n, one, u, tdu,
            rwork, tdrwork, zero, ih, tdih, &fail);
  mat_io(n, n, ih, tdih, print, "\nThe matrix u' * ih * u is\n");

  /* Calculate x = u' * ux */
  nag_dgemv(Nag_RowMajor, Nag_Trans, n, n, one, u, tdu, ux, ione,
            zero, x, ione, &fail);
  printf("\nThe components of the estimated filtered state are \n\n");
  printf("  k       x(k)  \n");
  for (i = 0; i < n; ++i)
    printf("  %ld   %8.4f \n", i, x[i]);

 END:
  NAG_FREE(ainv);
  NAG_FREE(ainvb);
  NAG_FREE(c);
  NAG_FREE(ainvu);
  NAG_FREE(ainvbu);
  NAG_FREE(qinv);
  NAG_FREE(rinv);
  NAG_FREE(t);
  NAG_FREE(x);
  NAG_FREE(z);
  NAG_FREE(rwork);
  NAG_FREE(tu);
  NAG_FREE(rinvy);
  NAG_FREE(ig);
  NAG_FREE(ih);
  NAG_FREE(cu);
  NAG_FREE(u);
  NAG_FREE(ux);
  NAG_FREE(diag);

  return exit_status;
}

static void mat_io(Integer n, Integer m, double mat[], Integer tdmat,
                   ioflag flag, const char *message)
{
  Integer i, j;
#define MAT(I, J) mat[((I) -1) * tdmat + (J) -1]
  if (flag == print) printf("%s \n", message);
  for (i = 1; i <= n; ++i)
    {
      for (j = 1; j <= m; ++j)
        {
          if (flag == read) scanf("%lf", &MAT(i, j));
          if (flag == print) printf("%8.4f ", MAT(i, j));
        }
      if (flag == print) printf("\n");
    }
} /* mat_io */