第一版 第 18 章 :  程式碼  13




解答 :
#include <iostream>
#include <vector>
#include <cassert>

using namespace std ;

const double  SMALL = 1.0e-13 ;

template <class T>
T  abs( const T& a ) {
    return  a >= static_cast<T>(0) ? a : -a ;
}


template  <class T>
class   Matrix {

  private :

    vector< vector<T> >  mat ;            // 雙重向量陣列
    int                  row , col ;      // 列數與行數   

    void  swap( vector<T>& a , vector<T>& b ) const {
        vector<T>  tmp = a ;
        a = b ;
        b = tmp ;
    }

  public :

    // 預設建構函式
    Matrix() : row(0) , col(0) {}

    // 建構函式
    Matrix( int r , int c , const T& val = 0 ) : row(r) , col(c) {
        mat = vector< vector<T> >(r,vector<T>(c,val)) ;
    }

    // 建構函式
    Matrix( vector< vector<T> >& a ) 
        : mat(a), row(a.size()), col(a[0].size()) {}
    

    // 回傳矩陣列數與行數
    inline  int  rows() const {  return row ;  }
    inline  int  cols() const {  return col ;  }
    
    // 回傳矩陣的第 i 列的向量陣列參考
    inline vector<T>&        operator [] ( int i )       { 
        return  mat[i] ; 
    }

    inline const vector<T>&  operator [] ( int i ) const { 
        return  mat[i] ; 
    }

    Matrix<T>  operator - () const {
        Matrix<T>  foo = *this ;
        int  i , j ;
        for ( i = 0 ; i < row ; ++i ) {
            for ( j = 0 ; j < col ; ++j ) foo.mat[i][j] = -foo.mat[i][j] ;
        }
        return  foo ;
    }


    T  det() const {

        assert( row == col ) ;
        vector< vector<T> >  a = mat ; 

        int  i , j , k ;
        T   m ;
        for ( i = 0 ; i < row-1 ; ++i ) {

            if ( a[i][i] == static_cast<T>(0) ) {
                for ( j = i+1 ; j < row ; ++j ) {
                    if ( a[j][i] != static_cast<T>(0) ) {
                        swap( a[i] , a[j] ) ;
                        break ;
                    }
                }
                if ( j == row ) return  static_cast<T>(0) ;
            }
            
            for ( j = i+1 ; j < row ; ++j ) {
                m = a[j][i] / a[i][i] ;
                for ( k = i ; k < col ; ++k ) {
                    a[j][k] -= m * a[i][k] ;
                }
            }

        }

        T  prod = static_cast<T>(1) ;
        for ( i = 0 ; i < row ; ++i ) prod *= a[i][i] ;
        return  prod ;

    }

    Matrix<T>  inverse() const {

        assert( row == col ) ;
        vector< vector<T> >  a = mat ; 
        vector< vector<T> >  b(row, vector<T>(col,static_cast<T>(0)))  ;

        int  i , j , k ;
        for ( i = 0 ; i < row ; ++i ) b[i][i] = static_cast<T>(1) ;
        

        T   m ;
        for ( i = 0 ; i < row-1 ; ++i ) {

            if ( a[i][i] == static_cast<T>(0) ) {
                for ( j = i+1 ; j < row ; ++j ) {
                    if ( a[j][i] != static_cast<T>(0) ) {
                        swap( a[i] , a[j] ) ;
                        swap( b[i] , b[j] ) ;
                        break ;
                    }
                }
                assert( j < row ) ;
            }
            
            for ( j = i+1 ; j < row ; ++j ) {
                m = a[j][i] / a[i][i] ;
                for ( k = i ; k < col ; ++k ) a[j][k] -= m * a[i][k] ;
                for ( k = 0 ; k < col ; ++k ) b[j][k] -= m * b[i][k] ;
            }

        }

        for ( i = row-1 ; i >= 0 ; --i ) {
            for ( j = i-1 ; j >= 0 ; --j ) {
                m = a[j][i] / a[i][i] ;
                for ( k = i ; k >= j ; --k )  a[j][k] -= m * a[i][k] ;
                for ( k = 0 ; k < col ; ++k ) b[j][k] -= m * b[i][k] ;
            }
        }
        
        for ( i = 0 ; i < row ; ++i ) 
            for ( j = 0 ; j < col ; ++j ) 
                b[i][j] /= a[i][i] ;
        
        return  Matrix<T>(b) ;

    }
    
        
    
    // 重載 += 運算子
    Matrix<T>&   operator += ( const Matrix<T>& rhs ) {
        int  i , j ;
        for ( i = 0 ; i < row ; ++i ) {
            for ( j = 0 ; j < col ; ++j ) mat[i][j] += rhs[i][j] ;
        }
        return  *this ;
    }

    // 重載 -= 運算子
    Matrix<T>&   operator -= ( const Matrix<T>& rhs ) {
        int  i , j ;
        for ( i = 0 ; i < row ; ++i ) {
            for ( j = 0 ; j < col ; ++j ) mat[i][j] -= rhs[i][j] ;
        }
        return  *this ;
    }
    
};


// 定義矩陣的輸出運算子
template <class T>
ostream&  operator << ( ostream& out , const Matrix<T>& m ) {

    int  i , j ;
    for ( i = 0 ; i < m.rows() ; ++i ) {
        out << '\n' ;
        for ( j = 0 ; j < m.cols() ; ++j ) {
            out << ( abs(m[i][j]) < SMALL ? static_cast<T>(0) : m[i][j] ) << "  " ;
        }
    }
    return  out << '\n' ;
}

// 重載矩陣乘法運算子
template  <class T>
Matrix<T>  operator * ( const Matrix<T>& m1 , 
                        const Matrix<T>& m2 ) {

    T          sum ;
    int        i , j , k ;
    Matrix<T>  m(m1.rows(),m2.cols()) ;
 
    for ( i = 0 ; i < m1.rows() ; ++i ) {
        for ( j = 0 ; j < m2.cols() ; ++j ) {
            sum = 0 ;
            for ( k = 0 ; k < m1.cols() ; ++k ) 
                sum += m1[i][k] * m2[k][j] ;
            m[i][j] = sum ;
        }
    }
    return  m ;
}

template <class T>
bool   operator==( const Matrix<T>& m1 ,
                   const Matrix<T>& m2 ) {
    if ( m1.rows() != m2.rows() || m1.cols() != m2.cols() ) return false ;
    int i , j ;
    for ( i = 0 ; i < m1.rows() ; ++i ) 
        for ( j = 0 ; j < m1.cols() ; ++j ) 
            if ( m1[i][j] != m2[i][j] ) return false ;
    return true ;
}

template <class T>
bool   operator!=( const Matrix<T>& m1 ,
                   const Matrix<T>& m2 ) {
    return ! ( m1 == m2 ) ;
}



// 重載矩陣加法運算子
template  <class T>
Matrix<T>  operator + ( const Matrix<T>& m1 , 
                        const Matrix<T>& m2 ) {
    Matrix<T>  m = m1 ;
    return  m += m2 ;
}


// 重載矩陣減法運算子
template  <class T>
Matrix<T>  operator - ( const Matrix<T>& m1 , 
                        const Matrix<T>& m2 ) {
    Matrix<T>  m = m1 ;
    return  m -= m2 ;
}
    

int  main() {

    // A 為 2 x 2 矩陣
    Matrix<double>  A(4,4) ;
    A[0][0] = 4 ;  A[0][1] = 1 ;  A[0][2] = 5 ;  A[0][3] = 0 ;  
    A[1][0] = 0 ;  A[1][1] =-4 ;  A[1][2] = 0 ;  A[1][3] = 3 ;  
    A[2][0] = 3 ;  A[2][1] = 0 ;  A[2][2] = 4 ;  A[2][3] = 0 ;  
    A[3][0] = 2 ;  A[3][1] = 0 ;  A[3][2] = 0 ;  A[3][3] = 9 ;  

    cout << "> A = " << A << '\n' ;
    cout << "> det(A) = " << A.det() << endl ;
    cout << "> A * inverse(A) = " << A * A.inverse() <<  endl ;

    return 0 ;
    
}