You are given an m * n
matrix, mat
, and an integer k
, which has its rows sorted in non-decreasing order.
You are allowed to choose exactly 1 element from each row to form an array. Return the Kth smallest array sum among all possible arrays.
Example 1:
Input: mat = [[1,3,11],[2,4,6]], k = 5 Output: 7 Explanation: Choosing one element from each row, the first k smallest sum are: [1,2], [1,4], [3,2], [3,4], [1,6]. Where the 5th sum is 7.
Example 2:
Input: mat = [[1,3,11],[2,4,6]], k = 9 Output: 17
思路:最小值肯定是每一行index爲0 [0,0,0]
那麼下一層:j index 分別爲[1,0,0] [0,1,0] [0,0,1]
再下一層:[2,0,0][1,1,0][1,0,1] [1,1,0] [0,2,0] [0,1,1] [1,0,1] [0,1,1] [0,0,2]
注意要去重複,用hashset<List<Integer>> 去重複座標;
class Solution {
public class Node {
public List<Integer> indexs;
public int sum;
public Node(List<Integer> indexs, int sum) {
this.indexs = indexs;
this.sum = sum;
}
}
public int kthSmallest(int[][] mat, int k) {
if(mat == null || mat.length == 0 || mat[0].length == 0) {
return 0;
}
int n = mat.length;
int m = mat[0].length;
PriorityQueue<Node> pq = new PriorityQueue<>((a, b) -> (a.sum - b.sum));
List<Integer> indexs = new ArrayList<>();
for(int i = 0; i < n; i++) {
indexs.add(0);
}
HashSet<List<Integer>> visited = new HashSet<>();
Node start = new Node(indexs, getSum(mat, indexs));
pq.offer(start);
int count = 0;
while(!pq.isEmpty()) {
int size = pq.size();
for(int i = 0; i < size; i++) {
Node node = pq.poll();
if(visited.contains(node.indexs)) {
continue;
}
visited.add(node.indexs);
count++;
if(count == k) {
return node.sum;
}
//generate next level;
for(Node neighbor: getNeighbors(node, mat)) {
pq.offer(neighbor);
}
}
}
return -1;
}
private List<Node> getNeighbors(Node node, int[][] mat) {
List<Integer> indexs = node.indexs;
List<Node> nodes = new ArrayList<>();
for(int i = 0; i < indexs.size(); i++) {
List<Integer> newIndexs = new ArrayList<>();
newIndexs.addAll(indexs);
if(indexs.get(i) + 1 < mat[0].length) {
newIndexs.set(i, indexs.get(i) + 1);
Node newnode = new Node(newIndexs, getSum(mat, newIndexs));
nodes.add(newnode);
}
}
return nodes;
}
private int getSum(int[][] mat, List<Integer> indexs) {
int sum = 0;
int index = 0;
for(int i = 0; i < mat.length; i++) {
sum += mat[i][indexs.get(index)];
index++;
}
return sum;
}
}