package myArray;

import java.util.Arrays;
import java.lang.IllegalArgumentException;

public class IntegerMatrix {
	int rows;
	int cols;
	int[][] data;

	public IntegerMatrix(int rows, int cols) {
		this.rows=rows;
		this.cols=cols;
		data=new int[rows][cols];
	}

	public IntegerMatrix(int rows,int cols,int init) {
		this(rows,cols);
		for(int r=0;r<rows;r++) {
			for(int c=0;c<cols;c++) {
				if (init==0) data[r][c]= (r==c ? 1:0);
				else data[r][c]=init;
			}
		}
	}

	public void set(int r,int c,int v) { data[r][c]=v; }
	public int get(int r,int c) { return data[r][c]; }
	public void addTo(int r,int c, int v) { data[r][c] += v; }

	public int[] getRow(int r) { return data[r]; }
	public int[] getCol(int c) {
		int[] col = new int[rows];
		for(int r=0;r<rows;r++) col[r]=data[r][c];
		return col;
	}

	public IntegerMatrix product(IntegerMatrix that) {
		if (this.cols != that.rows) throw new IllegalArgumentException("cols of this must match rows of that");
		IntegerMatrix prod = new IntegerMatrix(this.rows,that.cols);
		for(int r=0;r<this.rows;r++) {
			for(int c=0;c<that.cols;c++) {
				for(int c2=0;c2<this.cols;c2++) {
					prod.addTo(r,c,this.get(r,c2) * that.get(c2,c));
				}
			}
		}
		return prod;
	}

	public String[] toStringArray() {
		String[] mat = new String[rows];
		for(int r=0;r<rows;r++) mat[r] = Arrays.toString(data[r]);
		return mat;
	}

	static public void main(String[] args) {
		IntegerMatrix a = new IntegerMatrix(4,5,3);
		IntegerMatrix b = new IntegerMatrix(5,6,0);
		IntegerMatrix c = new IntegerMatrix(5,6,1);

		System.out.println("Matrix a:");
		for(String s: a.toStringArray()) System.out.println("  " + s);

		System.out.println("Matrix b:");
		for(String s: b.toStringArray()) System.out.println("  " + s);

		System.out.println("Matrix c:");
		for(String s: c.toStringArray()) System.out.println("  " + s);

		System.out.println("Matrix a x b:");
		for(String s: a.product(b).toStringArray()) System.out.println("  " + s);

		System.out.println("Matrix a x c:");
		for(String s: a.product(c).toStringArray()) System.out.println("  " + s);

	}
}