讓數百萬臺手機訓練同一個模型?Google把這套框架開源了

作者 | 琥珀

出品 | AI科技大本營(公衆號id:rgznai100)

【導語】據瞭解,全球有 30 億臺智能手機和 70 億臺邊緣設備。每天,這些電話與設備之間的交互不斷產生新的數據。傳統的數據分析和機器學習模式,都需要在處理數據之前集中收集數據至服務器,然後進行機器學習訓練並得到模型參數,最終獲得更好的產品。

但如果這些需要聚合的數據敏感且昂貴的話,那麼這種中心化的數據收集手段可能就不太適用了。

去掉這一步驟,直接在生成數據的邊緣設備上進行數據分析和機器學習訓練呢?

近日,Google 開源了一款名爲 TensorFlow Federated (TFF)的框架,可用於去中心化(decentralized)數據的機器學習及運算實驗。它實現了一種稱爲聯邦學習(Federated Learning,FL)的方法,將爲開發者提供分佈式機器學習,以便在沒有數據離開設備的情況下,便可在多種設備上訓練共享的 ML 模型。其中,通過加密方式提供多一層的隱私保護,並且設備上模型訓練的權重與用於連續學習的中心模型共享。

傳送門:https://www.tensorflow.org/federated/

實際上,早在 2017 年 4 月,Google AI 團隊就推出了聯邦學習的概念。這種被稱爲聯邦學習的框架目前已應用在 Google 內部用於訓練神經網絡模型,例如智能手機中虛擬鍵盤的下一詞預測和音樂識別搜索功能。

圖注:每臺手機都在本地訓練模型(A);將用戶更新信息聚合(B);然後形成改進的共享模型(C)。

DeepMind 研究員 Andrew Trask 隨後發推稱讚:“Google 已經開源了 Federated Learning……可在數以百萬計的智能手機上共享模型訓練!”

讓我們一起來看看使用教程:

從一個著名的圖像數據集 MNIST 開始。MNIST 的原始數據集爲 NIST,其中包含 81 萬張手寫的數字,由 3600 個志願者提供,目標是建立一個識別數字的 ML 模型。

傳統手段是立即將 ML 算法應用於整個數據集。但實際上,如果數據提供者不願意將原始數據上傳到中央服務器,就無法將所有數據聚合在一起。

TFF 的優勢就在於,可以先選擇一個 ML 模型架構,然後輸入數據進行訓練,同時保持每個數據提供者的數據是獨立且保存在本地。

下面顯示的是通過調用 TFF 的 FL API,使用已由 GitHub 上的“Leaf”項目處理的 NIST 數據集版本來分隔每個數據提供者所寫的數字:

GitHub 傳送鏈接:https://github.com/TalwalkarLab/leaf

# Load simulation data.

source, _ = tff.simulation.datasets.emnist.load_data()

defclient_data(n):

  dataset = source.create_tf_dataset_for_client(source.client_ids[n])

  return mnist.keras_dataset_from_emnist(dataset).repeat(10).batch(20)

# Wrap a Keras model for use with TFF.

defmodel_fn():

  return tff.learning.from_compiled_keras_model(

      mnist.create_simple_keras_model(), sample_batch)

# Simulate a few rounds of training with the selected client devices.

trainer = tff.learning.build_federated_averaging_process(model_fn)

state = trainer.initialize()

for _ in range(5):

  state, metrics = trainer.next(state, train_data)

  print (metrics.loss)

除了可調用 FL API 外,TFF 還帶有一組較低級的原語(primitive),稱之爲 Federated Core (FC) API。這個 API 支持在去中心化的數據集上表達各種計算。

使用 FL 進行機器學習模型訓練僅是第一步;其次,我們還需要對這些數據進行評估,這時就需要 FC API 了。

假設我們有一系列傳感器可用於捕獲溫度讀數,並希望無需上傳數據便可計算除這些傳感器上的平均溫度。調用 FC 的 API,就可以表達一種新的數據類型,例如指出 tf.float32,該數據位於分佈式的客戶端上。

READINGS_TYPE = tff.FederatedType(tf.float32, tff.CLIENTS)

然後在該類型的數據上定義聯邦平均數。

@tff.federated_computation(READINGS_TYPE)

defget_average_temperature(sensor_readings):

  return tff.federated_average(sensor_readings)

之後,TFF 就可以在去中心化的數據環境中運行。從開發者的角度來講,FL 算法可以看做是一個普通的函數,它恰好具有駐留在不同位置(分別在各個客戶端和協調服務中的)輸入和輸出。

例如,使用了 TFF 之後,聯邦平均算法的一種變體:

參考鏈接:https://arxiv.org/abs/1602.05629

@tff.federated_computation(

  tff.FederatedType(DATASET_TYPE, tff.CLIENTS),

  tff.FederatedType(MODEL_TYPE, tff.SERVER, all_equal=True),

  tff.FederatedType(tf.float32, tff.SERVER, all_equal=True))

deffederated_train(client_data, server_model, learning_rate):

  return tff.federated_average(

      tff.federated_map(local_train, [

          client_data,

          tff.federated_broadcast(server_model),

          tff.federated_broadcast(learning_rate)]))

目前已開放教程,可以先在模型上試驗現有的 FL 算法,也可以爲 TFF 庫提供新的聯邦數據集和模型,還可以添加新的 FL 算法實現,或者擴展現有 FL 算法的新功能。

據瞭解,在 FL 推出之前,Google 還推出了 TensorFlow Privacy,一個機器學習框架庫,旨在讓開發者更容易訓練具有強大隱私保障的 AI 模型。目前二者可以集成,在差異性保護用戶隱私的基礎上,還能通過聯邦學習(FL)技術快速訓練模型。

最後附上 TF Dev Summit’19 上,TensorFlow Federated (TFF)的發佈會現場視頻:

參考鏈接:https://medium.com/tensorflow/introducing-tensorflow-federated-a4147aa20041

(本文爲 AI科技大本營原創文章,轉載請微信聯繫 1092722531)

4 月13日-4 月14日,CSDN 將在北京主辦“Python 開發者日( 2019 )”,匯聚十餘位來自阿里巴巴IBM英偉達等國內外一線科技公司的Python技術專家,還有數百位來自各行業領域的Python開發者。目前購票通道已開啓,早鳥票限量發售中,3 月15日之前可享受優惠價 299 元(售完即止)。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章