不等概率不放回

import java.util.ArrayList;

import java.util.List;

import java.util.Random;


/**

 * 不等概率不放回的抽樣類 使用方法:傳入你的概率rates,以及需要抽取的樣本個數k。假如傳入的概率是:[1,2,3,4,5],

 * k爲2,如果最後選擇到的概率是1,3;那麼返回的index爲0(概率1的index),2(概率3的index)

 * 

 * @author xutaoyang

 * 

 */

public class UnequalWithoutReplacementKRandom {


static Random rand = new Random();


/**

* 對外接口方法

* @param rates

*            概率

* @param k

*            目標樣本的個數

* @return 命中的樣本的在概率list中的index

*/

public static List<Integer> randKWithoutReplacement(List<Double> rates, int k) {

if (null == rates || rates.isEmpty()) {

throw new RuntimeException("<<UnequalWithoutReplacementKRandom>> the rates list is null or empty");

}

if (k >= rates.size()) {

throw new RuntimeException("<<UnequalWithoutReplacementKRandom>> k is bigger than rates' size");

}

List<Node> nodes = new ArrayList<Node>(rates.size());

for (int index = 0; index < rates.size(); index++) {

nodes.add(new Node(rates.get(index), index));

}

List<Integer> result = new ArrayList<Integer>(k);

List<Node> heap = buildHeap(nodes);

for (int index = 0; index < k; index++) {

result.add(heapPop(heap));

}

return result;

}


private static List<Node> buildHeap(List<Node> nodes) {

List<Node> heap = new ArrayList<Node>(nodes.size() + 1);

heap.add(null);

for (int index = 0; index < nodes.size(); index++) {

heap.add(nodes.get(index));

}

for (int index = heap.size() - 1; index > 1; index--) {

double curTW = heap.get(index >> 1).totalWeight;

heap.get(index >> 1).totalWeight = curTW + heap.get(index).totalWeight;

}

return heap;

}


/** 關於double的計算都用+-x/了,那點誤差就讓它去吧,性能高很多啊 */

private static int heapPop(List<Node> heap) {

double gas = heap.get(1).totalWeight * rand.nextDouble();

int i = 1;

while (gas > heap.get(i).weight) {

gas = gas - heap.get(i).weight;

i <<= 1;

if (gas > heap.get(i).totalWeight) {

gas = gas - heap.get(i).totalWeight;

i++;

}

}

double weight = heap.get(i).weight;

int value = heap.get(i).value;

heap.get(i).weight = 0;

while (i > 0) {

heap.get(i).totalWeight = heap.get(i).totalWeight - weight;

i >>= 1;

}

return value;

}


private static class Node {


double weight;

int value;

double totalWeight;


public Node(double weight, int value) {

this.weight = weight;

this.value = value;

this.totalWeight = weight;

}


}


}



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章