[05] 通過P/Invoke加速C#程序

通過P/Invoke加速C#程序

任何語言都會提供FFI機制(Foreign Function Interface, 叫法不太一樣), 大多數的FFI機制是和C API. C#提供了P/Invoke來和操作系統, 第三方擴展進行交互.

FFI通常用來和老的代碼交互, 例如有大量的遺留代碼, 重寫成本太高, 可以導出C接口, 然後新系統和老系統交互; 還有一種用處就是優化, 將某一部分功能挪到C/C++(或者其他Native語言)裏面, 通過特殊的優化, 對系統進行加速.

所有的FFI均存在額外的開銷, 除了C++和C這種語言交互. 託管語言和非託管語言, 託管語言和託管語言交互的成本都不小. C#和C的交互, 主要的成本有兩塊:

  • 參數傳遞的成本

    C#裏面的字符串是UTF-16編碼的, 但是在C裏面一般使用ASCII或者兼容的編碼, 所以調用之前需要先做一次轉換.

    內存佈局不一樣的參數, 會有額外的開銷.

  • 調用的額外開銷

    P/Invoke 的開銷介於每個呼叫 10 到 30 x86 指令之間。 除了此固定成本外,封送還會產生額外的開銷。 在託管代碼和非託管代碼中具有相同的表示形式的可聲明類型之間沒有封送成本。 例如,int 和 Int32 之間沒有翻譯費用。

    可以理解爲10-30個時鐘週期, 比虛函數調用成本要高一些.

以上是P/Invoke優化的基礎知識. 只要調用的函數執行的時間較長, 參數的轉換足夠少, 那麼進行P/Invoke優化就是有意義的.

某遊戲服務器使用了AES-ECB加密算法進行通訊協議的加密. 算法一直沒改, 實現修改了好幾次, 因爲整個編碼過程中, 會產生多個臨時byte[]對象, 所以一直想要優化掉.

下面這個版本是C# Slice的版本, 希望把加密後的內容放到我準備好的Slice裏面(IByteBuffer). 但是其中有一個MemoryStream還是無法處理, 這個對象內部還是會產生byte[].

public static int AesEncrypt(byte[] src, int offset, int count, byte[] dest, int destOffset, byte[] Key0)
{
    using Rijndael rm = Rijndael.Create();
    rm.Key = Key0;
    rm.Mode = CipherMode.ECB;
    rm.Padding = PaddingMode.PKCS7;

    using ICryptoTransform cTransform = rm.CreateEncryptor();
    using var memoryStream = new MemoryStream(dest, destOffset, count + 32);
    using var writer = new CryptoStream(memoryStream, cTransform, CryptoStreamMode.Write);
    writer.Write(src, offset, count);
    writer.FlushFinalBlock();

    return (int)memoryStream.Position;
}

花了好長時間去研究.NET內部的實現, 沒找到解決辦法.

所以這時候就把眼睛轉向了P/Invoke和C++. 好在可以先通過C#的版本生成一個輸入輸出樣本, 然後C++嘗試着去跑通整個輸入輸出.

下面是C++的版本:

aes_ech.h

#pragma once
#include <openssl/aes.h>
#include <assert.h>
#include <string.h>

#ifdef WIN32
#define __DLLIMPORT __declspec(dllimport)
#define __DLLEXPORT __declspec(dllexport)
#else
#define __DLLIMPORT
#define __DLLEXPORT 
#endif

extern "C"
{
__DLLEXPORT int AesEcbEncrypt(unsigned char* key, int key_size,
		unsigned char* source, int source_length,
		unsigned char* dest);

__DLLEXPORT int AesEcbDecrypt(unsigned char* key, int key_size,
		unsigned char* source, int source_length,
		unsigned char* dest);
}


static inline int pkcs7padding(unsigned char* data, int length) {
	int padding = AES_BLOCK_SIZE - length % AES_BLOCK_SIZE;
	int destSize = length + padding;
	for (int index = length; index < destSize; ++index) {
		data[index] = padding;
	}
	return destSize;
}

static inline int Encrypt(unsigned char* key, int keyLength,
			unsigned char* src, int srcLength,
			unsigned char* dest) {
	int paddingLength = pkcs7padding(src, srcLength);

	AES_KEY aes_key;
	AES_set_encrypt_key(reinterpret_cast<const unsigned char*>(&key[0]),
		keyLength * 8, &aes_key);

	unsigned char* encrypted = dest;

	for (int block = 0; block < paddingLength; block += AES_BLOCK_SIZE) {
		AES_ecb_encrypt(reinterpret_cast<const unsigned char*>(&src[block]),
			reinterpret_cast<unsigned char*>(&encrypted[block]),
			&aes_key, AES_ENCRYPT);
	}

	return paddingLength;
}

static inline int pkcs7unpadding(unsigned char* data, int dataLength) {
	int padding = data[dataLength - 1];
	return dataLength - padding;
}

static inline int Decrypt(unsigned char *key, int keyLength,
			unsigned char* encrypted, int encryptedLength,
			unsigned char* decrypted) {
	AES_KEY aes_key;
	AES_set_decrypt_key(reinterpret_cast<const unsigned char*>(&key[0]),
		keyLength * 8, &aes_key);

	int decrypted_length = encryptedLength;

	for (int block = 0; block < encryptedLength;
		block += AES_BLOCK_SIZE) {
		AES_ecb_encrypt(reinterpret_cast<const unsigned char*>(&encrypted[block]),
			reinterpret_cast<unsigned char*>(&decrypted[block]),
			&aes_key, AES_DECRYPT);
	}

	return pkcs7unpadding(decrypted, encryptedLength);
}

aes_ecb.cpp

#include "aes_ecb.h"

extern "C" 
{
__DLLEXPORT int AesEcbEncrypt(unsigned char* key, int key_size,
    unsigned char* source, int source_length,
    unsigned char* dest) {
    return ::Encrypt(key, key_size, source, source_length, dest);
}

__DLLEXPORT int AesEcbDecrypt(unsigned char* key, int key_size,
    unsigned char* source, int source_length,
    unsigned char* dest) {
    return ::Decrypt(key, key_size, source, source_length, dest);
}
}

C#的P/Invoke封裝, 以及測試代碼:

using System;
using System.Runtime.InteropServices;
using System.Text;

namespace AesPInvoke
{
    static class AesWin
    {
        [DllImport("AESECB.dll", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbEncrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);

        [DllImport("AESECB.dll", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbDecrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);
    }
    static class AesLinux 
    {
        [DllImport("AESECB.so", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbEncrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);

        [DllImport("AESECB.so", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbDecrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);
    }

    static class Aes 
    {
        public unsafe delegate int AesFunc(byte* key, int key_size, byte* source, int source_length, byte* dest);
        static AesFunc encrypt;
        static AesFunc decrypt;
        public static AesFunc AesEncrpt => encrypt;
        public static AesFunc AesDecrypt => decrypt;
        static unsafe Aes() 
        {
            if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) 
            {
                encrypt = AesLinux.AesEcbEncrypt;
                decrypt = AesLinux.AesEcbDecrypt;
            }
            else 
            {
                encrypt = AesWin.AesEcbEncrypt;
                decrypt = AesWin.AesEcbDecrypt;
            }
        }
    }

    class Program
    {

        private static bool Compare(ArraySegment<byte> a, ArraySegment<byte> b) 
        {
            if (a.Count != b.Count)
            {
                return false;
            }
            for (int i = 0; i < a.Count; ++i) 
            {
                if (a[i] != b[i]) return false;
            }
            return true;
        }

        static  unsafe void Main(string[] args)
        {
            byte[] origin = new byte[] {
                                    0x06, 0x04, 0x34, 0x35, 0x32, 0x56, 0x0a, 0x10, 0x08, 0xf9, 0xeb, 0x06,
                                    0x10, 0x93, 0x12, 0x18, 0x85, 0x1a, 0x20, 0x89, 0xdf, 0xf6, 0xd3, 0x01
            };
            byte[] dest = new byte[] {0x0f, 0xd9, 0x52, 0x10, 0x11, 0x4b, 0xcc, 0xe5,
                              0x48, 0x9d, 0x47, 0x2a, 0x69, 0xa4, 0x19, 0xcc,
                              0x08, 0x6b, 0x7d, 0xe9, 0x65, 0x26, 0x53, 0x10,
                              0x5c, 0xc9, 0x2f, 0xa8, 0x02, 0x43, 0x32, 0x8f};

            var originSegment = new ArraySegment<byte>(origin);
            var destSegment = new ArraySegment<byte>(dest);

            byte[] key = Encoding.UTF8.GetBytes("12345678876543211234567887654abc");

            byte[] input = new byte[origin.Length + 32];
            Array.Copy(origin, input, origin.Length);

            byte[] output = new byte[origin.Length + 32];

            fixed(byte* keyPointer = key) 
            fixed(byte* inputPointer = input)
            fixed(byte* outputPointer = output)
            {
                var length = Aes.AesEncrpt(keyPointer, key.Length, inputPointer, origin.Length, outputPointer);
                var data = new ArraySegment<byte>(output, 0, length);
                Console.WriteLine("{0}", Compare(destSegment, data));
            }

            input = new byte[dest.Length];
            Array.Copy(dest, input, dest.Length);
            output = new byte[dest.Length];

            fixed(byte* keyPointer = key) 
            fixed(byte* inputPointer = input)
            fixed(byte* outputPointer = output)
            {
                var length = Aes.AesDecrypt(keyPointer, key.Length, inputPointer, dest.Length, outputPointer);
                var data = new ArraySegment<byte>(output, 0, length);
                Console.WriteLine("{0}", Compare(originSegment, data));
            }

            Console.WriteLine("Hello World!");
        }
    }
}

跑通測試之後, 就可以集成到系統裏面去, 把託管實現給替換掉. 一次可以把多餘的AllocArray, 和加速同時完成.

C++版本的AES ECB加密使用了OpenSSL庫, 好處是工業級實現, 而且還有可能會有AES-NI加速, Windows上面只需要通過vcpkg就可以方便的移植過來, Linux上面本身就有這個庫.

大部分C#代碼都可以跑得非常快, 一般情況下是不需要進行這種極端優化. 但是某遊戲服務器是一個比較特殊的服務器, 其服務器只有一個進程, 一個進程內需要跑IO密集, 計算密集(加解密,物理,戰鬥等), 還要承擔GC的負擔, 所以才採用了這種優化方式.

參考:

  1. P/Invoke
  2. P/Invoke開銷
  3. OpenSSL AES
  4. AES-NI Performance
  5. vcpkg

通過P/Invoke加速C#程序

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章