001: /*
002: * JScience - Java(TM) Tools and Libraries for the Advancement of Sciences.
003: * Copyright (C) 2006 - JScience (http://jscience.org/)
004: * All rights reserved.
005: *
006: * Permission to use, copy, modify, and distribute this software is
007: * freely granted, provided that this notice is preserved.
008: */
009: package org.jscience.mathematics.vector;
010:
011: import java.util.Comparator;
012:
013: import org.jscience.mathematics.structure.Field;
014: import org.jscience.mathematics.number.Number;
015:
016: import javolution.context.LocalContext;
017: import javolution.context.ObjectFactory;
018: import javolution.util.FastTable;
019: import javolution.util.Index;
020:
021: /**
022: * <p> This class represents the decomposition of a {@link Matrix matrix}
023: * <code>A</code> into a product of a {@link #getLower lower}
024: * and {@link #getUpper upper} triangular matrices, <code>L</code>
025: * and <code>U</code> respectively, such as <code>A = P·L·U<code> with
026: * <code>P<code> a {@link #getPermutation permutation} matrix.</p>
027: *
028: * <p> This decomposition</a> is typically used to resolve linear systems
029: * of equations (Gaussian elimination) or to calculate the determinant
030: * of a square {@link Matrix} (<code>O(m³)</code>).</p>
031: *
032: * <p> Numerical stability is guaranteed through pivoting if the
033: * {@link Field} elements are {@link Number numbers}
034: * For others elements types, numerical stability can be ensured by setting
035: * the {@link javolution.context.LocalContext context-local} pivot
036: * comparator (see {@link #setPivotComparator}).</p>
037: *
038: * <p> Pivoting can be disabled by setting the {@link #setPivotComparator
039: * pivot comparator} to <code>null</code> ({@link #getPermutation P}
040: * is then the matrix identity).</p>
041: *
042: * @author <a href="mailto:jean-marie@dautelle.com">Jean-Marie Dautelle</a>
043: * @version 3.3, January 2, 2007
044: * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">
045: * Wikipedia: LU decomposition</a>
046: */
047: public final class LUDecomposition<F extends Field<F>> {
048:
049: /**
050: * Holds the default comparator for pivoting.
051: */
052: public static final Comparator<Field> NUMERIC_COMPARATOR = new Comparator<Field>() {
053:
054: @SuppressWarnings("unchecked")
055: public int compare(Field left, Field right) {
056: if ((left instanceof Number) && (right instanceof Number))
057: return ((Number) left).isLargerThan((Number) right) ? 1
058: : -1;
059: if (left.equals(left.plus(left))) // Zero
060: return -1;
061: if (right.equals(right.plus(right))) // Zero
062: return 1;
063: return 0;
064: }
065: };
066:
067: /**
068: * Holds the local comparator.
069: */
070: private static final LocalContext.Reference<Comparator<Field>> PIVOT_COMPARATOR = new LocalContext.Reference<Comparator<Field>>(
071: NUMERIC_COMPARATOR);
072:
073: /**
074: * Holds the dimension of the square matrix source.
075: */
076: private int _n;
077:
078: /**
079: * Holds the pivots indexes.
080: */
081: private final FastTable<Index> _pivots = new FastTable<Index>();
082:
083: /**
084: * Holds the LU elements.
085: */
086: private DenseMatrix<F> _LU;
087:
088: /**
089: * Holds the number of permutation performed.
090: */
091: private int _permutationCount;
092:
093: /**
094: * Returns the lower/upper decomposition of the specified matrix.
095: *
096: * @param source the matrix for which the decomposition is calculated.
097: * @return the lower/upper decomposition of the specified matrix.
098: * @throws DimensionException if the specified matrix is not square.
099: */
100: @SuppressWarnings("unchecked")
101: public static <F extends Field<F>> LUDecomposition<F> valueOf(
102: Matrix<F> source) {
103: if (!source.isSquare())
104: throw new DimensionException("Matrix is not square");
105: int dimension = source.getNumberOfRows();
106: LUDecomposition lu = FACTORY.object();
107: lu._n = dimension;
108: lu._permutationCount = 0;
109: lu.construct(source);
110: return lu;
111: }
112:
113: /**
114: * Constructs the LU decomposition of the specified matrix.
115: * We make the choise of Lii = ONE (diagonal elements of the
116: * lower triangular matrix are multiplicative identities).
117: *
118: * @param source the matrix to decompose.
119: * @throws MatrixException if the matrix source is not square.
120: */
121: private void construct(Matrix<F> source) {
122: _LU = source instanceof DenseMatrix ? ((DenseMatrix<F>) source)
123: .copy() : DenseMatrix.valueOf(source);
124: _pivots.clear();
125: for (int i = 0; i < _n; i++) {
126: _pivots.add(Index.valueOf(i));
127: }
128:
129: // Main loop.
130: Comparator<Field> cmp = LUDecomposition.getPivotComparator();
131: final int n = _n;
132: for (int k = 0; k < _n; k++) {
133:
134: if (cmp != null) { // Pivoting enabled.
135: // Rearranges the rows so that the absolutely largest
136: // elements of the matrix source in each column lies
137: // in the diagonal.
138: int pivot = k;
139: for (int i = k + 1; i < n; i++) {
140: if (cmp.compare(_LU.get(i, k), _LU.get(pivot, k)) > 0) {
141: pivot = i;
142: }
143: }
144: if (pivot != k) { // Exchanges.
145: for (int j = 0; j < n; j++) {
146: F tmp = _LU.get(pivot, j);
147: _LU.set(pivot, j, _LU.get(k, j));
148: _LU.set(k, j, tmp);
149: }
150: int j = _pivots.get(pivot).intValue();
151: _pivots.set(pivot, _pivots.get(k));
152: _pivots.set(k, Index.valueOf(j));
153: _permutationCount++;
154: }
155: }
156:
157: // Computes multipliers and eliminate k-th column.
158: F lukkInv = _LU.get(k, k).inverse();
159: for (int i = k + 1; i < n; i++) {
160: // Multiplicative order is important
161: // for non-commutative elements.
162: _LU.set(i, k, _LU.get(i, k).times(lukkInv));
163: for (int j = k + 1; j < n; j++) {
164: _LU.set(i, j, _LU.get(i, j).plus(
165: _LU.get(i, k).times(
166: _LU.get(k, j).opposite())));
167: }
168: }
169: }
170: }
171:
172: /**
173: * Sets the {@link javolution.context.LocalContext local} comparator used
174: * for pivoting or <code>null</code> to disable pivoting.
175: *
176: * @param cmp the comparator for pivoting or <code>null</code>.
177: */
178: public static void setPivotComparator(Comparator<Field> cmp) {
179: PIVOT_COMPARATOR.set(cmp);
180: }
181:
182: /**
183: * Returns the {@link javolution.context.LocalContext local}
184: * comparator used for pivoting or <code>null</code> if pivoting
185: * is not performed (default {@link #NUMERIC_COMPARATOR}).
186: *
187: * @return the comparator for pivoting or <code>null</code>.
188: */
189: public static Comparator<Field> getPivotComparator() {
190: return PIVOT_COMPARATOR.get();
191: }
192:
193: /**
194: * Returns the solution X of the equation: A * X = B with
195: * <code>this = A.lu()</code> using back and forward substitutions.
196: *
197: * @param B the input matrix.
198: * @return the solution X = (1 / A) * B.
199: * @throws DimensionException if the dimensions do not match.
200: */
201: public DenseMatrix<F> solve(Matrix<F> B) {
202: if (_n != B.getNumberOfRows())
203: throw new DimensionException("Input vector has "
204: + B.getNumberOfRows() + " rows instead of " + _n);
205:
206: // Copies B with pivoting.
207: final int n = B.getNumberOfColumns();
208: DenseMatrix<F> X = createNullDenseMatrix(_n, n);
209: for (int i = 0; i < _n; i++) {
210: for (int j = 0; j < n; j++) {
211: X.set(i, j, B.get(_pivots.get(i).intValue(), j));
212: }
213: }
214:
215: // Solves L * Y = pivot(B)
216: for (int k = 0; k < _n; k++) {
217: for (int i = k + 1; i < _n; i++) {
218: F luik = _LU.get(i, k);
219: for (int j = 0; j < n; j++) {
220: X.set(i, j, X.get(i, j).plus(
221: luik.times(X.get(k, j).opposite())));
222: }
223: }
224: }
225:
226: // Solves U * X = Y;
227: for (int k = _n - 1; k >= 0; k--) {
228: for (int j = 0; j < n; j++) {
229: X.set(k, j, (_LU.get(k, k).inverse())
230: .times(X.get(k, j)));
231: }
232: for (int i = 0; i < k; i++) {
233: F luik = _LU.get(i, k);
234: for (int j = 0; j < n; j++) {
235: X.set(i, j, X.get(i, j).plus(
236: luik.times(X.get(k, j).opposite())));
237: }
238: }
239: }
240: return X;
241: }
242:
243: private DenseMatrix<F> createNullDenseMatrix(int m, int n) {
244: DenseMatrix<F> M = DenseMatrix.newInstance(n, false);
245: for (int i = 0; i < m; i++) {
246: DenseVector<F> V = DenseVector.newInstance();
247: M._rows.add(V);
248: for (int j = 0; j < n; j++) {
249: V._elements.add(null);
250: }
251: }
252: return M;
253: }
254:
255: /**
256: * Returns the solution X of the equation: A * X = Identity with
257: * <code>this = A.lu()</code> using back and forward substitutions.
258: *
259: * @return <code>this.solve(Identity)</code>
260: */
261: public DenseMatrix<F> inverse() {
262: // Calculates inv(U).
263: final int n = _n;
264: DenseMatrix<F> R = createNullDenseMatrix(n, n);
265: for (int i = 0; i < n; i++) {
266: for (int j = i; j < n; j++) {
267: R.set(i, j, _LU.get(i, j));
268: }
269: }
270: for (int j = n - 1; j >= 0; j--) {
271: R.set(j, j, R.get(j, j).inverse());
272: for (int i = j - 1; i >= 0; i--) {
273: F sum = R.get(i, j).times(R.get(j, j).opposite());
274: for (int k = j - 1; k > i; k--) {
275: sum = sum.plus(R.get(i, k).times(
276: R.get(k, j).opposite()));
277: }
278: R.set(i, j, (R.get(i, i).inverse()).times(sum));
279: }
280: }
281: // Solves inv(A) * L = inv(U)
282: for (int i = 0; i < n; i++) {
283: for (int j = n - 2; j >= 0; j--) {
284: for (int k = j + 1; k < n; k++) {
285: F lukj = _LU.get(k, j);
286: if (R.get(i, j) != null) {
287: R.set(i, j, R.get(i, j).plus(
288: R.get(i, k).times(lukj.opposite())));
289: } else {
290: R.set(i, j, R.get(i, k).times(lukj.opposite()));
291: }
292: }
293: }
294: }
295: // Swaps columns (reverses pivots permutations).
296: FastTable<F> tmp = FastTable.newInstance();
297: for (int i = 0; i < n; i++) {
298: tmp.reset();
299: for (int j = 0; j < n; j++) {
300: tmp.add(R.get(i, j));
301: }
302: for (int j = 0; j < n; j++) {
303: R.set(i, _pivots.get(j).intValue(), tmp.get(j));
304: }
305: }
306: FastTable.recycle(tmp);
307: return R;
308: }
309:
310: /**
311: * Returns the determinant of the {@link Matrix} having this
312: * decomposition.
313: *
314: * @return the determinant of the matrix source.
315: */
316: public F determinant() {
317: F product = _LU.get(0, 0);
318: for (int i = 1; i < _n; i++) {
319: product = product.times(_LU.get(i, i));
320: }
321: return ((_permutationCount & 1) == 0) ? product : product
322: .opposite();
323: }
324:
325: /**
326: * Returns the lower matrix decomposition (<code>L</code>) with diagonal
327: * elements equal to the multiplicative identity for F.
328: *
329: * @param zero the additive identity for F.
330: * @param one the multiplicative identity for F.
331: * @return the lower matrix.
332: */
333: public DenseMatrix<F> getLower(F zero, F one) {
334: DenseMatrix<F> L = _LU.copy();
335: for (int j = 0; j < _n; j++) {
336: for (int i = 0; i < j; i++) {
337: L.set(i, j, zero);
338: }
339: L.set(j, j, one);
340: }
341: return L;
342: }
343:
344: /**
345: * Returns the upper matrix decomposition (<code>U</code>).
346: *
347: * @param zero the additive identity for F.
348: * @return the upper matrix.
349: */
350: public DenseMatrix<F> getUpper(F zero) {
351: DenseMatrix<F> U = _LU.copy();
352: for (int j = 0; j < _n; j++) {
353: for (int i = j + 1; i < _n; i++) {
354: U.set(i, j, zero);
355: }
356: }
357: return U;
358: }
359:
360: /**
361: * Returns the permutation matrix (<code>P</code>).
362: *
363: * @param zero the additive identity for F.
364: * @param one the multiplicative identity for F.
365: * @return the permutation matrix.
366: */
367: public SparseMatrix<F> getPermutation(F zero, F one) {
368: SparseMatrix<F> P = SparseMatrix.newInstance(_n, zero, false);
369: for (int i = 0; i < _n; i++) {
370: P.getRow(_pivots.get(i).intValue())._elements.put(Index
371: .valueOf(i), one);
372: }
373: return P;
374: }
375:
376: /**
377: * Returns the lower/upper decomposition in one single matrix.
378: *
379: * @return the lower/upper matrix merged in a single matrix.
380: */
381: public DenseMatrix<F> getLU() {
382: return _LU;
383: }
384:
385: /**
386: * Returns the pivots elements of this decomposition.
387: *
388: * @return the row indices after permutation.
389: */
390: public FastTable<Index> getPivots() {
391: return _pivots;
392: }
393:
394: ///////////////////////
395: // Factory creation. //
396: ///////////////////////
397:
398: private static final ObjectFactory<LUDecomposition> FACTORY = new ObjectFactory<LUDecomposition>() {
399: protected LUDecomposition create() {
400: return new LUDecomposition();
401: }
402:
403: @SuppressWarnings("unchecked")
404: protected void cleanup(LUDecomposition lu) {
405: lu._LU = null;
406: }
407: };
408:
409: private LUDecomposition() {
410: }
411:
412: }
|