转自: http://blog.csdn.net/whinah/article/details/6419680
从理论上讲,只要允许使用栈,所有的递归程序都可以转化成迭代。
但是并非所有递归都必须用栈,不用堆栈也可以转化成迭代的,大致有两类
- 尾递归:可以通过简单的变换,让递归作为最后一条语句,并且仅此一个递归调用。
- // recursive
- int fac1(int n) {
- if (n <= 0) return 1;
- return n * fac1(n-1);
- }
- // iterative
- int fac2(int n) {
- int i = 0, y = 1;
- for (; i <= n; ++i) y *= i;
- return y;
- }
- 自顶向下->自底向上:对程序的结构有深刻理解后,自底向上计算,比如 fibnacci 数列的递归->迭代转化。
- // recursive, top-down
- int fib1(int n) {
- if (n <= 1) return 1;
- return fib1(n-1) + fib1(n-2);
- }
- // iterative, down-top
- int fib2(int n) {
- int f0 = 1, f1 = 1, i;
- for (i = 2; i <= n; ++i) {
- int f2 = f1 + f0;
- f0 = f1; f1 = f2;
- }
- return f1;
- }
对于非尾递归,就必须使用堆栈。可以简单生硬地使用堆栈进行转化:把函数调用和返回的地方翻译成汇编代码,然后把对硬件 stack 的 push, pop 操作转化成对私有 stack 的 push, pop ,这其中需要特别注意的是对返回地址的 push/pop,对应的硬件指令一般是 call/ret。使用私有 stack 有两个好处:
- 可以省去公用局部变量,也就是在任何一次递归调用中都完全相同的函数参数,再加上从这些参数计算出来的局部变量。
- 如果需要得到当前的递归深度,可以从私有 stack 直接拿到,而用递归一般需要一个单独的 depth 变量,然后每次递归调用加 1。
我们把私有 stack 元素称为 Frame,那么 Frame 中必须包含以下信息:
- 返回地址(对应于每个递归调用的下一条语句的地址)
- 对每次递归调用都不同的参数
通过实际操作,我发现,有一类递归的 Frame 可以省去返回地址!所以,这里又分为两种情况:
- Frame 中可以省去返回地址的递归:仅有两个递归调用,并且其中有一个是尾递归。
- // here used a function 'partition', but don't implement it
- tempalte<class RandIter>
- void QuickSort1(RandIter beg, RandIter end) {
- if (end - beg <= 1) return;
- RandIter pos = partition(beg, end);
- QuickSort1(beg, pos);
- QuickSort1(pos + 1, end);
- }
- tempalte<class RandIter>
- void QuickSort2(RandIter beg, RandIter end) {
- std::stack<std::pair<RandIter> > stk;
- stk.push({beg, end});
- while (!stk.empty()) {
- std::pair<RandIter, RandIter> ii = stk.top(); stk.pop();
- if (ii.second - ii.first) > 1) {
- RandIter pos = partition(beg, end);
- stk.push({ii.first, pos});
- stk.push({pos + 1, ii.second});
- }
- }
- }
- Frame 中必须包含返回地址的递归,这个比较复杂,所以我写了个完整的示例:
- 以MergeSort为例,因为 MergeSort 是个后序过程,两个递归调用中没有任何一个是尾递归
- MergeSort3 使用了 GCC 的 Label As Value 特性,只能在 GCC 兼容的编译器中使用
- 单纯对于这个实例来说,返回地址其实只有两种,返回地址为 0 的情况可以通过判断私有栈(varname=stk)是否为空,stk为空时等效于 retaddr == 0。如果要精益求精,一般情况下指针的最低位总是0,可以把这个标志保存在指针的最低位,当然,如此的话就无法对 sizeof(T)==1 的对象如 char 进行排序了。
-
- #include <stdio.h>
- #include <string.h>
- # if 1
- #include <stack>
- #include <vector>
- template<class T>
- class MyStack : public std::stack<T, std::vector<T> >
- {
- };
- #else
- template<class T>
- class MyStack {
- union {
- char* a;
- T* p;
- };
- int n, t;
- public:
- explicit MyStack(int n=128) {
- this->n = n;
- this->t = 0;
- a = new char[n*sizeof(T)];
- }
- ~MyStack() {
- while (t > 0)
- pop();
- delete[] a;
- }
- void swap(MyStack<T>& y) {
- char* q = y.a; y.a = a; a = q;
- int z;
- z = y.n; y.n = n; n = z;
- z = y.t; y.t = t; t = z;
- }
- T& top() const {
- return p[t-1];
- }
- void pop() {
- --t;
- p[t].~T();
- }
- void push(const T& x) {
- x.print(); // debug
- p[t] = x;
- ++t;
- }
- int size() const { return t; }
- bool empty() const { return 0 == t; }
- bool full() const { return n == t; }
- };
- #endif
- template<class T>
- struct Frame {
- static T* base;
- T *beg, *tmp;
- int len;
- int retaddr;
- Frame(T* beg, T* tmp, int len, int retaddr)
- : beg(beg), tmp(tmp), len(len), retaddr(retaddr)
- {}
- void print() const { // for debug
- printf("%4d %4d %d/n", int(beg-base), len, retaddr);
- }
- };
- template<class T> T* Frame<T>::base;
- #define TOP(field) stk.top().field
- template<class T>
- bool issorted(const T* a, int n)
- {
- for (int i = 1; i < n; ++i) {
- if (a[i-1] > a[i]) return false;
- }
- return true;
- }
- template<class T>
- void mymerge(const T* a, int la, const T* b, int lb, T* c) {
- int i = 0, j = 0, k = 0;
- for (; i < la && j < lb; ++k) {
- if (b[j] < a[i])
- c[k] = b[j], ++j;
- else
- c[k] = a[i], ++i;
- }
- for (; i < la; ++i, ++k) c[k] = a[i];
- for (; j < lb; ++j, ++k) c[k] = b[j];
- }
- template<class T>
- void MergeSort1(T* beg, T* tmp, int len) {
- if (len > 1) {
- int mid = len / 2;
- MergeSort1(beg , tmp , mid);
- MergeSort1(beg+mid, tmp+mid, len-mid);
- mymerge(tmp, mid, tmp+mid, len-mid, beg);
- memcpy(tmp, beg, sizeof(T)*len);
- }
- else
- *tmp = *beg;
- }
- template<class T>
- void MergeSort2(T* beg0, T* tmp0, int len0) {
- int mid;
- int cnt = 0;
- Frame<T>::base = beg0;
- MyStack<Frame<T> > stk;
- stk.push(Frame<T>(beg0, tmp0, len0, 0));
- while (true) {
- ++cnt;
- if (TOP(len) > 1) {
- mid = TOP(len) / 2;
- stk.push(Frame<T>(TOP(beg), TOP(tmp), mid, 1));
- continue;
- L1:
- mid = TOP(len) / 2;
- stk.push(Frame<T>(TOP(beg)+mid, TOP(tmp)+mid, TOP(len)-mid, 2));
- continue;
- L2:
- mid = TOP(len) / 2;
- mymerge(TOP(tmp), mid, TOP(tmp)+mid, TOP(len)-mid, TOP(beg));
- memcpy(TOP(tmp), TOP(beg), sizeof(T)*TOP(len));
- } else
- *TOP(tmp) = *TOP(beg);
- int retaddr0 = TOP(retaddr);
- stk.pop();
- switch (retaddr0) {
- case 0: return;
- case 1: goto L1;
- case 2: goto L2;
- }
- }
- }
- // This Implementation Use GCC's goto saved label value
- // Very similiar with recursive version
- template<class T>
- void MergeSort3(T* beg0, T* tmp0, int len0) {
- MyEntry:
- int mid;
- int retaddr;
- Frame<T>::base = beg0;
- MyStack<Frame<T> > stk;
- stk.push(Frame<T>(beg0, tmp0, len0, 0));
- #define Cat1(a,b) a##b
- #define Cat(a,b) Cat1(a,b)
- #define HereLabel() Cat(HereLable_, __LINE__)
- #define RecursiveCall(beg, tmp, len) /
- stk.push(Frame<T>(beg, tmp, len, (char*)&&HereLabel() - (char*)&&MyEntry)); /
- continue; /
- HereLabel():;
- //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // retaddr == 0 是最外层的递归调用,
- // 只要到达这一层时 retaddr 才为 0,
- // 此时就可以返回了
- #define MyReturn /
- retaddr = TOP(retaddr); /
- stk.pop(); /
- if (0 == retaddr) { /
- return; /
- } /
- goto *((char*)&&MyEntry + retaddr);
- //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- while (true) {
- if (TOP(len) > 1) {
- mid = TOP(len) / 2;
- RecursiveCall(TOP(beg), TOP(tmp), mid);
- mid = TOP(len) / 2;
- RecursiveCall(TOP(beg)+mid, TOP(tmp)+mid, TOP(len)-mid);
- mid = TOP(len) / 2;
- mymerge(TOP(tmp), mid, TOP(tmp)+mid, TOP(len)-mid, TOP(beg));
- memcpy(TOP(tmp), TOP(beg), sizeof(T)*TOP(len));
- } else
- *TOP(tmp) = *TOP(beg);
- MyReturn;
- }
- }
- template<class T>
- void MergeSortDriver(T* beg, int len, void (*mf)(T* beg_, T* tmp_, int len_))
- {
- T* tmp = new T[len];
- (*mf)(beg, tmp, len);
- delete[] tmp;
- }
- #define test(a,n,mf) /
- memcpy(a, b, sizeof(a[0])*n); /
- MergeSortDriver(a, n, &mf); /
- printf("sort by %s:", #mf); /
- for (i = 0; i < n; ++i) printf("% ld", a[i]); /
- printf("/n");
- int main(int argc, char* argv[])
- {
- int n = argc - 1;
- int i;
- long* a = new long[n];
- long* b = new long[n];
- for (i = 0; i < n; ++i)
- b[i] = strtol(argv[i+1], NULL, 10);
- test(a, n, MergeSort1);
- test(a, n, MergeSort2);
- test(a, n, MergeSort3);
- printf("All Successed/n");
- delete[] a;
- delete[] b;
- return 0;
- }