強大的C# Expression在一個函數求導問題中的簡單運用

 號稱面試的題目總是非常有趣的,這裏是又一個例子:

【原題出處

http://topic.csdn.net/u/20110928/15/B00A34FE-8544-42E2-A771-3C4A888DB85A.html

【問題梗概】

求一個函數的一階導數。

【代碼方案】

 

[csharp] view plaincopy
  1. namespace Derivative  
  2. {  
  3.     class Program  
  4.     {  
  5.         // 求一個節點表達的算式的導函數    
  6.         static Expression GetDerivative(Expression node)  
  7.         {  
  8.             if (node.NodeType == ExpressionType.Add  
  9.                 || node.NodeType == ExpressionType.Subtract)  
  10.             {   // 該節點在做加減法,套用加減法導數公式    
  11.                 BinaryExpression binexp = (BinaryExpression)node;  
  12.                 Expression dleft = GetDerivative(binexp.Left);  
  13.                 Expression dright = GetDerivative(binexp.Right);  
  14.                 BinaryExpression resbinexp;  
  15.   
  16.                 if (node.NodeType == ExpressionType.Add)  
  17.                     resbinexp = Expression.Add(dleft, dright);  
  18.                 else  
  19.                     resbinexp = Expression.Subtract(dleft, dright);  
  20.                 return resbinexp;  
  21.             }  
  22.             else if (node.NodeType == ExpressionType.Multiply)  
  23.             {   // 該節點在做乘法,套用乘法導數公式    
  24.                 BinaryExpression binexp = (BinaryExpression)node;  
  25.                 Expression left = binexp.Left;  
  26.                 Expression right = binexp.Right;  
  27.   
  28.                 Expression dleft = GetDerivative(left);  
  29.                 Expression dright = GetDerivative(right);  
  30.   
  31.                 return Expression.Add(Expression.Multiply(dleft, right),  
  32.                     Expression.Multiply(left, dright));  
  33.             }  
  34.             else if (node.NodeType == ExpressionType.Parameter)  
  35.             {   // 該節點是x本身(葉子節點),故而其導數即常數1    
  36.                 return Expression.Constant(1.0);  
  37.             }  
  38.             else if (node.NodeType == ExpressionType.Constant)  
  39.             {   // 該節點是一個常數(葉子節點),故其導數爲零    
  40.                 return Expression.Constant(0.0);  
  41.             }  
  42.             else if (node.NodeType == ExpressionType.Call)  
  43.             {  
  44.                 MethodCallExpression callexp = (MethodCallExpression)node;  
  45.                 Expression arg0 = callexp.Arguments[0];  
  46.                 // 一下一元函數求導後均需要乘上自變量的導數  
  47.                 Expression darg0 = GetDerivative(arg0);  
  48.                 if (callexp.Method.Name == "Exp")  
  49.                 {  
  50.                     // 指數函數的導數還是其本身  
  51.                     return Expression.Multiply(  
  52.                            Expression.Call(null, callexp.Method, arg0), darg0);  
  53.                 }  
  54.                 else if (callexp.Method.Name == "Sin")  
  55.                 {  
  56.                     // 正弦函數的倒數是餘弦函數  
  57.                     MethodInfo miCos = typeof(Math).GetMethod("Cos",   
  58.                                        BindingFlags.Public | BindingFlags.Static);  
  59.                     return Expression.Multiply(  
  60.                            Expression.Call(null, miCos, arg0), darg0);  
  61.                 }  
  62.                 else if (callexp.Method.Name == "Cos")  
  63.                 {  
  64.                     // 餘弦函數的導數是正弦函數的相反數  
  65.                     MethodInfo miSin = typeof(Math).GetMethod("Sin",   
  66.                                        BindingFlags.Public | BindingFlags.Static);  
  67.                     return Expression.Multiply(  
  68.                            Expression.Negate(Expression.Call(null, miSin, arg0)), darg0);  
  69.                 }  
  70.             }  
  71.   
  72.             throw new NotImplementedException();    // 其餘的尚未實現            
  73.         }  
  74.   
  75.         static Func<doubledouble> GetDerivative(Expression<Func<doubledouble>> func)  
  76.         {  
  77.             // 從Lambda表達式中獲得函數體    
  78.             Expression resBody = GetDerivative(func.Body);  
  79.   
  80.             // 需要續用Lambda表達式的自變量    
  81.             ParameterExpression parX = func.Parameters[0];  
  82.   
  83.             Expression<Func<doubledouble>> resFunc  
  84.                 = (Expression<Func<doubledouble>>)Expression.Lambda(resBody, parX);  
  85.   
  86.             Console.WriteLine("diff function = {0}", resFunc);  
  87.   
  88.             // 編譯成CLR的IL表達的函數    
  89.             return resFunc.Compile();  
  90.         }  
  91.   
  92.         static double GetDerivative(Expression<Func<doubledouble>> func, double x)  
  93.         {  
  94.             Func<doubledouble> diff = GetDerivative(func);  
  95.             return diff(x);  
  96.         }  
  97.   
  98.         static void Main(string[] args)  
  99.         {  
  100.             // 舉例:求出函數f(x) = cos(x*x)+sin(3*x)+exp(2*x)在x=2.0處的導數    
  101.             double y = GetDerivative(x => Math.Cos(x*x) + Math.Sin(3*x) + Math.Exp(2*x), 2.0);  
  102.             Console.WriteLine("f'(x) = {0}", y);  
  103.         }  
  104.     }  
  105. }    

 

 

【實現大意】

用表達式分解並遞歸求導(過程是相當容易的,比想象的還容易)。目前只是實現了一個最簡單的模型。

【優勢】

給出的是解析解,在求導運算方面沒有任何數值解的誤差,輸出運算也是瞬時的,時間複雜度僅和表達式複雜度相關。

【限制】

1. 函數只能以Lambda表達式輸入,只能是能求出解析解的表達式

2. 目前只實現了加減法和乘法

【後續擴展】

1. 實現其他運算符(沒有太大難度,只是比較繁瑣而已)

2. 表達式樹優化(也不太難的,根據情況定),最基本的可以從常數乘法開始……

3. 條件運算符的處理(這個會變得極難極複雜,但一定程度上實現分段函數求導),其他特殊情況(對求導還可以,如果考慮求不定積分問題可能會有很多特殊情況和hardcode)

4. 輸入端向字符串解析過渡;複雜運算符->逐漸向自定義的數據結構過渡?……

...

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章