轉自: http://blog.csdn.net/whinah/article/details/6419680
從理論上講,只要允許使用棧,所有的遞歸程序都可以轉化成迭代。
但是並非所有遞歸都必須用棧,不用堆棧也可以轉化成迭代的,大致有兩類
- 尾遞歸:可以通過簡單的變換,讓遞歸作爲最後一條語句,並且僅此一個遞歸調用。
- // recursive
- int fac1(int n) {
- if (n <= 0) return 1;
- return n * fac1(n-1);
- }
- // iterative
- int fac2(int n) {
- int i = 0, y = 1;
- for (; i <= n; ++i) y *= i;
- return y;
- }
- 自頂向下->自底向上:對程序的結構有深刻理解後,自底向上計算,比如 fibnacci 數列的遞歸->迭代轉化。
- // recursive, top-down
- int fib1(int n) {
- if (n <= 1) return 1;
- return fib1(n-1) + fib1(n-2);
- }
- // iterative, down-top
- int fib2(int n) {
- int f0 = 1, f1 = 1, i;
- for (i = 2; i <= n; ++i) {
- int f2 = f1 + f0;
- f0 = f1; f1 = f2;
- }
- return f1;
- }
對於非尾遞歸,就必須使用堆棧。可以簡單生硬地使用堆棧進行轉化:把函數調用和返回的地方翻譯成彙編代碼,然後把對硬件 stack 的 push, pop 操作轉化成對私有 stack 的 push, pop ,這其中需要特別注意的是對返回地址的 push/pop,對應的硬件指令一般是 call/ret。使用私有 stack 有兩個好處:
- 可以省去公用局部變量,也就是在任何一次遞歸調用中都完全相同的函數參數,再加上從這些參數計算出來的局部變量。
- 如果需要得到當前的遞歸深度,可以從私有 stack 直接拿到,而用遞歸一般需要一個單獨的 depth 變量,然後每次遞歸調用加 1。
我們把私有 stack 元素稱爲 Frame,那麼 Frame 中必須包含以下信息:
- 返回地址(對應於每個遞歸調用的下一條語句的地址)
- 對每次遞歸調用都不同的參數
通過實際操作,我發現,有一類遞歸的 Frame 可以省去返回地址!所以,這裏又分爲兩種情況:
- Frame 中可以省去返回地址的遞歸:僅有兩個遞歸調用,並且其中有一個是尾遞歸。
- // here used a function 'partition', but don't implement it
- tempalte<class RandIter>
- void QuickSort1(RandIter beg, RandIter end) {
- if (end - beg <= 1) return;
- RandIter pos = partition(beg, end);
- QuickSort1(beg, pos);
- QuickSort1(pos + 1, end);
- }
- tempalte<class RandIter>
- void QuickSort2(RandIter beg, RandIter end) {
- std::stack<std::pair<RandIter> > stk;
- stk.push({beg, end});
- while (!stk.empty()) {
- std::pair<RandIter, RandIter> ii = stk.top(); stk.pop();
- if (ii.second - ii.first) > 1) {
- RandIter pos = partition(beg, end);
- stk.push({ii.first, pos});
- stk.push({pos + 1, ii.second});
- }
- }
- }
- Frame 中必須包含返回地址的遞歸,這個比較複雜,所以我寫了個完整的示例:
- 以MergeSort爲例,因爲 MergeSort 是個後序過程,兩個遞歸調用中沒有任何一個是尾遞歸
- MergeSort3 使用了 GCC 的 Label As Value 特性,只能在 GCC 兼容的編譯器中使用
- 單純對於這個實例來說,返回地址其實只有兩種,返回地址爲 0 的情況可以通過判斷私有棧(varname=stk)是否爲空,stk爲空時等效於 retaddr == 0。如果要精益求精,一般情況下指針的最低位總是0,可以把這個標誌保存在指針的最低位,當然,如此的話就無法對 sizeof(T)==1 的對象如 char 進行排序了。
-
- #include <stdio.h>
- #include <string.h>
- # if 1
- #include <stack>
- #include <vector>
- template<class T>
- class MyStack : public std::stack<T, std::vector<T> >
- {
- };
- #else
- template<class T>
- class MyStack {
- union {
- char* a;
- T* p;
- };
- int n, t;
- public:
- explicit MyStack(int n=128) {
- this->n = n;
- this->t = 0;
- a = new char[n*sizeof(T)];
- }
- ~MyStack() {
- while (t > 0)
- pop();
- delete[] a;
- }
- void swap(MyStack<T>& y) {
- char* q = y.a; y.a = a; a = q;
- int z;
- z = y.n; y.n = n; n = z;
- z = y.t; y.t = t; t = z;
- }
- T& top() const {
- return p[t-1];
- }
- void pop() {
- --t;
- p[t].~T();
- }
- void push(const T& x) {
- x.print(); // debug
- p[t] = x;
- ++t;
- }
- int size() const { return t; }
- bool empty() const { return 0 == t; }
- bool full() const { return n == t; }
- };
- #endif
- template<class T>
- struct Frame {
- static T* base;
- T *beg, *tmp;
- int len;
- int retaddr;
- Frame(T* beg, T* tmp, int len, int retaddr)
- : beg(beg), tmp(tmp), len(len), retaddr(retaddr)
- {}
- void print() const { // for debug
- printf("%4d %4d %d/n", int(beg-base), len, retaddr);
- }
- };
- template<class T> T* Frame<T>::base;
- #define TOP(field) stk.top().field
- template<class T>
- bool issorted(const T* a, int n)
- {
- for (int i = 1; i < n; ++i) {
- if (a[i-1] > a[i]) return false;
- }
- return true;
- }
- template<class T>
- void mymerge(const T* a, int la, const T* b, int lb, T* c) {
- int i = 0, j = 0, k = 0;
- for (; i < la && j < lb; ++k) {
- if (b[j] < a[i])
- c[k] = b[j], ++j;
- else
- c[k] = a[i], ++i;
- }
- for (; i < la; ++i, ++k) c[k] = a[i];
- for (; j < lb; ++j, ++k) c[k] = b[j];
- }
- template<class T>
- void MergeSort1(T* beg, T* tmp, int len) {
- if (len > 1) {
- int mid = len / 2;
- MergeSort1(beg , tmp , mid);
- MergeSort1(beg+mid, tmp+mid, len-mid);
- mymerge(tmp, mid, tmp+mid, len-mid, beg);
- memcpy(tmp, beg, sizeof(T)*len);
- }
- else
- *tmp = *beg;
- }
- template<class T>
- void MergeSort2(T* beg0, T* tmp0, int len0) {
- int mid;
- int cnt = 0;
- Frame<T>::base = beg0;
- MyStack<Frame<T> > stk;
- stk.push(Frame<T>(beg0, tmp0, len0, 0));
- while (true) {
- ++cnt;
- if (TOP(len) > 1) {
- mid = TOP(len) / 2;
- stk.push(Frame<T>(TOP(beg), TOP(tmp), mid, 1));
- continue;
- L1:
- mid = TOP(len) / 2;
- stk.push(Frame<T>(TOP(beg)+mid, TOP(tmp)+mid, TOP(len)-mid, 2));
- continue;
- L2:
- mid = TOP(len) / 2;
- mymerge(TOP(tmp), mid, TOP(tmp)+mid, TOP(len)-mid, TOP(beg));
- memcpy(TOP(tmp), TOP(beg), sizeof(T)*TOP(len));
- } else
- *TOP(tmp) = *TOP(beg);
- int retaddr0 = TOP(retaddr);
- stk.pop();
- switch (retaddr0) {
- case 0: return;
- case 1: goto L1;
- case 2: goto L2;
- }
- }
- }
- // This Implementation Use GCC's goto saved label value
- // Very similiar with recursive version
- template<class T>
- void MergeSort3(T* beg0, T* tmp0, int len0) {
- MyEntry:
- int mid;
- int retaddr;
- Frame<T>::base = beg0;
- MyStack<Frame<T> > stk;
- stk.push(Frame<T>(beg0, tmp0, len0, 0));
- #define Cat1(a,b) a##b
- #define Cat(a,b) Cat1(a,b)
- #define HereLabel() Cat(HereLable_, __LINE__)
- #define RecursiveCall(beg, tmp, len) /
- stk.push(Frame<T>(beg, tmp, len, (char*)&&HereLabel() - (char*)&&MyEntry)); /
- continue; /
- HereLabel():;
- //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // retaddr == 0 是最外層的遞歸調用,
- // 只要到達這一層時 retaddr 才爲 0,
- // 此時就可以返回了
- #define MyReturn /
- retaddr = TOP(retaddr); /
- stk.pop(); /
- if (0 == retaddr) { /
- return; /
- } /
- goto *((char*)&&MyEntry + retaddr);
- //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- while (true) {
- if (TOP(len) > 1) {
- mid = TOP(len) / 2;
- RecursiveCall(TOP(beg), TOP(tmp), mid);
- mid = TOP(len) / 2;
- RecursiveCall(TOP(beg)+mid, TOP(tmp)+mid, TOP(len)-mid);
- mid = TOP(len) / 2;
- mymerge(TOP(tmp), mid, TOP(tmp)+mid, TOP(len)-mid, TOP(beg));
- memcpy(TOP(tmp), TOP(beg), sizeof(T)*TOP(len));
- } else
- *TOP(tmp) = *TOP(beg);
- MyReturn;
- }
- }
- template<class T>
- void MergeSortDriver(T* beg, int len, void (*mf)(T* beg_, T* tmp_, int len_))
- {
- T* tmp = new T[len];
- (*mf)(beg, tmp, len);
- delete[] tmp;
- }
- #define test(a,n,mf) /
- memcpy(a, b, sizeof(a[0])*n); /
- MergeSortDriver(a, n, &mf); /
- printf("sort by %s:", #mf); /
- for (i = 0; i < n; ++i) printf("% ld", a[i]); /
- printf("/n");
- int main(int argc, char* argv[])
- {
- int n = argc - 1;
- int i;
- long* a = new long[n];
- long* b = new long[n];
- for (i = 0; i < n; ++i)
- b[i] = strtol(argv[i+1], NULL, 10);
- test(a, n, MergeSort1);
- test(a, n, MergeSort2);
- test(a, n, MergeSort3);
- printf("All Successed/n");
- delete[] a;
- delete[] b;
- return 0;
- }