近日刷到Leetcode 4題目。感覺這個題目作爲標記hard的題目,還是很有意思的。
題目如下:
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
解法
那這道題是怎麼做的呢。Leetcode只用你寫調用的函數就好。所以它的時間複雜度是調用函數的時間複雜度,而不是整個程序的時間複雜度(整個程序的時間複雜度只要是O(m+n))。
解法是很容易想到的。因爲題目已經要求時間複雜度是O(log(m+n)),所以顯然可以使用二分的思想。這裏提供一個找到兩個序列中第k大數的思路。
每次比較兩個序列的中位數大小。相應的就可以丟掉其中一個數列的一半。比如兩個序列A、B長度爲10、20,而B的中位數比A的大,這就說明的B的中位數一定大於15個數。如果k<15的話,那就可以直接丟掉B序列的後面一半了。然後再次調用這個函數就可以了。
具體實現還有一些細節要考慮。最後只要找第(m+n)/2大的數就好了。
代碼附在最後。
時間複雜度
這裏來看時間複雜度。首先程序主體部分複雜度是O(log(n+m))的。但是我開始時候的困惑是數組vector的size這布操作,複雜度是O(n)嘛?
查閱相關資料可以看到,現在vector size()時間複雜度已經是常數了。說明vector的實現方式比string高端了一點。vector本身維護了size,而這也方便了越界檢查吧。
第二個巧妙之處在於,數組傳參時使用的是地址,而不是把每個參數都拷貝到對應一層。所以數據傳參的時間複雜度也是常數了。
題解代碼
class Solution {
public:
int find(vector<int>& nums1, vector<int>& nums2, int st1, int ed1, int st2, int ed2, int n) {
int l1 = ed1-st1+1;
int l2 = ed2-st2+1;
int m1=(st1+ed1)/2;
int m2=(st2+ed2)/2;
int big= l1+l2-n+1;
if(l1==0)
return st2+n-1+nums1.size();
if(l2==0)
return st1+n-1;
if(n==1)
return (nums1[st1]<nums2[st2])?st1:st2+nums1.size();
if(l1==1 && l2==1){
return (nums1[st1]>nums2[st2])?st1:st2+nums1.size();
}
if(l1==1)
return (nums2[st2+n-1]<=nums1[st1])?(st2+n-1+nums1.size()):
( (nums2[st2+n-2]<=nums1[st1])?st1:( st2+n-2+nums1.size() )
);
if(l2==1)
return (nums1[st1+n-1]<=nums2[st2])?(st1+n-1):
( (nums1[st1+n-2]<=nums2[st2])?(st2+nums1.size()):( st1+n-2 )
);
int small1=m1-st1, small2=m2-st2;
int large1=ed1-m1, large2=ed2-m2;
if(n<=small1){
return find(nums1,nums2,st1,m1-1,st2,ed2,n);
}
if(n<=small2){
return find(nums1,nums2,st1,ed1,st2,m2-1,n);
}
if( big<=large1 ){
return find(nums1,nums2,m1+1,ed1,st2,ed2,n-small1-1);
}
if( big<=large2 ){
return find(nums1,nums2,st1,ed1,m2+1,ed2,n-small2-1);
}
if( n<=(l1+l2)/2 ){
if(nums1[m1]>nums2[m2]){
return find(nums1,nums2,st1,m1,st2,ed2,n);
} else {
return find(nums1,nums2,st1,ed1,st2,m2,n);
}
} else {
if(nums1[m1]>nums2[m2]){
return find(nums1,nums2,st1,ed1,m2+1,ed2,n-small2-1);
} else {
return find(nums1,nums2,m1+1,ed1,st2,ed2,n-small1-1);
}
}
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int s1=nums1.size();
int s2=nums2.size();
if((s1+s2)%2==1){
int res=find(nums1,nums2,0,s1-1,0,s2-1,(s1+s2)/2+1);
if(res<s1)
return nums1[res];
else
return nums2[res-s1];
} else {
int res1=find(nums1,nums2,0,s1-1,0,s2-1,(s1+s2)/2);
int res2=find(nums1,nums2,0,s1-1,0,s2-1,(s1+s2)/2+1);
double aa,bb;
if(res1<s1)
aa= nums1[res1];
else
aa= nums2[res1-s1];
if(res2<s1)
bb= nums1[res2];
else
bb= nums2[res2-s1];
return (aa+bb)/2;
}
}
};