背景
最近看项目中的概率抽奖算法实现,总感觉有问题,验证了多遍,发现没问题。但是代码不够精简,尤其使用了多次Random,这个还是可以优化的~
那么,如何实现高效的概率算法?如何衡量算法的准确性呢?
算法实现
特点:只随机一次,找出命中概率的 index。
// 入参示例:int[] array = {30, 40, 20, 10};
private static int getIndex(int[] array) {
int num = new Random().nextInt(100);
for (int i = 0; i < array.length; i++) {
if (num < array[i]) {
return i;
}
num -= array[i];
}
return -1;
}
如何衡量
// 计算概率,测试10000次
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < 10000; i++) {
// 单次获取概率结果
int index = getIndex(array);
// System.out.printf("中奖啦! 概率: %d, Index: %d%n", array[index], index);
Integer value = map.getOrDefault(array[index], 0);
map.put(array[index], value + 1);
}
System.out.println("----------------------------------------------");
System.out.printf("看看结果吧! 概率: %s%n", map);
System.out.println("----------------------------------------------");
观察如下结果,符合按概率随机的要求。
----------------------------------------------
看看结果吧! 概率: {20=2004, 40=3985, 10=1036, 30=2975}
----------------------------------------------
看看结果吧! 概率: {20=2053, 40=3932, 10=999, 30=3016}
----------------------------------------------
看看结果吧! 概率: {20=2044, 40=4015, 10=962, 30=2979}
----------------------------------------------
看看结果吧! 概率: {20=2040, 40=3955, 10=982, 30=3023}
----------------------------------------------
看看结果吧! 概率: {20=1968, 40=3976, 10=1009, 30=3047}
----------------------------------------------
看看结果吧! 概率: {20=1984, 40=4029, 10=989, 30=2998}
----------------------------------------------
健壮性
如果入参的总概率不满100%,那剩余的比例不能中奖,怎么实现呢?
- 总概率不满100%时,补全数组;
- 计算中奖的时候,再舍去补全的最后一组(当然,如果没补就不用舍去)。
完整代码如下:
package cn.eyeo.mall.start.lottery;
import org.junit.Test;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
* 抽奖测试类
*
* @author <a href="mailto:[email protected]">amos.wang</a>
* @date 2023/3/27
*/
public class LotteryDrawTests {
private static final int TOTAL_PROBABILITY = 100;
@Test
public void lottery() {
int[] array = {30, 40, 20};
probability(array);
}
private static void probability(int[] source) {
int sum = Arrays.stream(source).sum();
if (sum > TOTAL_PROBABILITY) {
throw new IllegalArgumentException("总概率不能>100");
}
System.out.println("----------------------------------------------");
System.out.println("原始数据: " + Arrays.toString(source));
// 不满100%则补全
int[] array = source;
if (sum < TOTAL_PROBABILITY) {
array = new int[source.length + 1];
System.arraycopy(source, 0, array, 0, source.length);
array[source.length] = TOTAL_PROBABILITY - sum;
}
System.out.println("补全后的数据: " + Arrays.toString(array));
// 计算概率,测试10000次
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < 10000; i++) {
// 单次获取概率结果
int index = getIndex(array);
if (index < source.length) {
// System.out.printf("中奖啦! 概率: %d, Index: %d%n", array[index], index);
Integer value = map.getOrDefault(array[index], 0);
map.put(array[index], value + 1);
}
}
System.out.println("----------------------------------------------");
System.out.printf("看看结果吧! 概率: %s%n", map);
}
private static int getIndex(int[] array) {
int num = new Random().nextInt(TOTAL_PROBABILITY);
for (int i = 0; i < array.length; i++) {
if (num < array[i]) {
return i;
}
num -= array[i];
}
return -1;
}
}