Lua中的 table.sort 算法原理

Lua中的table.sort算法原理

table.sort的介紹

table.sort是Lua自帶的一個排序函數,函數原型爲:

table.sort(list[, comp])

其中list是目標table,comp是一個可選參數,可以自定義比較函數;當不提供comp函數時則默認按照升序進行排序;這裏需要注意table.sort是一個不穩定的排序算法;同時排序的table必須是一個數組,並且數組的索引必須是連續的;

table.sort的算法原理

首先來看Lua源碼中的sort方法的實現原理:

1:  	static int sort (lua_State *L) {
2:  	  int n = aux_getn(L, 1);
3:  	  luaL_checkstack(L, 40, "");  /* assume array is smaller than 2^40 */
4:  	  if (!lua_isnoneornil(L, 2))  /* is there a 2nd argument? */
5:  	    luaL_checktype(L, 2, LUA_TFUNCTION);
6:  	  lua_settop(L, 2);  /* make sure there is two arguments */
7:  	  auxsort(L, 1, n);
8:  	  return 0;
9:  	}

這個方法主要是獲取數組的大小,並對參數進行相應的校驗;之後調用auxsort方法來實現排序算法;

1:  static void auxsort (lua_State *L, int l, int u) {
2:    while (l < u) {  /* for tail recursion */
3:      int i, j;
4:      /* sort elements a[l], a[(l+u)/2] and a[u] */
5:      lua_rawgeti(L, 1, l);
6:      lua_rawgeti(L, 1, u);
7:      if (sort_comp(L, -1, -2))  /* a[u] < a[l]? */
8:        set2(L, l, u);  /* swap a[l] - a[u] */
9:      else
10:        lua_pop(L, 2);
11:      if (u-l == 1) break;  /* only 2 elements */
12:      i = (l+u)/2;
13:      lua_rawgeti(L, 1, i);
14:      lua_rawgeti(L, 1, l);
15:      if (sort_comp(L, -2, -1))  /* a[i]<a[l]? */
16:        set2(L, i, l);
17:      else {
18:        lua_pop(L, 1);  /* remove a[l] */
19:        lua_rawgeti(L, 1, u);
20:        if (sort_comp(L, -1, -2))  /* a[u]<a[i]? */
21:          set2(L, i, u);
22:        else
23:          lua_pop(L, 2);
24:      }
25:      if (u-l == 2) break;  /* only 3 elements */
26:      lua_rawgeti(L, 1, i);  /* Pivot */
27:      lua_pushvalue(L, -1);
28:      lua_rawgeti(L, 1, u-1);
29:      set2(L, i, u-1);
30:      /* a[l] <= P == a[u-1] <= a[u], only need to sort from l+1 to u-2 */
31:      i = l; j = u-1;
32:      for (;;) {  /* invariant: a[l..i] <= P <= a[j..u] */
33:        /* repeat ++i until a[i] >= P */
34:        while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) {
35:          if (i>u) luaL_error(L, "invalid order function for sorting");
36:          lua_pop(L, 1);  /* remove a[i] */
37:        }
38:        /* repeat --j until a[j] <= P */
39:        while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) {
40:          if (j<l) luaL_error(L, "invalid order function for sorting");
41:          lua_pop(L, 1);  /* remove a[j] */
42:        }
43:        if (j<i) {
44:          lua_pop(L, 3);  /* pop pivot, a[i], a[j] */
45:          break;
46:        }
47:        set2(L, i, j);
48:      }
49:      lua_rawgeti(L, 1, u-1);
50:      lua_rawgeti(L, 1, i);
51:      set2(L, u-1, i);  /* swap pivot (a[u-1]) with a[i] */
52:      /* a[l..i-1] <= a[i] == P <= a[i+1..u] */
53:      /* adjust so that smaller half is in [j..i] and larger one in [l..u] */
54:      if (i-l < u-i) {
55:        j=l; i=i-1; l=i+2;
56:      }
57:      else {
58:        j=i+1; i=u; u=j-2;
59:      }
60:      auxsort(L, j, i);  /* call recursively the smaller one */
61:    }  /* repeat the routine for the larger one */
62:  }

上面是Lua源碼中的sort方法的實現,Lua源碼中包含了很多對堆棧的操作,爲了更直觀的看到算法的實現,對上面的方法進行了僞代碼的實現:

1:  sort(array list, int l, int u)
2:  {
3:  	while(l < u)
4:  	{
5:  	    if(list[u] < list[l])
6:  	      	swap(list[u], list[l])
7:  
8:  	    if(u - l == 1)
9:  	      	break
10:  	    int i = (l + u)/2
11:  	    if(list[i] < list[l])
12:  	      	swap(list[i], list[l])
13:  	    else if(list[u] < list[i])
14:  	      	swap(list[u], list[l])
15:  
16:  	    if(u - l == 2)
17:  	      	break
18:  
19:  	    int p = i
20:  	    swap(list[i], list[u-1])
21:  	    i = l
22:  	    j = u-1
23:  	    for(;;)
24:  	    {
25:  			while(++i, list[i] < list[p])
26:  			{
27:  				if(i > u)
						error("invalid order function for sorting")
28:  			}
29:  			while(--j, list[j] > list[p])
30:  			{
27:  				if(j < l)
						error("invalid order function for sorting")
32:  			}
33:  			if(j < i)
34:  	        	break
35:  
36:  	    	swap(list[i], list[j])
37:  	    }
38:  
39:  	    swap(list[u-1], list[i])
40:  	    if (i-l < u-i)
41:  	    	j=l; i=i-1; l=i+2;
42:  	    else 
43:  	    	j=i+1; i=u; u=j-2;
44:  
45:  	    sort(list, j, i)
46:    	}
47:  }

通過僞代碼可以看出sort的核心算法的本質是快速排序,所以說table.sort排序是不穩定的排序;第5-17行是使用三數取中的方法對排序算法進行了優化;第39-43行主要是通過對中間值所在位置的判斷,來設置下次排序的起始和結束索引;

下面看一下sort_comp方法的實現:

1:  	static int sort_comp (lua_State *L, int a, int b) {
2:  	if (!lua_isnil(L, 2)) {  /* function? */
3:  		int res;
4:  		lua_pushvalue(L, 2);
5:  		lua_pushvalue(L, a-1);  /* -1 to compensate function */
6:  		lua_pushvalue(L, b-2);  /* -2 to compensate function and `a' */
7:  		lua_call(L, 2, 1);
8:  		res = lua_toboolean(L, -1);
9:  		lua_pop(L, 1);
10:  		return res;
11:  	}
12:  	else  /* a < b? */
13:  		return lua_lessthan(L, a, b);
14:  	}

可以看到在第二行對table.sort的第二個參數comp進行了檢查,當comp參數爲空時,則使用系統提供的升序排序方法;

(注:以上都是自己的理解,歡迎各位大佬指正!)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章