MatrixMath library for Arduino
by Charlie Matlack
contact: eecharlie in Arduino forums
This library was modified from code posted by RobH45345,
notably including replacement of the inversion algorithm.

The version of the library here (updated 4/3/2013) is patched for all Arduino library versions. A patch to make the original library workable for the DUE is discussed here

FEATURES

Briefly, the functions provided by MatrixMath:

void MatrixPrint(float* A, int m, int n, String label);
void MatrixCopy(float* A, int n, int m, float* B);
void MatrixMult(float* A, float* B, int m, int p, int n, float* C);
void MatrixAdd(float* A, float* B, int m, int n, float* C);
void MatrixSubtract(float* A, float* B, int m, int n, float* C);
void MatrixTranspose(float* A, int m, int n, float* C);
int MatrixInvert(float* A, int n);

Matrices should be stored in row-major arrays, which is fairly standard. The user must keep track of array dimensions and send them to the functions; mistakes on dimensions will not be caught by the library.

It's worth pointing out that the MatrixInvert() function uses Gauss-Jordan elimination with partial pivoting. Partial pivoting is a compromise between a numerically unstable algorithm and full pivoting, which involves more searching and swapping matrix elements.

Also, the inversion algorithm stores the result matrix on top of the the input matrix, meaning no extra memory is allocated during inversion but your original matrix is gone.

HOW TO IMPORT/INSTALL

Grab the source code below, and put in a folder called MatrixMath. [Apparently recent Playground changes have killed the zip file. Please edit this if there is a preferred way of making the source available.]

Put the MatrixMath folder in "libraries\".

In the Arduino IDE, create a new sketch (or open one) and

select from the menubar "Sketch->Import Library->MatrixMath".

Once the library is imported, a "#include MatrixMath.h" line will appear at the top of your Sketch.

The MatrixMathExample in the Examples folder demonstrates multiplication and inversion using the MatrixPrint() function to show results.

SOURCE CODE

MatrixMath.h

  1. /*
  2.  *  MatrixMath.h Library for Matrix Math
  3.  *
  4.  *  Created by Charlie Matlack on 12/18/10.
  5.  *  Modified from code by RobH45345 on Arduino Forums, taken from unknown source.
  6.  */
  7.  
  8. #ifndef MatrixMath_h
  9. #define MatrixMath_h
  10.  
  11. #if defined(ARDUINO) && ARDUINO >= 100
  12. #include "Arduino.h"
  13. #else
  14. #include "WProgram.h"
  15. #endif
  16.  
  17. class MatrixMath
  18. {
  19. public:
  20.     //MatrixMath();
  21.     void Print(float* A, int m, int n, String label);
  22.     void Copy(float* A, int n, int m, float* B);
  23.     void Multiply(float* A, float* B, int m, int p, int n, float* C);
  24.     void Add(float* A, float* B, int m, int n, float* C);
  25.     void Subtract(float* A, float* B, int m, int n, float* C);
  26.     void Transpose(float* A, int m, int n, float* C);
  27.     void Scale(float* A, int m, int n, float k);
  28.     int Invert(float* A, int n);
  29. };
  30.  
  31. extern MatrixMath Matrix;
  32. #endif

MatrixMath.cpp

  1. /*
  2.  *  MatrixMath.cpp Library for Matrix Math
  3.  *
  4.  *  Created by Charlie Matlack on 12/18/10.
  5.  *  Modified from code by RobH45345 on Arduino Forums, taken from unknown source.
  6.  *  
  7.  */
  8.  
  9. #include "MatrixMath.h"
  10.  
  11. #define NR_END 1
  12.  
  13. MatrixMath Matrix;          // Pre-instantiate
  14.  
  15. // Matrix Printing Routine
  16. // Uses tabs to separate numbers under assumption printed float width won't cause problems
  17. void MatrixMath::Print(float* A, int m, int n, String label){
  18.     // A = input matrix (m x n)
  19.     int i,j;
  20.     Serial.println();
  21.     Serial.println(label);
  22.     for (i=0; i<m; i++){
  23.         for (j=0;j<n;j++){
  24.             Serial.print(A[n*i+j]);
  25.             Serial.print("\t");
  26.         }
  27.         Serial.println();
  28.     }
  29. }
  30.  
  31. void MatrixMath::Copy(float* A, int n, int m, float* B)
  32. {
  33.     int i, j, k;
  34.     for (i=0;i<m;i++)
  35.         for(j=0;j<n;j++)
  36.         {
  37.             B[n*i+j] = A[n*i+j];
  38.         }
  39. }
  40.  
  41. //Matrix Multiplication Routine
  42. // C = A*B
  43. void MatrixMath::Multiply(float* A, float* B, int m, int p, int n, float* C)
  44. {
  45.     // A = input matrix (m x p)
  46.     // B = input matrix (p x n)
  47.     // m = number of rows in A
  48.     // p = number of columns in A = number of rows in B
  49.     // n = number of columns in B
  50.     // C = output matrix = A*B (m x n)
  51.     int i, j, k;
  52.     for (i=0;i<m;i++)
  53.         for(j=0;j<n;j++)
  54.         {
  55.             C[n*i+j]=0;
  56.             for (k=0;k<p;k++)
  57.                 C[n*i+j]= C[n*i+j]+A[p*i+k]*B[n*k+j];
  58.         }
  59. }
  60.  
  61.  
  62. //Matrix Addition Routine
  63. void MatrixMath::Add(float* A, float* B, int m, int n, float* C)
  64. {
  65.     // A = input matrix (m x n)
  66.     // B = input matrix (m x n)
  67.     // m = number of rows in A = number of rows in B
  68.     // n = number of columns in A = number of columns in B
  69.     // C = output matrix = A+B (m x n)
  70.     int i, j;
  71.     for (i=0;i<m;i++)
  72.         for(j=0;j<n;j++)
  73.             C[n*i+j]=A[n*i+j]+B[n*i+j];
  74. }
  75.  
  76.  
  77. //Matrix Subtraction Routine
  78. void MatrixMath::Subtract(float* A, float* B, int m, int n, float* C)
  79. {
  80.     // A = input matrix (m x n)
  81.     // B = input matrix (m x n)
  82.     // m = number of rows in A = number of rows in B
  83.     // n = number of columns in A = number of columns in B
  84.     // C = output matrix = A-B (m x n)
  85.     int i, j;
  86.     for (i=0;i<m;i++)
  87.         for(j=0;j<n;j++)
  88.             C[n*i+j]=A[n*i+j]-B[n*i+j];
  89. }
  90.  
  91.  
  92. //Matrix Transpose Routine
  93. void MatrixMath::Transpose(float* A, int m, int n, float* C)
  94. {
  95.     // A = input matrix (m x n)
  96.     // m = number of rows in A
  97.     // n = number of columns in A
  98.     // C = output matrix = the transpose of A (n x m)
  99.     int i, j;
  100.     for (i=0;i<m;i++)
  101.         for(j=0;j<n;j++)
  102.             C[m*j+i]=A[n*i+j];
  103. }
  104.  
  105. void MatrixMath::Scale(float* A, int m, int n, float k)
  106. {
  107.     for (int i=0; i<m; i++)
  108.         for (int j=0; j<n; j++)
  109.             A[n*i+j] = A[n*i+j]*k;
  110. }
  111.  
  112.  
  113. //Matrix Inversion Routine
  114. // * This function inverts a matrix based on the Gauss Jordan method.
  115. // * Specifically, it uses partial pivoting to improve numeric stability.
  116. // * The algorithm is drawn from those presented in
  117. //   NUMERICAL RECIPES: The Art of Scientific Computing.
  118. // * The function returns 1 on success, 0 on failure.
  119. // * NOTE: The argument is ALSO the result matrix, meaning the input matrix is REPLACED
  120. int MatrixMath::Invert(float* A, int n)
  121. {
  122.     // A = input matrix AND result matrix
  123.     // n = number of rows = number of columns in A (n x n)
  124.     int pivrow;     // keeps track of current pivot row
  125.     int k,i,j;      // k: overall index along diagonal; i: row index; j: col index
  126.     int pivrows[n]; // keeps track of rows swaps to undo at end
  127.     float tmp;      // used for finding max value and making column swaps
  128.  
  129.     for (k = 0; k < n; k++)
  130.     {
  131.         // find pivot row, the row with biggest entry in current column
  132.         tmp = 0;
  133.         for (i = k; i < n; i++)
  134.         {
  135.             if (abs(A[i*n+k]) >= tmp)   // 'Avoid using other functions inside abs()?'
  136.             {
  137.                 tmp = abs(A[i*n+k]);
  138.                 pivrow = i;
  139.             }
  140.         }
  141.  
  142.         // check for singular matrix
  143.         if (A[pivrow*n+k] == 0.0f)
  144.         {
  145.             Serial.println("Inversion failed due to singular matrix");
  146.             return 0;
  147.         }
  148.  
  149.         // Execute pivot (row swap) if needed
  150.         if (pivrow != k)
  151.         {
  152.             // swap row k with pivrow
  153.             for (j = 0; j < n; j++)
  154.             {
  155.                 tmp = A[k*n+j];
  156.                 A[k*n+j] = A[pivrow*n+j];
  157.                 A[pivrow*n+j] = tmp;
  158.             }
  159.         }
  160.         pivrows[k] = pivrow;    // record row swap (even if no swap happened)
  161.  
  162.         tmp = 1.0f/A[k*n+k];    // invert pivot element
  163.         A[k*n+k] = 1.0f;        // This element of input matrix becomes result matrix
  164.  
  165.         // Perform row reduction (divide every element by pivot)
  166.         for (j = 0; j < n; j++)
  167.         {
  168.             A[k*n+j] = A[k*n+j]*tmp;
  169.         }
  170.  
  171.         // Now eliminate all other entries in this column
  172.         for (i = 0; i < n; i++)
  173.         {
  174.             if (i != k)
  175.             {
  176.                 tmp = A[i*n+k];
  177.                 A[i*n+k] = 0.0f;  // The other place where in matrix becomes result mat
  178.                 for (j = 0; j < n; j++)
  179.                 {
  180.                     A[i*n+j] = A[i*n+j] - A[k*n+j]*tmp;
  181.                 }
  182.             }
  183.         }
  184.     }
  185.  
  186.     // Done, now need to undo pivot row swaps by doing column swaps in reverse order
  187.     for (k = n-1; k >= 0; k--)
  188.     {
  189.         if (pivrows[k] != k)
  190.         {
  191.             for (i = 0; i < n; i++)
  192.             {
  193.                 tmp = A[i*n+k];
  194.                 A[i*n+k] = A[i*n+pivrows[k]];
  195.                 A[i*n+pivrows[k]] = tmp;
  196.             }
  197.         }
  198.     }
  199.     return 1;
  200. }
  201.  

Example code demonstrating usage:

  1.  
  2. #include <MatrixMath.h>
  3.  
  4.  
  5. #define N  (2)
  6.  
  7. float A[N][N];
  8. float B[N][N];
  9. float C[N][N];
  10. float v[N];      // This is a row vector
  11. float w[N];
  12.  
  13. float max = 10;  // maximum random matrix entry range
  14.  
  15. void setup() {
  16.     Serial.begin(9600);
  17.  
  18.         // Initialize matrices
  19.         for (int i = 0; i < N; i++)
  20.         {
  21.           v[i] = i+1;                    // vector of sequential numbers
  22.           for (int j = 0; j < N; j++)
  23.           {
  24.             A[i][j] = random(max) - max/2.0f;  // A is random
  25.             if (i == j)
  26.             {
  27.               B[i][j] = 1.0f;                  // B is identity
  28.             } else
  29.             {
  30.               B[i][j] = 0.0f;
  31.             }
  32.           }
  33.         }
  34.  
  35. }
  36.  
  37. void loop(){
  38.  
  39.   Matrix.Multiply((float*)A,(float*)B,N,N,N,(float*)C);
  40.  
  41.         Serial.println("\nAfter multiplying C = A*B:");
  42.     Matrix.Print((float*)A,N,N,"A");
  43.  
  44.     Matrix.Print((float*)B,N,N,"B");
  45.     Matrix.Print((float*)C,N,N,"C");
  46.         Matrix.Print((float*)v,N,1,"v");
  47.  
  48.         Matrix.Add((float*) B, (float*) C, N, N, (float*) C);
  49.         Serial.println("\nC = B+C (addition in-place)");
  50.         Matrix.Print((float*)C,N,N,"C");
  51.         Matrix.Print((float*)B,N,N,"B");
  52.  
  53.         Matrix.Copy((float*)A,N,N,(float*)B);
  54.         Serial.println("\nCopied A to B:");
  55.     Matrix.Print((float*)B,N,N,"B");
  56.  
  57.         Matrix.Invert((float*)A,N);
  58.         Serial.println("\nInverted A:");
  59.     Matrix.Print((float*)A,N,N,"A");
  60.  
  61.         Matrix.Multiply((float*)A,(float*)B,N,N,N,(float*)C);
  62.         Serial.println("\nC = A*B");
  63.     Matrix.Print((float*)C,N,N,"C");
  64.  
  65.         // Because the library uses pointers and DIY indexing,
  66.         // a 1D vector can be smoothly handled as either a row or col vector
  67.         // depending on the dimensions we specify when calling a function
  68.         Matrix.Multiply((float*)C,(float*)v,N,N,1,(float*)w);
  69.         Serial.println("\n C*v = w:");
  70.         Matrix.Print((float*)v,N,1,"v");
  71.         Matrix.Print((float*)w,N,1,"w");
  72.  
  73. while(1);
  74. }

Share