從多項分佈採樣的Java實現

版權聲明:本文爲博主原創文章,未經博主允許不得轉載。 https://blog.csdn.net/xiao_xia_/article/details/78248438
思路:
將每個概率值對應到[0,1]區間內的各個子區間(概率值大小體現在子區間的長度上),每次採樣時,按照均勻分佈隨機生成一個[0,1]區間內的值,其落到哪個區間,則該區間概率值對應的元素即爲被採樣的元素;


算法:
1、先對概率值從大到小排列(不是必要過程,是便於加速的技巧,這樣每次查找時優先檢測隨機數是否落在大概率的區間內,減少比較次數);
2、生成一個[0,1)區間內的隨機數x (注意,Rand().nextDouble()得到的是[0,1)區間內的數,而wikipedia給出的算法中要求生成的是(0,1)區間的數);
3、將x與概率值列表中的各值pi逐個比較,並累加已比較過的前i-1個概率值的累加和sum:
若x落在[sum, sum+pi)區間內,則pi對應的元素被採樣並返回 (注意區間的開閉應該參考步驟2中的情況);
否則,將pi累加入sum,繼續將x與p(i+1)比較;

Tips:
若程序退出時仍未採到合法樣本,則可能給定的概率分佈不滿足∑pi=1的條件(且x剛好落在[1-sum, 1)區間內);

應用場景:
機器學習(如強化學習)中,利用softmax函數定義policy,根據多項分佈選擇對應的action(使得agent有較大概率選到當前模型下的最佳action,又有一定的機率去探索其他action);
softmax policy的另一種替代方式是epsilon-learning中用epsilon來控制探索和利用的機率的方式,即以epsilon的概率進行探索(隨機選一個action),以1-epsilon的概率進行利用(選當前模型下最佳action);

算法代碼:

/**
        * sample from amultinomial distribution
        * https://en.wikipedia.org/wiki/Multinomial_distribution#Sampling_from_a_multinomial_distribution
        *@parampdist a list of <item,probablity>
        *@returnthe selected item, i.e. result belongs to pdist.getFirstValues
        *@authorqxliuOct 16, 2017 11:47:33 AM
        */
       publicstaticintsampleFromMultinomialDistribution(List<TwoTuple<Integer,Double>>pdist){
              List<TwoTuple<Integer,Double>>pidxlist=newArrayList<>(pdist);//avoid changing pdist
              intitemNum=pidxlist.size();
              Collections.sort(pidxlist,newReRanker().setIsDesc(true));
              Randomrand=newRandom();
              doublerd=rand.nextDouble();//a random double in [0,1)
              doublesum=0;
              intsampledIdx=-1;
              for(intk=0;k<itemNum;k++){
                     doublepk=pidxlist.get(k).getSecond();
                     if(rd>=sum&&rd<sum+pk){
                           sampledIdx=pidxlist.get(k).getFirst();//====
                           break;
                     }
                     sum+=pk;
              }
              if(sampledIdx<0&&sum!=1){
                     thrownewIllegalArgumentException("error distribution! sampledIdx="+sampledIdx+", distribution="+pidxlist);
              }
              returnsampledIdx;
       }
測試代碼:
/**
 *@authorqxliu2017 Oct10, 2017 4:34:12 PM
 *
 */
publicstaticvoidmain(String[]args){             
              intsampleNum=10;//打算採樣的次數(決定最終採得的樣本數)
              List<TwoTuple<Integer,Double>>pdist=newArrayList<>();//定義多項式分佈
              pdist.add(newTwoTuple<Integer,Double>(1,0.5));//每個元素爲:樣本的標記(比如id),樣本被選中的概率;
              pdist.add(newTwoTuple<Integer,Double>(2,0.3));//應該要求概率分佈之和爲1,但算法裏並未檢查概率值總和是否爲1;
              pdist.add(newTwoTuple<Integer,Double>(3,0.2));//若概率值之和不爲1,則有可能報錯(當隨機數出現在[sum, 1)區間內時),也可能不報錯;
//            pdist.add(new TwoTuple<Integer, Double>(9, 0.5));
//            pdist.add(new TwoTuple<Integer, Double>(7, 0.5));
              for(inti=0;i<sampleNum;i++){
                     System.out.println(sampleFromMultinomialDistribution(pdist));
              }
       }

輸出:(10次採樣的結果,也可能是滿足該分佈的其他情況;由於採樣次數少,有時結果也可能看起來不滿足原分佈)
1
3
1
2
1
2
2
2
1
2
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章