#include "bigfloat.h"

int driver(int argc, char**argv)
{
    // find what degree formula to generate
    ASSERT(argc == 2, "usage: multistep 5, to get 11th degree (5*2+1=11)\n");
    int halfDegree;
    sscanf(argv[1], "%d", &halfDegree);
    ASSERT(halfDegree > 0, "half degree must be at least 1");
    int degree = 2*halfDegree+1;

    // build polynomials of the appropriate degree, and second derivative
    BigFloat *p2 = new BigFloat[3*degree];  // position: poly of degree
    BigFloat *p = &p2[degree];
    BigFloat *a2 = new BigFloat[3*degree];  // acceleration: poly of degree-2
    BigFloat *a = &a2[degree];
    for (int i=0; i<3*degree; ++i)
    {
        BigFloat x(1);
        BigFloat c(i-degree);
        for (int j=0; j<degree-3; ++j)
            x.Mult(c);
        a2[i].Copy(x);
        a2[i].Mult((degree-1)*(degree-2));
        x.Mult(c);
        x.Mult(c);
        p2[i].Copy(x);
    }

    // allocate the matrix
    BigFloat **m = new BigFloat*[halfDegree]();
    for (int i=0; i<halfDegree; ++i)
        m[i] = new BigFloat[halfDegree+1];

    // fill in the matrix using the appropriate polynomials
    // p-2 - p-1 - p1 + p2sy == c0*a0 + c1*(a-1 + a1)
    // c0, c1 are what we are solving for
    for (int i=0; i<halfDegree; ++i)
    {
        m[i][0].Copy(a[i]);
        for (int j=1; j<halfDegree; ++j)
        {
            m[i][j].Copy(a[i-j]);
            m[i][j].Add(a[i+j]);
        }
        m[i][halfDegree].Copy(p[i-halfDegree]);
        m[i][halfDegree].Sub(p[i+1-halfDegree]);
        m[i][halfDegree].Sub(p[i-1+halfDegree]);
        m[i][halfDegree].Add(p[i+halfDegree]);
    }

    // Solve for the coefficients
    BigFloat::GaussianElimination(m, halfDegree, halfDegree);

    // report the coefficients
    for (int i=0; i<halfDegree; ++i)
    {
        BigFloat num;
        BigFloat denom;
        m[i][halfDegree].ToFraction(num, denom);
        num.PrintDecimal();
        printf(" / ");
        denom.PrintDecimal();
        printf("\n");
    }
    return 0;
}


int main(int argc, char**argv)
{
    try
    {
        return driver(argc, argv);
    }
    catch( const std::exception & ex )
    {
        fprintf(stderr, "%s\n", ex.what());
        return 1;
    }
}