#include	"solver/Gauss_column_pivot.h"

#include	<string.h>
#include	"transport/Equation_set.h"
#include	"transport/Matrix.h"
#include	"transport/State_vector.h"

Type*	Gauss_column_pivot_Type_pointer;
static	char*	class_name = "the_solver(gauss_column_pivot)";


Gauss_column_pivot::Gauss_column_pivot(Metaclass* meta, Solver_def* def)
				: Solver(meta, def)
{
	DEBUG	<< "Gauss_column_pivot::Gauss_column_pivot(" << meta->oid() << ", "
	DEBUG	<< def->oid() << ")\n";

	name(class_name);
	directType(TYPE_OF(Gauss_column_pivot));
};

Gauss_column_pivot::Gauss_column_pivot(APL* theAPL) : Solver(theAPL)
{
};

Gauss_column_pivot::~Gauss_column_pivot()
{
	DEBUG << "Gauss_column_pivot::~Gauss_column_pivot()\n";
};


Vector&	Gauss_column_pivot::execute(Equation_iterator eqn_iter)
{
	DEBUG1	<< "State_vector*  Gauss_column_pivot::execute(Equation_iterator eqn_iter)\n";

				// Unencapsulate Equation_set into Matrix
				// (should be a separate class in future).

	int n = eqn_iter.size();
	Matrix	a(n,n,0);
	Vector&	b = *new Vector(n,0);
	Vector st = Vector(n,0);

  	register int i=0, j=0, k=0;
	eqn_iter(RESET);
	Equation*	eq;
	while (eq = eqn_iter(FORWARD)) {
					// ?should check eqn type
		i = eq->state_variable()->id();
		b[i] = eq->gain().value();
		st[i] = eq->state_variable()->state_variable();
		List_iterator(Coefficient) coeff_iter = eq->coeff_iterator();
		Coefficient* coeff;
		while(coeff = coeff_iter(FORWARD)) {
					// ?should check coeff type
			j = coeff->state_variable()->id();
			a[i][j] = coeff->value();
		};
	};

// start time discretisation: temporalily force time step = 1 hour.
	time_discretise(a, b, st, n, 3600);
	

  	float temp, pivot, err;
  	register int io, in, ii;

	DEBUG2	<< "build eqn matrix\nGauss_column_pivot::\n";
	DEBUG2	<< "Matrix dimension = " << n << "\nMatrix:" << a << "\n";
	DEBUG2	<< "Right vector:" << b << "\n";

					// Select pivoting.
	err=1e-9;
  	for(k=0; k<n; k++) {
    		pivot = 0;
    		for(i=k; i<n; i++) 
      			if(fabs(a[i][k]) > fabs(pivot)) {
        			pivot = a[i][k];
				io = i;
			};

					// Check for singularity.
    		if((fabs(pivot)-err) <= 0) {
      			puts("No solution !!");
      			break;
    		}
    		if(io > k) {
      			for(j=k; j<n; j++) {
        			temp = a[k][j];
				a[k][j] = a[io][j];
				a[io][j] = temp;
      				temp = b[k];
      				b[k] = b[io];
      				b[io] = temp;
			};
    		}
    		pivot = 1/pivot; in = n-1;
    		if(k<in) {
      			for(j=k; j<in; j++) {
        			a[k][j+1] *= pivot;
        			for(i=k; i<in; i++)
          				a[i+1][j+1] -= a[i+1][k]*a[k][j+1];
     			}
    		}
    		b[k] *= pivot;
    		if(k==in) break;
    		for(j=k; j<in; j++)
			b[j+1] -= b[k]*a[j+1][k];
  	}
  	for(ii=1; ii<n; ii++) {
    		i = in-ii;
   		for(j=i; j<in; j++)
			b[i] -= a[i][j+1]*b[j+1];
  	}

					// return new state_vector.

	DEBUG2	<< "solved matrix\nGauss_column_pivot::result=\n";
	DEBUG2	<< "New right vector: " << b << "\n";
	return b;
};
