算法导论第三版,9.3
import random
import math
#returns the number of elements that smaller than x
#the input is A[p...r] inclusive in the convention of book, 1 <= p <= r <= n
#in Python, to represent A[p...r], we should use a[p-1:r]
def partition(a,p,r,x):
low = [m for m in a if m < x]
high = [m for m in a if m > x]
a[p-1:r] = low + [x] + high
return len(low)
def median(a):
a.sort()
return a[(len(a)+1)/2 - 1]
#x is the ith smallest element means there are (i-1) elements smaller than x
#x is the 1st smallest means x is the smallest
def select(a,i):
if len(a) == 1:
return a[0]
#1.divide into n/5 groups
groups = []
numOfGroups = int(math.ceil(len(a)*1.0/5))
start,end = 0,0
for j in range(0,numOfGroups-1):
start = j*5
end = start + 5
groups.append(a[start:end])
groups.append(a[end:])
#2.find the median of each group
medians = []
for g in groups:
medians.append(median(g))
#3.find the median of the n/5 medians
x = select(medians,(len(medians)+1)/2)
#4.partition the array with x,so x is the kth smallest element
k = partition(a,1,len(a),x) + 1
#5.
if k == i :
return x
#if k is greager than i,means there are too many elements smaller than x,
#we need to narrow the range
elif k > i :
#SELECT (A[1...k],i)
return select(a[0:k],i)
#if k is not enough,we need to find the result from the other side
else:
#SELECT (A[k+1...n],i - k)
return select(a[k:],i-k)
#test select
result = []
for i in range(10000):
a = random.sample(range(100), 21)
order = random.randint(1,20)
r1 = select(a,order)
a.sort()
r2 = a[order-1]
result.append(r1==r2)
print(all(result))