帶權隨機算法-根據權重隨機選出N個對象研究歷程

1.簡介
一個長度爲M的對象數組,對象有權重屬性W(權重總和不服從1分配),要求根據權重隨機找出N個對象,概率服從權重分配(或者可按一定分佈服從)
2.原始(第一)想法
2.1 權重映射
先遍歷一遍數組,找到每個權重的上下限Wmin與Wmax 並計算出總和Wtotal,在0~Wtotal中取隨機數,再根據二分查找(可以根據Wtotal和size算出Waverage 使得二分查找更精確)找到對應範圍內的對象
如果數量級比較小可以直接申請一段空間,簡化回查複雜度。
這裏寫圖片描述

Java代碼實現如下:

package com.kowalski;


import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * Created by kowalski.zhang on 2018/5/14
 */
public class Algorithm {

  public static void main(String... args) {

  }

  public static Random random = ThreadLocalRandom.current();

  public static <T extends Weight> List<T> getRandomListByWeight(List<T> sourceList, int takeNum){

    if(sourceList == null || sourceList.isEmpty() || takeNum <= 0){
      return null;
    }

    List<T> resList = new ArrayList<>();

    int total = getTotalWeightAndFillWeight(sourceList);

    if(sourceList.size() == 1){
      resList.add(sourceList.get(0));
      return resList;
    }
    resList = IntStream.range(0, takeNum).map(i -> search(sourceList, random.nextInt(total)))
        .filter(resIndex -> resIndex != -1).mapToObj(sourceList::get).collect(Collectors.toList());
    return resList;
  }

  /***
   * 二分查找(可優化)
   * @param sourceList
   * @param randomNum
   * @param <T>
   * @return
   */
  public static<T extends Weight> int search(List<T> sourceList, int randomNum){
    int low = 0;
    int high = sourceList.size() - 1;
    int middle;

    while(low <= high){
      middle = (low + high) / 2;
      if(sourceList.get(middle).getLowWeight() > randomNum){
        //比關鍵字大則關鍵字在左區域
        high = middle - 1;
      }else if(sourceList.get(middle).getHighWeight() < randomNum){
        //比關鍵字小則關鍵字在右區域
        low = middle + 1;
      }else{
        return middle;
      }
    }

    return -1;      //最後仍然沒有找到,則返回-1
  }

  public static<T extends Weight> int getTotalWeightAndFillWeight(List<T> sourceList){
    if(sourceList == null || sourceList.isEmpty()){
      return 0;
    }
    int total = 0;
    for(T source:sourceList){
      /**填充上下限 計算total*/
      source.setLowWeight(total);
      total += source.getWeight();
      source.setHighWeight(total);
    }
    return total;
  }

  public static class Weight {

    public Weight(int weight) {
      this.weight = weight;
    }

    public int getWeight() {
      return weight;
    }

    public void setWeight(int weight) {
      this.weight = weight;
    }

    public int getLowWeight() {
      return lowWeight;
    }

    public void setLowWeight(int lowWeight) {
      this.lowWeight = lowWeight;
    }

    public int getHighWeight() {
      return highWeight;
    }

    public void setHighWeight(int highWeight) {
      this.highWeight = highWeight;
    }

    /**權重*/
    private int weight;
    /**下限*/
    private int lowWeight;
    /**上限*/
    private int highWeight;
  }
}

3.原始(第一)想法的破滅

3.1問題的出現:
3.1.1.取出N個不重複對象

如何取出N個不重複的對象?
*將已取出對象拋出重新分配上下限?這個效率肯定不允許
*將已取出對象與最後的對象進行位置交換(設立最後部分爲禁區),下次隨機數將在0~(Wtotal-Wtoke)間產生,如果當前對象的範圍小於尾部對像,則直接將當前對象置換爲尾部對象,但如果拋出對象範圍較大,則問題就會變得很複雜…數組調整,等等,效率也不一定允許。

3.1.2.不只是想服從絕對的權重佔總權重比

可能需要要求某些權重永遠不能被取出(或者更大程度的縮小小權重出現概率),可以動態變化分佈規則

4.新想法探究

如果想要已更高的效率取出N個不重複的服從權重分佈的隨機對像,基本上以上方法已經無法滿足。
那是一個明媚的下午…剛睡醒還在懵懵狀態下的我,在紙上畫出了個這麼個玩意:
這裏寫圖片描述
圖解:每個不同的權重對應不同的長度,最右邊一條線在往左靠近的過程中,取出依次接觸到的權重線,那怎麼能讓短的(像Obj4的W4)也有機會被取出呢,那麼就再加一個隨機數不就好啦~ 大的W有可能隨機到小的數,小的有可能隨機到大的數,然後取topk不就好了!!!!(腦瓜崩嗡嗡的…)
然後就引入了個離散權重的概念:

Java代碼在這裏:

package com.kowalski;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * Created by kowalski.zhang on 2018/5/21
 */
public class FinalFinalAlgorithm {

  public static void main(String... args) {

    List<Demo> demos;
    /**造數據*/
    demos = IntStream.range(0, 100000).mapToObj(i -> new Demo(0, StrictMath.random() * 100))
        .collect(Collectors.toList());

    long time1 = System.currentTimeMillis();
    /**take*/
    List<Demo> topK = Taker.take(demos, 10000);
    System.out.println(System.currentTimeMillis() - time1);

    int mm = 0;
    double max = 0;
    for(Demo demo:topK){
      mm ++;
      System.out.println("W:" + demo.getWeight() + " D:" + demo.getDispersedWeight());
      if(demo.getWeight() > max){
        max = demo.getWeight();
      }
    }

    System.out.println("max " + max);

  }

  public static class Weight implements Comparable<Weight>, Serializable {

    private static final long serialVersionUID = 2816396154208421520L;

    public Weight(double weight) {
      this.weight = weight;
    }

    public double getWeight() {

      return weight;
    }

    public void setWeight(double weight) {
      this.weight = weight;
    }

    public double getDispersedWeight() {
      return dispersedWeight;
    }

    public void setDispersedWeight(Double dispersedWeight) {
      this.dispersedWeight = dispersedWeight;
    }

    /**
     * 權重
     */
    private double weight;
    /**
     * 離散後權重
     */
    private double dispersedWeight;

    @Override public int compareTo(Weight o) {

      if (this.getDispersedWeight() > o.getDispersedWeight()) {
        return 1;
      } else {
        return -1;
      }
    }
  }

  /***
   * 實體類
   */
  public static class Demo extends Weight {
    private static final long serialVersionUID = 3499760378453027078L;
    private Integer num;

    public Demo(Integer num, double weight) {
      super(weight);
      this.num = num;
    }

    public Integer getNum() {
      return num;
    }

    public void setNum(Integer num) {
      this.num = num;
    }
  }

  /***
   * taker
   */
  public static class Taker {

    /**
     * take
     * @param source 原數據
     * @param takeNum 取出量
     * @param isDestruction 是否破壞原數據
     * @param dispersedType 離散方式 (默認權重與總權重比)
     * @param <T>
     * @return
     */
    public static<T extends Weight> List<T> take(List<T> source, int takeNum, boolean isDestruction,
        DispersedTypeEnum dispersedType){

      fillDispersedWeight(source, dispersedType == null?
          DispersedTypeEnum.ABSOLUTE_FOLLOW_WEIGHT:dispersedType);
      List<T> rest = Collections.emptyList();
      try {
        rest = Sort.topKWithSortByQuickSort(isDestruction?source:deepCopy(source), takeNum);
      } catch (IOException | ClassNotFoundException e) {
        e.printStackTrace();
      }
      return rest;
    }

    /**
     * take 默認破壞原數據
     * @param source 原數據
     * @param takeNum 取出量
     * @param dispersedType 離散方式 (默認簡單加法離散_平均離散值)
     * @param <T>
     * @return
     */
    public static<T extends Weight> List<T> take(List<T> source, int takeNum, DispersedTypeEnum dispersedType) {
      return take(source, takeNum, true, dispersedType);
    }
    /**
     * take 默認破壞原數據
     * @param source 原數據
     * @param takeNum 取出量
     * @param <T>
     * @return
     */
    public static<T extends Weight> List<T> take(List<T> source, int takeNum) {
      return take(source, takeNum, true, null);
    }

    /**
     * take 默認破壞原數據 全量排序
     * @param source 原數據
     * @param <T>
     * @return
     */
    public static<T extends Weight> List<T> take(List<T> source) {
      if(source == null || source.isEmpty()){
        return Collections.emptyList();
      }
      return take(source, source.size(), true, null);
    }

    /**
     * 不破壞原數組
     * @param source
     * @param takeNum
     * @param <T>
     * @return
     */
    public static<T extends Weight> List<T> takeWithOutDestruction(List<T> source, int takeNum)
        throws IOException, ClassNotFoundException {
      return Sort.topKWithSortByQuickSort(deepCopy(source), takeNum);
    }

    /**
     * 破壞原數組
     * @param source
     * @param takeNum
     * @param <T>
     * @return
     */
    public static<T extends Weight> List<T> takeWithDestruction(List<T> source, int takeNum){
      return Sort.topKWithSortByQuickSort(source, takeNum);
    }

    /**
     * 獲取平均權重
     * @param source
     * @param <T>
     * @return
     */
    public static<T extends Weight> double getAverageWeight(List<T> source){
      if(source == null || source.isEmpty()){
        return 0;
      }
      return getTotalWeight(source) / source.size();
    }

    /**
     * 獲取總權重
     * @param source
     * @param <T>
     * @return
     */
    public static<T extends Weight> double getTotalWeight(List<T> source){
      if(source == null || source.isEmpty()){
        return 0;
      }
      return source.stream().mapToDouble(Weight::getWeight).sum();
    }

    /**
     * 填充離散權重
     * @param source
     * @param dispersedType 離散方式
     * @param <T>
     */
    public static<T extends Weight> void fillDispersedWeight(List<T> source, DispersedTypeEnum dispersedType){
      if(source == null || source.isEmpty()){
        return ;
      }
      /**離散量*/
      Double dispersedNum = dispersedType.getDispersedNum();
//      Method method = dispersedType.getMethod();
      /**填充離散*/
      switch (dispersedType){
        case ABSOLUTE_FOLLOW_WEIGHT:
          double totalWeight = getTotalWeight(source);
          for(T t:source){
            t.setDispersedWeight((StrictMath.random() * (t.getWeight()/totalWeight)));
          }
          break;
        case SIMPLE_ADD:
        case SIMPLE_ADD_AVERAGE:
          if(dispersedNum == null){
            dispersedNum = getAverageWeight(source);
          }
          for(T t:source){
            //        try {
            //          t.setDispersedWeight((double) method.invoke(DispersedMethod.class, t.getWeight(), dispersedNum));
            //        } catch (IllegalAccessException | InvocationTargetException e) {
            //          e.printStackTrace();
            //        }
            t.setDispersedWeight((StrictMath.random() * dispersedNum) + (t.getWeight()));
          }
          break;
      }
    }

    /***
     * 數組深拷貝--序列化方案
     * @param src
     * @param <T>
     * @return
     * @throws IOException
     * @throws ClassNotFoundException
     */
    public static <T extends Weight> List<T> deepCopy(List<T> src) throws IOException, ClassNotFoundException {
      ByteArrayOutputStream byteOut = new ByteArrayOutputStream();
      ObjectOutputStream out = new ObjectOutputStream(byteOut);
      out.writeObject(src);

      ByteArrayInputStream byteIn = new ByteArrayInputStream(byteOut.toByteArray());
      ObjectInputStream in = new ObjectInputStream(byteIn);
      @SuppressWarnings("unchecked")
      List<T> dest = (List<T>) in.readObject();
      return dest;
    }
  }

//  public static final class DispersedMethod{
//    private DispersedMethod() {
//    }
//
//    public static double simpleAdd(double weight, double dispersedNum){
////      System.out.println("aaa");
//      return (StrictMath.random() * dispersedNum) + weight;
//    }
//  }

  public enum DispersedTypeEnum {

    ABSOLUTE_FOLLOW_WEIGHT(0, "絕對服從權重", null),//概率服從當前權重與總權重之比
    SIMPLE_ADD_AVERAGE(1, "簡單加法離散_平均離散值", null/*, "simpleAdd"*/),
    SIMPLE_ADD(2, "簡單加法離散", 100.0d/*, "simpleAdd"*/);

    private Integer type;
    private String desc;
    private Double dispersedNum;
//    private Method method;
    /**
     * 離散量  離散量爲空 採取平均權重離散
     */
    DispersedTypeEnum(Integer type, String desc, Double dispersedNum/*, String methodName*/){
      this.type = type;
      this.desc = desc;
      this.dispersedNum = dispersedNum;
//      try {
//        this.method = DispersedMethod.class.getMethod(methodName, double.class, double.class);
//      } catch (NoSuchMethodException e) {
//        e.printStackTrace();
//      }
    }

    public Integer getType() {
      return type;
    }

    public void setType(Integer type) {
      this.type = type;
    }

    public String getDesc() {
      return desc;
    }

    public Double getDispersedNum() {
      return dispersedNum;
    }

    public void setDispersedNum(Double dispersedNum) {
      this.dispersedNum = dispersedNum;
    }

    private static final Map<Integer, DispersedTypeEnum> map = new HashMap<>();

    static {
      for (DispersedTypeEnum enums : DispersedTypeEnum.values()) {
        map.put(enums.getType(), enums);
      }
    }

    public static DispersedTypeEnum getEnumValue(int code) {
      return map.get(code);
    }

    public static String getDescByType(int code) {
      return map.get(code).getDesc();
    }

    public static Double getDispersedNumByType(int code) {
      return map.get(code).getDispersedNum();
    }

//    public Method getMethod() {
//      return method;
//    }
//
//    public void setMethod(Method method) {
//      this.method = method;
//    }

    public void setDesc(String desc) {
      this.desc = desc;
    }
  }
  /***
   * 排序工具類
   */
  public static class Sort{
    /**
     * topk + 排序(快排實現)
     * @param source
     * @param k
     * @param <T>
     * @return
     */
    public static<T extends Weight> List<T> topKWithSortByQuickSort(List<T> source, int k) {
      if(source == null || source.isEmpty()){
        return Collections.emptyList();
      }
      int index;
      int rank;
      int start = 0;
      int end = source.size() - 1;
      while (end > start) {
        index = partition(source, start, end);
        rank = index + 1;
        if (rank >= k) {
          end = index - 1;
        } else if ((index - start) > (end - index)) {
          quickSort(source, index + 1, end);
          end = index - 1;
        } else {
          quickSort(source, start, index - 1);
          start = index + 1;
        }
      }
      return source.subList(0, k);
    }

    public static <T extends Weight> int partition(List<T> lst, int start, int end) {
      T x;
      x = lst.get(start);
      int i = start;
      for (int j = start + 1; j <= end; j++) {
        if (lst.get(j).compareTo(x) > 0) {
          i = i + 1;
          swap(lst, i, j);
        }
      }
      swap(lst, start, i);
      return i;
    }

    public static  <T extends Weight> void swap(List<T> lst, int p, int q) {
      T temp = lst.get(p);
      lst.set(p, lst.get(q));
      lst.set(q, temp);
    }

    public static<T extends Weight> void quickSort(List<T> lst, int start, int end) {
      if (start < end) {
        int index = partition(lst, start, end);
        quickSort(lst, start, index - 1);
        quickSort(lst, index + 1, end);
      }
    }
  }
}

代碼中TopK是基於快速排序改造的(不只是取了topk 還把這topk給排了序,網上很多隻是找到了第N個位置的數,然後就直接把前N個返回了,前N個並未排序,排序很關鍵!!!)
效率很滿意,嘿嘿嘿~

5.待研究
可以指定不同的離散方式干擾最終結果的分佈~ 什麼卡方分佈,正態分佈之類的(數學不行了…很多東西算不出來了…) 或許有更好的方案 有興趣的同學一起研究~ (數學好的用到了別的離散方式的同學歡迎帶帶我)
email:[email protected]

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章