匈牙利算法是一种组合优化算法,可以在多项式时间内解决分配问题(assignment problem)。

该算法由Harold W.Kuhn提出,1955年发表在Naval Research Logistics Quartely期刊上——The Hungarian Method for the assignment problem。

该算法的解释B站有一个老师(北京化工大学公开课——最优化方法)讲的非常好,对算法的原理步骤讲的很清晰,并辅助了具体的例子把算法手算了一遍,非常有助于理解。

下面是算法的python实现:

#!/usr/bin/python
"""
Implementation of the Hungarian (Munkres) Algorithm using Python and NumPy
References: http://www.ams.jhu.edu/~castello/362/Handouts/hungarian.pdfhttp://weber.ucsd.edu/~vcrawfor/hungar.pdfhttp://en.wikipedia.org/wiki/Hungarian_algorithmhttp://www.public.iastate.edu/~ddoty/HungarianAlgorithm.htmlhttp://www.clapper.org/software/python/munkres/
"""# Module Information.
__version__ = "1.1.1"
__author__ = "Thom Dedecko"
__url__ = "http://github.com/tdedecko/hungarian-algorithm"
__copyright__ = "(c) 2010 Thom Dedecko"
__license__ = "MIT License"class HungarianError(Exception):pass# Import numpy. Error if fails
try:import numpy as np
except ImportError:raise HungarianError("NumPy is not installed.")class Hungarian:"""Implementation of the Hungarian (Munkres) Algorithm using np.Usage:hungarian = Hungarian(cost_matrix)hungarian.calculate()orhungarian = Hungarian()hungarian.calculate(cost_matrix)Handle Profit matrix:hungarian = Hungarian(profit_matrix, is_profit_matrix=True)orcost_matrix = Hungarian.make_cost_matrix(profit_matrix)The matrix will be automatically padded if it is not square.For that numpy's resize function is used, which automatically adds 0's to any row/column that is addedGet results and total potential after calculation:hungarian.get_results()hungarian.get_total_potential()"""def __init__(self, input_matrix=None, is_profit_matrix=False):"""input_matrix is a List of Lists.input_matrix is assumed to be a cost matrix unless is_profit_matrix is True."""if input_matrix is not None:# Save inputmy_matrix = np.array(input_matrix)self._input_matrix = np.array(input_matrix)self._maxColumn = my_matrix.shape[1]self._maxRow = my_matrix.shape[0]# Adds 0s if any columns/rows are added. Otherwise stays unalteredmatrix_size = max(self._maxColumn, self._maxRow)pad_columns = matrix_size - self._maxRowpad_rows = matrix_size - self._maxColumnmy_matrix = np.pad(my_matrix, ((0,pad_columns),(0,pad_rows)), 'constant', constant_values=(0))# Convert matrix to profit matrix if necessaryif is_profit_matrix:my_matrix = self.make_cost_matrix(my_matrix)self._cost_matrix = my_matrixself._size = len(my_matrix)self._shape = my_matrix.shape# Results from algorithm.self._results = []self._totalPotential = 0else:self._cost_matrix = Nonedef get_results(self):"""Get results after calculation."""return self._resultsdef get_total_potential(self):"""Returns expected value after calculation."""return self._totalPotentialdef calculate(self, input_matrix=None, is_profit_matrix=False):"""Implementation of the Hungarian (Munkres) Algorithm.input_matrix is a List of Lists.input_matrix is assumed to be a cost matrix unless is_profit_matrix is True."""# Handle invalid and new matrix inputs.if input_matrix is None and self._cost_matrix is None:raise HungarianError("Invalid input")elif input_matrix is not None:self.__init__(input_matrix, is_profit_matrix)result_matrix = self._cost_matrix.copy()# Step 1: Subtract row mins from each row.for index, row in enumerate(result_matrix):result_matrix[index] -= row.min()# Step 2: Subtract column mins from each column.for index, column in enumerate(result_matrix.T):result_matrix[:, index] -= column.min()# Step 3: Use minimum number of lines to cover all zeros in the matrix.# If the total covered rows+columns is not equal to the matrix size then adjust matrix and repeat.total_covered = 0while total_covered < self._size:# Find minimum number of lines to cover all zeros in the matrix and find total covered rows and columns.cover_zeros = CoverZeros(result_matrix)covered_rows = cover_zeros.get_covered_rows()covered_columns = cover_zeros.get_covered_columns()total_covered = len(covered_rows) + len(covered_columns)# if the total covered rows+columns is not equal to the matrix size then adjust it by min uncovered num (m).if total_covered < self._size:result_matrix = self._adjust_matrix_by_min_uncovered_num(result_matrix, covered_rows, covered_columns)# Step 4: Starting with the top row, work your way downwards as you make assignments.# Find single zeros in rows or columns.# Add them to final result and remove them and their associated row/column from the matrix.expected_results = min(self._maxColumn, self._maxRow)zero_locations = (result_matrix == 0)while len(self._results) != expected_results:# If number of zeros in the matrix is zero before finding all the results then an error has occurred.if not zero_locations.any():raise HungarianError("Unable to find results. Algorithm has failed.")# Find results and mark rows and columns for deletionmatched_rows, matched_columns = self.__find_matches(zero_locations)# Make arbitrary selectiontotal_matched = len(matched_rows) + len(matched_columns)if total_matched == 0:matched_rows, matched_columns = self.select_arbitrary_match(zero_locations)# Delete rows and columnsfor row in matched_rows:zero_locations[row] = Falsefor column in matched_columns:zero_locations[:, column] = False# Save Resultsself.__set_results(zip(matched_rows, matched_columns))# Calculate total potentialvalue = 0for row, column in self._results:value += self._input_matrix[row, column]self._totalPotential = value@staticmethoddef make_cost_matrix(profit_matrix):"""Converts a profit matrix into a cost matrix.Expects NumPy objects as input."""# subtract profit matrix from a matrix made of the max value of the profit matrixmatrix_shape = profit_matrix.shapeoffset_matrix = np.ones(matrix_shape, dtype=int) * profit_matrix.max()cost_matrix = offset_matrix - profit_matrixreturn cost_matrixdef _adjust_matrix_by_min_uncovered_num(self, result_matrix, covered_rows, covered_columns):"""Subtract m from every uncovered number and add m to every element covered with two lines."""# Calculate minimum uncovered number (m)elements = []for row_index, row in enumerate(result_matrix):if row_index not in covered_rows:for index, element in enumerate(row):if index not in covered_columns:elements.append(element)min_uncovered_num = min(elements)# Add m to every covered elementadjusted_matrix = result_matrixfor row in covered_rows:adjusted_matrix[row] += min_uncovered_numfor column in covered_columns:adjusted_matrix[:, column] += min_uncovered_num# Subtract m from every elementm_matrix = np.ones(self._shape, dtype=int) * min_uncovered_numadjusted_matrix -= m_matrixreturn adjusted_matrixdef __find_matches(self, zero_locations):"""Returns rows and columns with matches in them."""marked_rows = np.array([], dtype=int)marked_columns = np.array([], dtype=int)# Mark rows and columns with matches# Iterate over rowsfor index, row in enumerate(zero_locations):row_index = np.array([index])if np.sum(row) == 1:column_index, = np.where(row)marked_rows, marked_columns = self.__mark_rows_and_columns(marked_rows, marked_columns, row_index,column_index)# Iterate over columnsfor index, column in enumerate(zero_locations.T):column_index = np.array([index])if np.sum(column) == 1:row_index, = np.where(column)marked_rows, marked_columns = self.__mark_rows_and_columns(marked_rows, marked_columns, row_index,column_index)return marked_rows, marked_columns@staticmethoddef __mark_rows_and_columns(marked_rows, marked_columns, row_index, column_index):"""Check if column or row is marked. If not marked then mark it."""new_marked_rows = marked_rowsnew_marked_columns = marked_columnsif not (marked_rows == row_index).any() and not (marked_columns == column_index).any():new_marked_rows = np.insert(marked_rows, len(marked_rows), row_index)new_marked_columns = np.insert(marked_columns, len(marked_columns), column_index)return new_marked_rows, new_marked_columns@staticmethoddef select_arbitrary_match(zero_locations):"""Selects row column combination with minimum number of zeros in it."""# Count number of zeros in row and column combinationsrows, columns = np.where(zero_locations)zero_count = []for index, row in enumerate(rows):total_zeros = np.sum(zero_locations[row]) + np.sum(zero_locations[:, columns[index]])zero_count.append(total_zeros)# Get the row column combination with the minimum number of zeros.indices = zero_count.index(min(zero_count))row = np.array([rows[indices]])column = np.array([columns[indices]])return row, columndef __set_results(self, result_lists):"""Set results during calculation."""# Check if results values are out of bound from input matrix (because of matrix being padded).# Add results to results list.for result in result_lists:row, column = resultif row < self._maxRow and column < self._maxColumn:new_result = (int(row), int(column))self._results.append(new_result)class CoverZeros:"""Use minimum number of lines to cover all zeros in the matrix.Algorithm based on: http://weber.ucsd.edu/~vcrawfor/hungar.pdf"""def __init__(self, matrix):"""Input a matrix and save it as a boolean matrix to designate zero locations.Run calculation procedure to generate results."""# Find zeros in matrixself._zero_locations = (matrix == 0)self._shape = matrix.shape# Choices starts without any choices made.self._choices = np.zeros(self._shape, dtype=bool)self._marked_rows = []self._marked_columns = []# marks rows and columnsself.__calculate()# Draw lines through all unmarked rows and all marked columns.self._covered_rows = list(set(range(self._shape[0])) - set(self._marked_rows))self._covered_columns = self._marked_columnsdef get_covered_rows(self):"""Return list of covered rows."""return self._covered_rowsdef get_covered_columns(self):"""Return list of covered columns."""return self._covered_columnsdef __calculate(self):"""Calculates minimum number of lines necessary to cover all zeros in a matrix.Algorithm based on: http://weber.ucsd.edu/~vcrawfor/hungar.pdf"""while True:# Erase all marks.self._marked_rows = []self._marked_columns = []# Mark all rows in which no choice has been made.for index, row in enumerate(self._choices):if not row.any():self._marked_rows.append(index)# If no marked rows then finish.if not self._marked_rows:return True# Mark all columns not already marked which have zeros in marked rows.num_marked_columns = self.__mark_new_columns_with_zeros_in_marked_rows()# If no new marked columns then finish.if num_marked_columns == 0:return True# While there is some choice in every marked column.while self.__choice_in_all_marked_columns():# Some Choice in every marked column.# Mark all rows not already marked which have choices in marked columns.num_marked_rows = self.__mark_new_rows_with_choices_in_marked_columns()# If no new marks then Finish.if num_marked_rows == 0:return True# Mark all columns not already marked which have zeros in marked rows.num_marked_columns = self.__mark_new_columns_with_zeros_in_marked_rows()# If no new marked columns then finish.if num_marked_columns == 0:return True# No choice in one or more marked columns.# Find a marked column that does not have a choice.choice_column_index = self.__find_marked_column_without_choice()while choice_column_index is not None:# Find a zero in the column indexed that does not have a row with a choice.choice_row_index = self.__find_row_without_choice(choice_column_index)# Check if an available row was found.new_choice_column_index = Noneif choice_row_index is None:# Find a good row to accomodate swap. Find its column pair.choice_row_index, new_choice_column_index = \self.__find_best_choice_row_and_new_column(choice_column_index)# Delete old choice.self._choices[choice_row_index, new_choice_column_index] = False# Set zero to choice.self._choices[choice_row_index, choice_column_index] = True# Loop again if choice is added to a row with a choice already in it.choice_column_index = new_choice_column_indexdef __mark_new_columns_with_zeros_in_marked_rows(self):"""Mark all columns not already marked which have zeros in marked rows."""num_marked_columns = 0for index, column in enumerate(self._zero_locations.T):if index not in self._marked_columns:if column.any():row_indices, = np.where(column)zeros_in_marked_rows = (set(self._marked_rows) & set(row_indices)) != set([])if zeros_in_marked_rows:self._marked_columns.append(index)num_marked_columns += 1return num_marked_columnsdef __mark_new_rows_with_choices_in_marked_columns(self):"""Mark all rows not already marked which have choices in marked columns."""num_marked_rows = 0for index, row in enumerate(self._choices):if index not in self._marked_rows:if row.any():column_index, = np.where(row)if column_index in self._marked_columns:self._marked_rows.append(index)num_marked_rows += 1return num_marked_rowsdef __choice_in_all_marked_columns(self):"""Return Boolean True if there is a choice in all marked columns. Returns boolean False otherwise."""for column_index in self._marked_columns:if not self._choices[:, column_index].any():return Falsereturn Truedef __find_marked_column_without_choice(self):"""Find a marked column that does not have a choice."""for column_index in self._marked_columns:if not self._choices[:, column_index].any():return column_indexraise HungarianError("Could not find a column without a choice. Failed to cover matrix zeros. Algorithm has failed.")def __find_row_without_choice(self, choice_column_index):"""Find a row without a choice in it for the column indexed. If a row does not exist then return None."""row_indices, = np.where(self._zero_locations[:, choice_column_index])for row_index in row_indices:if not self._choices[row_index].any():return row_index# All rows have choices. Return None.return Nonedef __find_best_choice_row_and_new_column(self, choice_column_index):"""Find a row index to use for the choice so that the column that needs to be changed is optimal.Return a random row and column if unable to find an optimal selection."""row_indices, = np.where(self._zero_locations[:, choice_column_index])for row_index in row_indices:column_indices, = np.where(self._choices[row_index])column_index = column_indices[0]if self.__find_row_without_choice(column_index) is not None:return row_index, column_index# Cannot find optimal row and column. Return a random row and column.from random import shuffleshuffle(row_indices)column_index, = np.where(self._choices[row_indices[0]])return row_indices[0], column_index[0]if __name__ == '__main__':profit_matrix = [[62, 75, 80, 93, 95, 97],[75, 80, 82, 85, 71, 97],[80, 75, 81, 98, 90, 97],[78, 82, 84, 80, 50, 98],[90, 85, 85, 80, 85, 99],[65, 75, 80, 75, 68, 96]]hungarian = Hungarian(profit_matrix, is_profit_matrix=True)hungarian.calculate()print("Expected value:\t\t543")print("Calculated value:\t", hungarian.get_total_potential())  # = 543print("Expected results:\n\t[(0, 4), (2, 3), (5, 5), (4, 0), (1, 1), (3, 2)]")print("Results:\n\t", hungarian.get_results())print("-" * 80)cost_matrix = [[4, 2, 8],[4, 3, 7],[3, 1, 6]]hungarian = Hungarian(cost_matrix)print('calculating...')hungarian.calculate()print("Expected value:\t\t12")print("Calculated value:\t", hungarian.get_total_potential())  # = 12print("Expected results:\n\t[(0, 1), (1, 0), (2, 2)]")print("Results:\n\t", hungarian.get_results())print("-" * 80)profit_matrix = [[62, 75, 80, 93, 0, 97],[75, 0, 82, 85, 71, 97],[80, 75, 81, 0, 90, 97],[78, 82, 0, 80, 50, 98],[0, 85, 85, 80, 85, 99],[65, 75, 80, 75, 68, 0]]hungarian = Hungarian()hungarian.calculate(profit_matrix, is_profit_matrix=True)print("Expected value:\t\t523")print("Calculated value:\t", hungarian.get_total_potential())  # = 523print("Expected results:\n\t[(0, 3), (2, 4), (3, 0), (5, 2), (1, 5), (4, 1)]")print("Results:\n\t", hungarian.get_results())print("-" * 80)

输出:

Expected value:         543
Calculated value:        543
Expected results:[(0, 4), (2, 3), (5, 5), (4, 0), (1, 1), (3, 2)]
Results:[(0, 4), (2, 3), (5, 5), (4, 0), (1, 1), (3, 2)]
--------------------------------------------------------------------------------
calculating...
Expected value:         12
Calculated value:        12
Expected results:[(0, 1), (1, 0), (2, 2)]
Results:[(0, 1), (1, 0), (2, 2)]
--------------------------------------------------------------------------------
Expected value:         523
Calculated value:        523
Expected results:[(0, 3), (2, 4), (3, 0), (5, 2), (1, 5), (4, 1)]
Results:[(0, 3), (2, 4), (3, 0), (5, 2), (1, 5), (4, 1)]
--------------------------------------------------------------------------------

其中Expected results是实现计算过的最佳结果,与事先计算过的最佳results比较,说明算法运行正确

匈牙利算法的python实现相关推荐

  1. python最长匹配_二分图最大匹配:匈牙利算法的python实现

    二分图匹配是很常见的算法问题,一般用匈牙利算法解决二分图最大匹配问题,但是目前网上绝大多数都是C/C++实现版本,没有python版本,于是就用python实现了一下深度优先的匈牙利算法,本文使用的是 ...

  2. 匈牙利算法与python实现

    匈牙利算法 0 引出 最近看DETR论文,发现其通过匈牙利算法来进行预测和ground truth匹配,从而实现set prediction.这个思路很有意思,并且该匹配算法能适用多种问题,因此,对其 ...

  3. 图解匈牙利算法(含python代码)

    文章目录 分配问题 匈牙利算法 算法步骤 算法实现 python版本 C++版本 分配问题 分配问题/指派问题(Assignment Problem)作为线性规划问题的一个特例,在运筹学研究中占有重要 ...

  4. 匈牙利算法python代码实现以及原理图解

    匈牙利算法python代码实现以及原理图解 1.匈牙利算法python代码实现: 2.原理图解: 1.匈牙利算法python代码实现: scipy中有对应的接口scipy.optimize.linea ...

  5. 匈牙利算法原理与Python实现

    匈牙利算法原理与Python实现 今天学习一个新的算法-匈牙利算法,用于聚类结果分析,先用图表示我当前遇到的问题: 这两列值是我用不同算法得到的聚类结果,从肉眼可以看出第一列聚类为0的结果在第二列中对 ...

  6. 匈牙利算法的基本原理与Python实现

    一.问题描述 问题描述:N个人分配N项任务,一个人只能分配一项任务,一项任务只能分配给一个人,将一项任务分配给一个人是需要支付报酬,如何分配任务,保证支付的报酬总数最小. 问题数学描述: 二.实例分析 ...

  7. 二分图匹配及匈牙利算法的全面讲解及python实现

    1.背景 在生活中常常遇到两组元素多对多匹配而又数目有限的情况,我们需要对其进行最大匹配数的分配,使效率最大化.例如,有一组压缩气缸和一组压缩活塞,每一个型号的压缩气缸有一个固定的内径大小,每一个型号 ...

  8. 利用python解决指派问题(匈牙利算法)

    前言:最近在备战数模,看到了指派问题,饶有兴趣,百度上找了很多关于指派问题的解法,很明显用的是匈牙利算法,手工的计算,各种各种的博客都很详尽,大概都了解了具体原理(不懂的人可以自行百度),但是基本找不 ...

  9. python实现匈牙利算法

    1.通过深度优先搜索实现匈牙利算法 #-*-coding:utf-8-*- #created by lixiao at 2019/4/16class DFS_hungary():def __init_ ...

最新文章

  1. 超卖频发or商品滞销?压倒卖家的最后一根稻草竟是库存!
  2. Pycharm设置pylint real-time scan实时扫描代码规范
  3. 第三十三讲 非线性方程组化为一阶方程
  4. 安装 | CCS5.5安装包与licence以及安装教程
  5. 关于手机的,发送验证码,正则
  6. bp神经网络隐含层神经元个数_CNN,残差网络,BP网络
  7. LateX在windows中运用MiKTeX
  8. tcp_v4_connect函数分析
  9. windows2003路由和远程访问 试图连接到数据存储时出错
  10. vim 插件配置与安装
  11. 图像处理:根据像素坐标及像素尺寸大小裁剪遥感影像
  12. 股市潜规则 你经历过几个?(转)
  13. MarkDown的用法
  14. 联想微型计算机m910q,联想ThinkCentre M910x迷你台式机 获最佳创新产品奖
  15. Git添加用户名、密码、修改用户名密码
  16. 在应急响应过程中,有什么好的方法可以寻找某一日期创建的文件?
  17. R语言——多元线性回归
  18. 程序员涨薪2000,是留小公司还是跳槽去大厂?
  19. GB28181融合视频会议的实现
  20. EtherCAT IGH 的编译选项介绍

热门文章

  1. 深入理解Solaris X64系统调用
  2. [数据讨论][解包相关]下江小春也能轻松掌握的碧蓝档案提取工具
  3. HTML5期末大作业:个人网站设计——拾艺客个人设计工作室(6个页面) HTML+CSS+JavaScript...
  4. 前端页面 原生php+H5 视频播放一 专辑列表页(专辑页list)
  5. 有 5 个人坐在一起,问第五个人多少岁?
  6. 怎么手动彻底清除电脑上的广告弹窗
  7. UML类图关系(C++)
  8. idm下载器去哪里下载 idm下载器用不了什么原因
  9. Android 隐私数据_Android安全警告:10亿台安卓设备不再支持安全更新
  10. linux升级n卡驱动,Centos 7 更新 NVIDIA 驱动