#include <assert.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <sys/resource.h>
#include <unistd.h>

typedef double REAL;

#define ELEM(A, an, i, j) ((A)[(i)*(an)+(j)])

/* Compute C = A*B, where A, B, and C, are n*n real matrices
 * n may be any value.
 *   an is the number of elements in a row of A
 *   bn is the number of elements in a row of B
 *   cn is the number of elements in a row of C
 */ 
void matrixmul (int n,
		REAL *A, int an,
		REAL *B, int bn, 
		REAL *C, int cn)
{
    int i,j,l;
    for (i=0; i<n; i++) {
        for (l=0; l<n; l++) {
	    ELEM(C, cn, i, l) = 0.0;
        }
    }
    for (i=0; i<n; i++) {
        for (l=0; l<n; l++) {
            for (j=0; j<n; j++) {
                ELEM(C, cn, i, l) += ELEM(A, an, i,j) * ELEM(B, bn, j, l);
            }
        }
    }
}

cilk void strassen(int n,
		   REAL *A, int an,
		   REAL *B, int bn, 
                   REAL *C, int cn) {
    /* You should put your code here. */
    matrixmul(n, A, an, B, bn, C, cn);
}


/* Effect: Make an N by M matrix, and fill it in with zeros. */
REAL *make_matrix (int n, int m) {
    REAL *result = malloc(n*m*sizeof(REAL));
    int i;
    int limit=n*m;
    for (i=0; i<limit; i++)
	result[i]=0.0;
    return result;
}

/* Effect: Destroy a matrix, returning its memory back to the system. */
void destroy_matrix (REAL *m) {
    free (m);
}

/* Effect: Return 0 if two matrices are the same, otherwise return nonzero. */
int compare_matrices (REAL *a, REAL *b, int n, int m) {
    int i;
    int limit=n*m;
    for (i=0; i<limit; i++)
	if (a[i]!=b[i]) return 1;
    return 0;
}

cilk int main (int argc, char *argv[]) {
    char **org_argv=argv;
    int do_compare=0;
    int do_print=0;
    int size=256;
    argc--; argv++;
    while (argc>0) {
	if (strcmp(*argv, "-c")==0)      do_compare=1;
	else if (strcmp(*argv, "-p")==0) do_print=1;
	else if (strcmp(*argv, "-n")==0) {
	    char *end;
	    argc--; argv++;
	    assert(argc>0);
	    errno=0;
	    size = strtol(*argv, &end, 10);
	    assert(errno==0 && *end==0 && *argv!=end);
	} else if (strcmp(*argv, "-h")==0) {
	    printf("Usage: %s [-n <n>] [-c]\n", org_argv[0]);
	} else {
	    fprintf(stderr, "Unrecognized argument. Try %s -h for help\n", org_argv[0]);
	}
	argc--; argv++;
    }
    {
	REAL *m1 = make_matrix(size, size);
	REAL *m2 = make_matrix(size, size);
	REAL *m3 = make_matrix(size, size);
	int i;
	for (i=0; i<size*size; i++) {
	    m1[i]=(double)(random()%(size*size*size));
	    m2[i]=(double)(random()%(size*size*size));
	}

	{
	    struct rusage rstart, rend;
	    getrusage(RUSAGE_SELF, &rstart);
	    spawn strassen(size, m1, size, m2, size, m3, size);
	    getrusage(RUSAGE_SELF, &rend);

	    if (do_print) {
		int i,j;
		for (i=0; i<size; i++) {
		    for (j=0; j<size; j++) {
			printf("%g ", ELEM(m1, size, i, j));
		    }
		    if (i*2==size) printf(" *  ");
		    else           printf("    ");
		    for (j=0; j<size; j++) {
			printf("%g ", ELEM(m2, size, i, j));
		    }
		    if (i*2==size) printf(" =  ");
		    else           printf("    ");
		    for (j=0; j<size; j++) {
			printf("%g ", ELEM(m3, size, i, j));
		    }
		    printf("\n");
		}
	    }
	    if (do_compare) {
		REAL *m4 = make_matrix(size, size);
		matrixmul(size, m1, size, m2, size, m4, size);
		if (compare_matrices(m3, m4, size, size)==0) {
		    printf("OK\n");
		} else {
		    printf("Not OK\n");
		    return 1;
		}
	    }
	    
	    /* Print the rusage only if the compares were OK. */
	    {
		double time = rend.ru_utime.tv_sec - rstart.ru_utime.tv_sec + (rend.ru_utime.tv_usec-rstart.ru_utime.tv_usec)*1e-6;
		printf("size=%d Time=%f user seconds\n", size, time);
	    }
	}
	    
    }
    return 0;
}
