1   /*
2    *
3    * Licensed to the Apache Software Foundation (ASF) under one
4    * or more contributor license agreements.  See the NOTICE file
5    * distributed with this work for additional information
7    * to you under the Apache License, Version 2.0 (the
8    * "License"); you may not use this file except in compliance
9    * with the License.  You may obtain a copy of the License at
10   *
12   *
13   * Unless required by applicable law or agreed to in writing, software
15   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16   * See the License for the specific language governing permissions and
17   * limitations under the License.
18   */
20
21  import java.util.Arrays;
22  import java.util.Deque;
24
26
27  /**
28   * Computes the optimal (minimal cost) assignment of jobs to workers (or other
29   * analogous) concepts given a cost matrix of each pair of job and worker, using
30   * the algorithm by James Munkres in "Algorithms for the Assignment and
31   * Transportation Problems", with additional optimizations as described by Jin
32   * Kue Wong in "A New Implementation of an Algorithm for the Optimal Assignment
33   * Problem: An Improved Version of Munkres' Algorithm". The algorithm runs in
34   * O(n^3) time and need O(n^2) auxiliary space where n is the number of jobs or
35   * workers, whichever is greater.
36   */
37  @InterfaceAudience.Private
38  public class MunkresAssignment {
39
40    // The original algorithm by Munkres uses the terms STAR and PRIME to denote
41    // different states of zero values in the cost matrix. These values are
42    // represented as byte constants instead of enums to save space in the mask
43    // matrix by a factor of 4n^2 where n is the size of the problem.
44    private static final byte NONE = 0;
45    private static final byte STAR = 1;
46    private static final byte PRIME = 2;
47
48    // The algorithm requires that the number of column is at least as great as
49    // the number of rows. If that is not the case, then the cost matrix should
50    // be transposed before computation, and the solution matrix transposed before
51    // returning to the caller.
52    private final boolean transposed;
53
54    // The number of rows of internal matrices.
55    private final int rows;
56
57    // The number of columns of internal matrices.
58    private final int cols;
59
60    // The cost matrix, the cost of assigning each row index to column index.
61    private float[][] cost;
62
63    // Mask of zero cost assignment states.
65
66    // Covering some rows of the cost matrix.
67    private boolean[] rowsCovered;
68
69    // Covering some columns of the cost matrix.
70    private boolean[] colsCovered;
71
72    // The alternating path between starred zeroes and primed zeroes
73    private Deque<Pair<Integer, Integer>> path;
74
75    // The solution, marking which rows should be assigned to which columns. The
76    // positions of elements in this array correspond to the rows of the cost
77    // matrix, and the value of each element correspond to the columns of the cost
78    // matrix, i.e. assignments[i] = j indicates that row i should be assigned to
79    // column j.
80    private int[] assignments;
81
82    // Improvements described by Jin Kue Wong cache the least value in each row,
83    // as well as the column index of the least value in each row, and the pending
84    // adjustments to each row and each column.
85    private float[] leastInRow;
86    private int[] leastInRowIndex;
89
90    /**
91     * Construct a new problem instance with the specified cost matrix. The cost
92     * matrix must be rectangular, though not necessarily square. If one dimension
93     * is greater than the other, some elements in the greater dimension will not
94     * be assigned. The input cost matrix will not be modified.
95     * @param costMatrix
96     */
97    public MunkresAssignment(float[][] costMatrix) {
98      // The algorithm assumes that the number of columns is at least as great as
99      // the number of rows. If this is not the case of the input matrix, then
100     // all internal structures must be transposed relative to the input.
101     this.transposed = costMatrix.length > costMatrix[0].length;
102     if (this.transposed) {
103       this.rows = costMatrix[0].length;
104       this.cols = costMatrix.length;
105     } else {
106       this.rows = costMatrix.length;
107       this.cols = costMatrix[0].length;
108     }
109
110     cost = new float[rows][cols];
112     rowsCovered = new boolean[rows];
113     colsCovered = new boolean[cols];
114     path = new LinkedList<Pair<Integer, Integer>>();
115
116     leastInRow = new float[rows];
117     leastInRowIndex = new int[rows];
120
121     assignments = null;
122
123     // Copy cost matrix.
124     if (transposed) {
125       for (int r = 0; r < rows; r++) {
126         for (int c = 0; c < cols; c++) {
127           cost[r][c] = costMatrix[c][r];
128         }
129       }
130     } else {
131       for (int r = 0; r < rows; r++) {
132         System.arraycopy(costMatrix[r], 0, cost[r], 0, cols);
133       }
134     }
135
136     // Costs must be finite otherwise the matrix can get into a bad state where
137     // no progress can be made. If your use case depends on a distinction
138     // between costs of MAX_VALUE and POSITIVE_INFINITY, you're doing it wrong.
139     for (int r = 0; r < rows; r++) {
140       for (int c = 0; c < cols; c++) {
141         if (cost[r][c] == Float.POSITIVE_INFINITY) {
142           cost[r][c] = Float.MAX_VALUE;
143         }
144       }
145     }
146   }
147
148   /**
149    * Get the optimal assignments. The returned array will have the same number
150    * of elements as the number of elements as the number of rows in the input
151    * cost matrix. Each element will indicate which column should be assigned to
152    * that row or -1 if no column should be assigned, i.e. if result[i] = j then
153    * row i should be assigned to column j. Subsequent invocations of this method
154    * will simply return the same object without additional computation.
155    * @return an array with the optimal assignments
156    */
157   public int[] solve() {
158     // If this assignment problem has already been solved, return the known
159     // solution
160     if (assignments != null) {
161       return assignments;
162     }
163
164     preliminaries();
165
166     // Find the optimal assignments.
167     while (!testIsDone()) {
168       while (!stepOne()) {
169         stepThree();
170       }
171       stepTwo();
172     }
173
174     // Extract the assignments from the mask matrix.
175     if (transposed) {
176       assignments = new int[cols];
177       outer:
178       for (int c = 0; c < cols; c++) {
179         for (int r = 0; r < rows; r++) {
180           if (mask[r][c] == STAR) {
181             assignments[c] = r;
182             continue outer;
183           }
184         }
185         // There is no assignment for this row of the input/output.
186         assignments[c] = -1;
187       }
188     } else {
189       assignments = new int[rows];
190       outer:
191       for (int r = 0; r < rows; r++) {
192         for (int c = 0; c < cols; c++) {
193           if (mask[r][c] == STAR) {
194             assignments[r] = c;
195             continue outer;
196           }
197         }
198       }
199     }
200
201     // Once the solution has been computed, there is no need to keep any of the
202     // other internal structures. Clear all unnecessary internal references so
203     // the garbage collector may reclaim that memory.
204     cost = null;
206     rowsCovered = null;
207     colsCovered = null;
208     path = null;
209     leastInRow = null;
210     leastInRowIndex = null;
213
214     return assignments;
215   }
216
217   /**
218    * Corresponds to the "preliminaries" step of the original algorithm.
219    * Guarantees that the matrix is an equivalent non-negative matrix with at
220    * least one zero in each row.
221    */
222   private void preliminaries() {
223     for (int r = 0; r < rows; r++) {
224       // Find the minimum cost of each row.
225       float min = Float.POSITIVE_INFINITY;
226       for (int c = 0; c < cols; c++) {
227         min = Math.min(min, cost[r][c]);
228       }
229
230       // Subtract that minimum cost from each element in the row.
231       for (int c = 0; c < cols; c++) {
232         cost[r][c] -= min;
233
234         // If the element is now zero and there are no zeroes in the same row
235         // or column which are already starred, then star this one. There
236         // must be at least one zero because of subtracting the min cost.
237         if (cost[r][c] == 0 && !rowsCovered[r] && !colsCovered[c]) {
239           // Cover this row and column so that no other zeroes in them can be
240           // starred.
241           rowsCovered[r] = true;
242           colsCovered[c] = true;
243         }
244       }
245     }
246
247     // Clear the covered rows and columns.
248     Arrays.fill(rowsCovered, false);
249     Arrays.fill(colsCovered, false);
250   }
251
252   /**
253    * Test whether the algorithm is done, i.e. we have the optimal assignment.
254    * This occurs when there is exactly one starred zero in each row.
255    * @return true if the algorithm is done
256    */
257   private boolean testIsDone() {
258     // Cover all columns containing a starred zero. There can be at most one
259     // starred zero per column. Therefore, a covered column has an optimal
260     // assignment.
261     for (int r = 0; r < rows; r++) {
262       for (int c = 0; c < cols; c++) {
263         if (mask[r][c] == STAR) {
264           colsCovered[c] = true;
265         }
266       }
267     }
268
269     // Count the total number of covered columns.
270     int coveredCols = 0;
271     for (int c = 0; c < cols; c++) {
272       coveredCols += colsCovered[c] ? 1 : 0;
273     }
274
275     // Apply an row and column adjustments that are pending.
276     for (int r = 0; r < rows; r++) {
277       for (int c = 0; c < cols; c++) {
280       }
281     }
282
283     // Clear the pending row and column adjustments.
286
287     // The covers on columns and rows may have been reset, recompute the least
288     // value for each row.
289     for (int r = 0; r < rows; r++) {
290       leastInRow[r] = Float.POSITIVE_INFINITY;
291       for (int c = 0; c < cols; c++) {
292         if (!rowsCovered[r] && !colsCovered[c] && cost[r][c] < leastInRow[r]) {
293           leastInRow[r] = cost[r][c];
294           leastInRowIndex[r] = c;
295         }
296       }
297     }
298
299     // If all columns are covered, then we are done. Since there may be more
300     // columns than rows, we are also done if the number of covered columns is
301     // at least as great as the number of rows.
302     return (coveredCols == cols || coveredCols >= rows);
303   }
304
305   /**
306    * Corresponds to step 1 of the original algorithm.
307    * @return false if all zeroes are covered
308    */
309   private boolean stepOne() {
310     while (true) {
311       Pair<Integer, Integer> zero = findUncoveredZero();
312       if (zero == null) {
313         // No uncovered zeroes, need to manipulate the cost matrix in step
314         // three.
315         return false;
316       } else {
317         // Prime the uncovered zero and find a starred zero in the same row.
319         Pair<Integer, Integer> star = starInRow(zero.getFirst());
320         if (star != null) {
321           // Cover the row with both the newly primed zero and the starred zero.
322           // Since this is the only place where zeroes are primed, and we cover
323           // it here, and rows are only uncovered when primes are erased, then
324           // there can be at most one primed uncovered zero.
325           rowsCovered[star.getFirst()] = true;
326           colsCovered[star.getSecond()] = false;
327           updateMin(star.getFirst(), star.getSecond());
328         } else {
329           // Will go to step two after, where a path will be constructed,
330           // starting from the uncovered primed zero (there is only one). Since
331           // we have already found it, save it as the first node in the path.
332           path.clear();
333           path.offerLast(new Pair<Integer, Integer>(zero.getFirst(),
334               zero.getSecond()));
335           return true;
336         }
337       }
338     }
339   }
340
341   /**
342    * Corresponds to step 2 of the original algorithm.
343    */
344   private void stepTwo() {
345     // Construct a path of alternating starred zeroes and primed zeroes, where
346     // each starred zero is in the same column as the previous primed zero, and
347     // each primed zero is in the same row as the previous starred zero. The
348     // path will always end in a primed zero.
349     while (true) {
350       Pair<Integer, Integer> star = starInCol(path.getLast().getSecond());
351       if (star != null) {
352         path.offerLast(star);
353       } else {
354         break;
355       }
356       Pair<Integer, Integer> prime = primeInRow(path.getLast().getFirst());
357       path.offerLast(prime);
358     }
359
360     // Augment path - unmask all starred zeroes and star all primed zeroes. All
361     // nodes in the path will be either starred or primed zeroes. The set of
362     // starred zeroes is independent and now one larger than before.
363     for (Pair<Integer, Integer> p : path) {
364       if (mask[p.getFirst()][p.getSecond()] == STAR) {
366       } else {
368       }
369     }
370
371     // Clear all covers from rows and columns.
372     Arrays.fill(rowsCovered, false);
373     Arrays.fill(colsCovered, false);
374
375     // Remove the prime mask from all primed zeroes.
376     for (int r = 0; r < rows; r++) {
377       for (int c = 0; c < cols; c++) {
378         if (mask[r][c] == PRIME) {
380         }
381       }
382     }
383   }
384
385   /**
386    * Corresponds to step 3 of the original algorithm.
387    */
388   private void stepThree() {
389     // Find the minimum uncovered cost.
390     float min = leastInRow[0];
391     for (int r = 1; r < rows; r++) {
392       if (leastInRow[r] < min) {
393         min = leastInRow[r];
394       }
395     }
396
397     // Add the minimum cost to each of the costs in a covered row, or subtract
398     // the minimum cost from each of the costs in an uncovered column. As an
399     // optimization, do not actually modify the cost matrix yet, but track the
400     // adjustments that need to be made to each row and column.
401     for (int r = 0; r < rows; r++) {
402       if (rowsCovered[r]) {
404       }
405     }
406     for (int c = 0; c < cols; c++) {
407       if (!colsCovered[c]) {
409       }
410     }
411
412     // Since the cost matrix is not being updated yet, the minimum uncovered
413     // cost per row must be updated.
414     for (int r = 0; r < rows; r++) {
415       if (!colsCovered[leastInRowIndex[r]]) {
416         // The least value in this row was in an uncovered column, meaning that
417         // it would have had the minimum value subtracted from it, and therefore
418         // will still be the minimum value in that row.
419         leastInRow[r] -= min;
420       } else {
421         // The least value in this row was in a covered column and would not
422         // have had the minimum value subtracted from it, so the minimum value
423         // could be some in another column.
424         for (int c = 0; c < cols; c++) {
427             leastInRowIndex[r] = c;
428           }
429         }
430       }
431     }
432   }
433
434   /**
435    * Find a zero cost assignment which is not covered. If there are no zero cost
436    * assignments which are uncovered, then null will be returned.
437    * @return pair of row and column indices of an uncovered zero or null
438    */
439   private Pair<Integer, Integer> findUncoveredZero() {
440     for (int r = 0; r < rows; r++) {
441       if (leastInRow[r] == 0) {
442         return new Pair<Integer, Integer>(r, leastInRowIndex[r]);
443       }
444     }
445     return null;
446   }
447
448   /**
449    * A specified row has become covered, and a specified column has become
450    * uncovered. The least value per row may need to be updated.
451    * @param row the index of the row which was just covered
452    * @param col the index of the column which was just uncovered
453    */
454   private void updateMin(int row, int col) {
455     // If the row is covered we want to ignore it as far as least values go.
456     leastInRow[row] = Float.POSITIVE_INFINITY;
457
458     for (int r = 0; r < rows; r++) {
459       // Since the column has only just been uncovered, it could not have any
461       // and covered costs do not count toward row minimums. Therefore, we do
463       if (!rowsCovered[r] && cost[r][col] < leastInRow[r]) {
464         leastInRow[r] = cost[r][col];
465         leastInRowIndex[r] = col;
466       }
467     }
468   }
469
470   /**
471    * Find a starred zero in a specified row. If there are no starred zeroes in
472    * the specified row, then null will be returned.
473    * @param r the index of the row to be searched
474    * @return pair of row and column indices of starred zero or null
475    */
476   private Pair<Integer, Integer> starInRow(int r) {
477     for (int c = 0; c < cols; c++) {
478       if (mask[r][c] == STAR) {
479         return new Pair<Integer, Integer>(r, c);
480       }
481     }
482     return null;
483   }
484
485   /**
486    * Find a starred zero in the specified column. If there are no starred zeroes
487    * in the specified row, then null will be returned.
488    * @param c the index of the column to be searched
489    * @return pair of row and column indices of starred zero or null
490    */
491   private Pair<Integer, Integer> starInCol(int c) {
492     for (int r = 0; r < rows; r++) {
493       if (mask[r][c] == STAR) {
494         return new Pair<Integer, Integer>(r, c);
495       }
496     }
497     return null;
498   }
499
500   /**
501    * Find a primed zero in the specified row. If there are no primed zeroes in
502    * the specified row, then null will be returned.
503    * @param r the index of the row to be searched
504    * @return pair of row and column indices of primed zero or null
505    */
506   private Pair<Integer, Integer> primeInRow(int r) {
507     for (int c = 0; c < cols; c++) {
508       if (mask[r][c] == PRIME) {
509         return new Pair<Integer, Integer>(r, c);
510       }
511     }
512     return null;
513   }
514 }