//============================================================
// COOOL           version 1.1           ---     Nov,  1995
//   Center for Wave Phenomena, Colorado School of Mines
//============================================================
//
//   This code is part of a preliminary release of COOOL (CWP
// Object-Oriented Optimization Library) and associated class 
// libraries. 
//
// The COOOL library is a free software. You can do anything you want
// with it including make a fortune.  However, neither the authors,
// the Center for Wave Phenomena, nor anyone else you can think of
// makes any guarantees about anything in this package or any aspect
// of its functionality.
//
// Since you've got the source code you can also modify the
// library to suit your own purposes. We would appreciate it 
// if the headers that identify the authors are kept in the 
// source code.


#ifndef DIAGMATRIX_HH
#define DIAGMATRIX_HH

//=====================================
// Diagonal-Matrix class template library for efficient algebraic operations
//
//  H. Lydia Deng,  01/24/1994
//======================================

// .NAME Vector class template
// .LIBRARY Base
// .HEADER c++ uitility classes
// .INCLUDE spamatrix.hh
// .FILE diagmatrix.hh
// 
// .SECTION Description
// .B DiagMatrix
// is a simple class for Diagonal matrices derived from SpaMatrix class
// it manages to do some simple operations of with arrays
// 
// .SECTION Caveats
// A paragraph describing unusual features of the class.
// This paragraph may be omitted.

#include "Vector.hh"
#include "Matrix.hh"

static const char*  myName1 =  "diagonal matrix class";

//******************************
// define the class of Diagonal Matrix
/********************************/

//@Man:
//@Memo: a diagonal matrix class
/*Doc: 
 This is a preliminary class template for diagonal matrix computation.
This class is derived from the abstract class, Matrix<Type>.
*/

namespace coool 
{
    using namespace coool;

#pragma interface
   
template<class Type>
class DiagMatrix : public Matrix<Type> 
{
   private:
   Vector<Type> *a;
   using Matrix<Type>::nrow;
   using Matrix<Type>::ncol;

   public:			
    //@ManMemo: a default constructor
    DiagMatrix();
    //@ManMemo: construct a m x m diagonal matrix
    DiagMatrix(int m);
    //@ManMemo: construct a DiagMatrix with diagonal elements the same as the Vector
    DiagMatrix(const Vector<Type>& x);
    //@ManMemo: construct an identical DiagMatrix
    DiagMatrix(const DiagMatrix<Type>& A);
    ~DiagMatrix(){delete a; } 
    
    //@ManMemo: returns the Matrix Type
    const char* matrixType() const {return myName1;}

    //@ManMemo: returns the maximum element on the diagonal
    Type max() const{ return a->max();}
    //@ManMemo: returns the minimum element on the diagonal
    Type min() const{ return a->min();}
    //@ManMemo: returns the index of the maximum diagonal element
    int indexMax()  const{ return a->indexMax();}
    //@ManMemo: returns the index of the minimum diagonal element
    int indexMin()  const{ return a->indexMin();}
    //@ManMemo: returns the diagonal as a Vector
    Vector<Type> diagVector() const { return *a;}

    //@ManMemo: returns the $i$th-row vector
    Vector<Type> rowVector(int i) const { 
	Vector<Type> v(ncol); v[i] = (*a)[i]; return v;}
    //@ManMemo: returns the $j$th-column vector
    Vector<Type> colVector(int j) const{ 
	Vector<Type> v(nrow); v[j] = (*a)[j]; return v;}

   //implicit type conversion
   operator DiagMatrix<int>() const;
   operator DiagMatrix<long>() const;
   operator DiagMatrix<float>() const;
   operator DiagMatrix<double>() const;

   //@ManMemo: fetch size of the DiagMatrix
      int size() const {return nrow;} 	
   //@ManMemo:  returns the $i$th diagonal element, re-writable
    Type& operator[](int i) {return (*a)[i];}
   //@ManMemo:  returns the $i$th diagonal element, non re-writable
    Type  operator[](int i) const {return (*a)[i];}

    //@ManMemo: sum across the $j$th-row
    Type operator() (char*, int j) { return (*a)[j];}
    //@ManMemo: sum across the $i$th-column
    Type operator() (int i, char*) { return (*a)[i];}
    //@ManMemo: 
    Type rowMax(int ir) const {return (*a)[ir];}
    //@ManMemo: 
    Type rowMin(int ir) const {return (*a)[ir];}
    //@ManMemo: 
    Type colMax(int ic) const {return (*a)[ic];}
    //@ManMemo: 
    Type colMin(int ic) const {return (*a)[ic];}
    
			//@ManMemo: Overloading operators
    DiagMatrix<Type>& operator=(const DiagMatrix<Type>&);
    //@ManMemo: 
    DiagMatrix<Type>& operator=(const Vector<Type>&);
    //@ManMemo: 
    DiagMatrix<Type>& operator= (const Type);
    //@ManMemo: 
    DiagMatrix<Type>& operator= (const Type*);
    //@ManMemo: 
    DiagMatrix<Type>& operator+=(const DiagMatrix<Type>&);
    //@ManMemo: 
    DiagMatrix<Type>& operator-=(const DiagMatrix<Type>&);
    //@ManMemo: 
    DiagMatrix<Type>& operator+=(const Type);
    //@ManMemo: 
    DiagMatrix<Type>& operator-=(const Type);
    //@ManMemo: 
    DiagMatrix<Type>& operator*=(const Type);
    //@ManMemo:
    DiagMatrix<Type>& operator/=(const Type);

   //@ManMemo: friend functions for I/O streams
template < class T >
    friend ostream& operator<< (ostream&, const DiagMatrix<T>&);

   //@ManMemo: friend functions, DiagMatrix * Vector
template < class T >
friend Vector<T> operator* (const DiagMatrix<T>& d, const Vector<T>& x);
   //@ManMemo: friend function, Vector * DiagMatrix
template < class T >
friend Vector<T> operator* (const Vector<T>& x, const DiagMatrix<T>& d);
 
   //@ManMemo: friend function, G(transpose)axpy
    Vector<Type> atdotx(const Vector<Type>& v){
	DiagMatrix<Type> A(*this);
	return A*v;
    }
   //@ManMemo: friend function, Gaxpy
    Vector<Type> adotx(const Vector<Type>& v){
	DiagMatrix<Type> A(*this);
	return A*v;
    }
};

#pragma implementation

template<class Type>
Vector<Type> operator*(const DiagMatrix<Type>& d, const Vector<Type>& x)
{
   if (d.size() != x.size()) inValidSize();
     
   int n = Min(d.size(), x.size());
     
   Vector<Type> y(n);
     
   for (int i=0; i<n;  i++) y[i] = x[i]*d.a[0][i];
     
   return y;
}

 
template<class Type>
Vector<Type> operator*(const Vector<Type>& x, const DiagMatrix<Type>& d)
{
   if (d.size() != x.size()) inValidSize();
     
   int n = Min(d.size(), x.size());
     
   Vector<Type> y(n);
     
   for (int i=0;  i<n; i++) y[i] = x[i]*d.a[0][i];
     
   return y;
}

 
template<class Type>
DiagMatrix<Type>::DiagMatrix():Matrix<Type>(1,1)
{  
 a = new Vector<Type>;
}

 
template<class Type>
DiagMatrix<Type>::DiagMatrix(int m)
:Matrix<Type>(m,m) {  
 a = new Vector<Type>(m); }

 
template<class Type>
DiagMatrix<Type>::DiagMatrix(const Vector<Type>& x) 
:Matrix<Type>(x.size(), x.size()) 
{
   a = new Vector<Type>(nrow); 
   *a = x;
}

 
template<class Type>
DiagMatrix<Type>::DiagMatrix(const DiagMatrix<Type>& A)
:Matrix<Type>(A.nrow, A.ncol) 
{
   a = new Vector<Type>(nrow); 
   *a = *(A.a); 
}

 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator=(const DiagMatrix<Type>& d)
{
   if (nrow != d.nrow) inValidSize();
   *a = *(d.a);  
   return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator=(const Vector<Type>& v)
{
    if (nrow != v.size()) inValidSize();
    *a = v;  
    return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator= (const Type c)
{
   *a = c;     
   return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator= (const Type *p)
{
   for (int i=0;  i<nrow;  i++) a[0][i] = p[i];  
   return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator+=(const DiagMatrix<Type>& d)
{
   *a += *(d.a);  
   return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator-=(const DiagMatrix<Type>& d)
{
   *a -= *(d.a);  
   return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator+=(const Type c)
{
   *a += c;  
   return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator-=(const Type c)
{
   *a -= c;  
   return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator*=(const Type c)
{
   (*a) *= c;  
   return *this;
}
 
template<class Type>
DiagMatrix<Type>& DiagMatrix<Type>::operator/=(const Type c)
{
   assert(c!=0); 
   (*a) /= c;  
   return *this;
}

//implicit type conversion
template<class Type>
DiagMatrix<Type>::operator DiagMatrix<int>() const
{
     
    DiagMatrix<int> A(nrow);
     
    for (int i=0; i<nrow; i++) A[i] = a[i];
     
    return A;
}
 
template<class Type>
DiagMatrix<Type>::operator DiagMatrix<long>() const
{
     
    DiagMatrix<long> A(nrow);
     
    for (int i=0; i<nrow; i++) A[i] = a[i];
     
    return A;
}
 
template<class Type>
DiagMatrix<Type>::operator DiagMatrix<float>() const
{
     
    DiagMatrix<float> A(nrow);
     
    for (int i=0; i<nrow; i++) A[i] = a[i];
     
    return A;
}
 
template<class Type>
DiagMatrix<Type>::operator DiagMatrix<double>() const
{
     
    DiagMatrix<double> A(nrow);
     
    for (int i=0; i<nrow; i++) A[i] = a[i];
     
    return A;
}

//write data into a file in ASCII mode
template<class Type>
ostream& operator<<(ostream& ofp, const DiagMatrix<Type>& A)
{
 
	int i,j;
 
	ofp<<endl;
  
 	for(i=0; i<A.nrow; i++) 
	{
	   for(j=0; j<A.ncol; j++)
	   {
	      if (i==j) ofp<<A.a[0][i]<<", "; else ofp<<"0, "; 
	   }
	   ofp<<endl;
	}
 
	ofp<<endl;
 
	return ofp;
}

 
template<class Type>
istream& operator>>(istream& ifp,  DiagMatrix<Type>& A)
{
   ifp >> *(A.a);    
   return ifp;
}

 
template <class Type>
inline DiagMatrix<Type>  operator+(const DiagMatrix<Type>& A, Type c)
{
   DiagMatrix<Type> B(A);  
   B += c;  
   return B;
}
 
template <class Type>
inline DiagMatrix<Type>  operator+(Type c, const DiagMatrix<Type>& A)
{ 
   DiagMatrix<Type> B(A);  
   B += c;  
   return B;
}
 
template <class Type>
inline DiagMatrix<Type>  operator-(const DiagMatrix<Type>& A, Type c)
{
   DiagMatrix<Type> B(A);  
   B -= c;  
   return B;
}
 
template <class Type>
inline DiagMatrix<Type>  operator-(Type c, const DiagMatrix<Type>& A)
{
   DiagMatrix<Type> B(A);  
   B -= c;  
   return -B;
}
 
template <class Type>
inline DiagMatrix<Type>  operator*(const DiagMatrix<Type>& A, Type c)
{
   DiagMatrix<Type> B(A);  
   B *= c;  
   return B;
}
 
template <class Type>
inline DiagMatrix<Type>  operator*(Type c, const DiagMatrix<Type>& A)
{
   DiagMatrix<Type> B(A);  
   B *= c;  
   return B;
}
 
template <class Type>
inline DiagMatrix<Type>  operator/(const DiagMatrix<Type>& A, Type c)
{
   DiagMatrix<Type> B(A);  
   B /= c; 
   return B;
}

/*
template <class Type>	
SpaMatrix<Type> operator*(const DiagMatrix<Type>& d, const SpaMatrix<Type>& A)
{
    int j, nz=A.size(), nrow=A.numOfRows(), ncol=A.numOfCol();
    SpaMatrix bb(A);
    if (d.nrow != nrow || nrow != ncol) inValidSize();
    
    for (int i=0; i<nrow; i++)
	for(j=(A.irow)[0][i]; j<(A.irow)[0][i+1]; j++)
	    bb[j] = bb[j]*d.a[i];
    return bb;
}

template <class Type>	
SpaMatrix<Type> operator*(const SpaMatrix<Type>& A, const DiagMatrix<Type>& d)
{
    int j, nz=A.size(), nrow=A.numOfRows(), ncol=A.numOfCol();
    SpaMatrix bb(A);
    if (d.nrow != nrow || nrow != ncol) inValidSize();

    for (int i=0; i<nz; i++)
	bb[i] = (A->nzelem(i))*d.a[A->colPos(i)];
    return bb;
}
*/
 
}

#endif
