Lua中的table.sort算法原理
table.sort的介紹
table.sort是Lua自帶的一個排序函數,函數原型爲:
table.sort(list[, comp])
其中list是目標table,comp是一個可選參數,可以自定義比較函數;當不提供comp函數時則默認按照升序進行排序;這裏需要注意table.sort是一個不穩定的排序算法;同時排序的table必須是一個數組,並且數組的索引必須是連續的;
table.sort的算法原理
首先來看Lua源碼中的sort方法的實現原理:
1: static int sort (lua_State *L) {
2: int n = aux_getn(L, 1);
3: luaL_checkstack(L, 40, ""); /* assume array is smaller than 2^40 */
4: if (!lua_isnoneornil(L, 2)) /* is there a 2nd argument? */
5: luaL_checktype(L, 2, LUA_TFUNCTION);
6: lua_settop(L, 2); /* make sure there is two arguments */
7: auxsort(L, 1, n);
8: return 0;
9: }
這個方法主要是獲取數組的大小,並對參數進行相應的校驗;之後調用auxsort方法來實現排序算法;
1: static void auxsort (lua_State *L, int l, int u) {
2: while (l < u) { /* for tail recursion */
3: int i, j;
4: /* sort elements a[l], a[(l+u)/2] and a[u] */
5: lua_rawgeti(L, 1, l);
6: lua_rawgeti(L, 1, u);
7: if (sort_comp(L, -1, -2)) /* a[u] < a[l]? */
8: set2(L, l, u); /* swap a[l] - a[u] */
9: else
10: lua_pop(L, 2);
11: if (u-l == 1) break; /* only 2 elements */
12: i = (l+u)/2;
13: lua_rawgeti(L, 1, i);
14: lua_rawgeti(L, 1, l);
15: if (sort_comp(L, -2, -1)) /* a[i]<a[l]? */
16: set2(L, i, l);
17: else {
18: lua_pop(L, 1); /* remove a[l] */
19: lua_rawgeti(L, 1, u);
20: if (sort_comp(L, -1, -2)) /* a[u]<a[i]? */
21: set2(L, i, u);
22: else
23: lua_pop(L, 2);
24: }
25: if (u-l == 2) break; /* only 3 elements */
26: lua_rawgeti(L, 1, i); /* Pivot */
27: lua_pushvalue(L, -1);
28: lua_rawgeti(L, 1, u-1);
29: set2(L, i, u-1);
30: /* a[l] <= P == a[u-1] <= a[u], only need to sort from l+1 to u-2 */
31: i = l; j = u-1;
32: for (;;) { /* invariant: a[l..i] <= P <= a[j..u] */
33: /* repeat ++i until a[i] >= P */
34: while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) {
35: if (i>u) luaL_error(L, "invalid order function for sorting");
36: lua_pop(L, 1); /* remove a[i] */
37: }
38: /* repeat --j until a[j] <= P */
39: while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) {
40: if (j<l) luaL_error(L, "invalid order function for sorting");
41: lua_pop(L, 1); /* remove a[j] */
42: }
43: if (j<i) {
44: lua_pop(L, 3); /* pop pivot, a[i], a[j] */
45: break;
46: }
47: set2(L, i, j);
48: }
49: lua_rawgeti(L, 1, u-1);
50: lua_rawgeti(L, 1, i);
51: set2(L, u-1, i); /* swap pivot (a[u-1]) with a[i] */
52: /* a[l..i-1] <= a[i] == P <= a[i+1..u] */
53: /* adjust so that smaller half is in [j..i] and larger one in [l..u] */
54: if (i-l < u-i) {
55: j=l; i=i-1; l=i+2;
56: }
57: else {
58: j=i+1; i=u; u=j-2;
59: }
60: auxsort(L, j, i); /* call recursively the smaller one */
61: } /* repeat the routine for the larger one */
62: }
上面是Lua源碼中的sort方法的實現,Lua源碼中包含了很多對堆棧的操作,爲了更直觀的看到算法的實現,對上面的方法進行了僞代碼的實現:
1: sort(array list, int l, int u)
2: {
3: while(l < u)
4: {
5: if(list[u] < list[l])
6: swap(list[u], list[l])
7:
8: if(u - l == 1)
9: break
10: int i = (l + u)/2
11: if(list[i] < list[l])
12: swap(list[i], list[l])
13: else if(list[u] < list[i])
14: swap(list[u], list[l])
15:
16: if(u - l == 2)
17: break
18:
19: int p = i
20: swap(list[i], list[u-1])
21: i = l
22: j = u-1
23: for(;;)
24: {
25: while(++i, list[i] < list[p])
26: {
27: if(i > u)
error("invalid order function for sorting")
28: }
29: while(--j, list[j] > list[p])
30: {
27: if(j < l)
error("invalid order function for sorting")
32: }
33: if(j < i)
34: break
35:
36: swap(list[i], list[j])
37: }
38:
39: swap(list[u-1], list[i])
40: if (i-l < u-i)
41: j=l; i=i-1; l=i+2;
42: else
43: j=i+1; i=u; u=j-2;
44:
45: sort(list, j, i)
46: }
47: }
通過僞代碼可以看出sort的核心算法的本質是快速排序,所以說table.sort排序是不穩定的排序;第5-17行是使用三數取中的方法對排序算法進行了優化;第39-43行主要是通過對中間值所在位置的判斷,來設置下次排序的起始和結束索引;
下面看一下sort_comp方法的實現:
1: static int sort_comp (lua_State *L, int a, int b) {
2: if (!lua_isnil(L, 2)) { /* function? */
3: int res;
4: lua_pushvalue(L, 2);
5: lua_pushvalue(L, a-1); /* -1 to compensate function */
6: lua_pushvalue(L, b-2); /* -2 to compensate function and `a' */
7: lua_call(L, 2, 1);
8: res = lua_toboolean(L, -1);
9: lua_pop(L, 1);
10: return res;
11: }
12: else /* a < b? */
13: return lua_lessthan(L, a, b);
14: }
可以看到在第二行對table.sort的第二個參數comp進行了檢查,當comp參數爲空時,則使用系統提供的升序排序方法;
(注:以上都是自己的理解,歡迎各位大佬指正!)