Premade Estimators
本文檔介紹了 TensorFlow 編程環境,並展示了怎麼用 Premade Estimators 來解決 Iris 分類問題。
文章目錄
需要安裝的包
在運行本文的代碼之前,你需要安裝以下包:
- 安裝TensorFlow。
- 如果你是在virtualenv或Anaconda中安裝的TensorFlow,請activate你的TensorFlow環境。
- 運行下面的代碼安裝、更新pandas包:
pip install pandas
獲取本文的代碼
通過以下步驟來獲得本文的代碼:
- 使用下面的命令將TensorFlow Model repository克隆到本地:
git clone https://github.com/tensorflow/models
- 將目錄切到本文使用的代碼的位置:
cd models/samples/core/get_started/
本文使用的代碼是 premade_estimator.py。該程序使用 iris_data.py 代碼去fetch訓練數據。
運行本文的代碼
使用下面的方式來運行本文的代碼。例如:
python premade_estimator.py
程序會在訓練過程中輸出訓練日誌,然後程序會輸出在測試集上的測試結果。
...
Prediction is "Setosa" (99.6%), expected "Setosa"
Prediction is "Versicolor" (99.8%), expected "Versicolor"
Prediction is "Virginica" (97.9%), expected "Virginica"
如果程序運行的結果有誤,請進行如下檢查:
- TensorFlow安裝有沒有問題
- TensorFlow的版本是否正確
- 確定activate了安裝TensorFlow的環境
1. TensorFlow 編程環境介紹
在使用 TensorFlow 編程之前,讓我們首先研究下 TF 編程環境。如下圖所示,TensorFlow 提供了一個包含多個 API 層的編程堆棧:
我們強烈推薦使用下列 API 編寫 TF 程序:
- Estimators:代表一個完整的模型。Estimator API 提供一些方法來訓練模型、評估模型性能並生成預測。
- Datasets:構建數據輸入管道。Dataset API 提供了一些方法來加載數據(E),操作數據(T),並將數據饋送到您的模型中(L)。Dataset API 與 Estimator API 完美兼容。注:構建數據輸入管道的過程其實就是構建 ETL 的過程。
2. Iris 分類問題
Iris數據集包含了3種鳶尾花、150個樣本,每個樣本包含鳶尾花的四個特徵值:花蕊長度(cm)、花蕊寬度(cm)、花瓣長度(cm)、花瓣寬度(cm)。
1.1 Iris數據集簡介
Iris 數據集包含4個特徵和1個標籤。4個特徵分別爲:
- 花蕊長度(sepal length)
- 花蕊寬度(sepal width)
- 花瓣長度(petal length)
- 花瓣寬度(petal width)
標籤表明了 Iris 的品種,共有三個品種:
- Iris setosa (0)
- Iris versicolor (1)
- Iris virginica (2)
下面的表是數據集的一個片段:
花蕊長度(cm) | 花蕊寬度(cm) | 花瓣長度(cm) | 花瓣寬度(cm) | 品種(標籤) |
---|---|---|---|---|
5.1 | 3.3 | 1.7 | 0.5 | 0 (setosa) |
5.0 | 2.3 | 3.3 | 1.0 | 1 (versicolor) |
6.4 | 2.8 | 5.6 | 2.2 | 2 (virginica) |
1.2 分類算法
本文檔使用一個 DNN 來實現對 Iris 的分類,該分類器的概況如下:
- 2 個隱藏層
- 每一個隱藏層包含 10 個節點
特徵、隱藏層、預測值的概況如下圖所示:
1.3 進行預測
訓練好模型之後,我們便可以使用訓練好的模型來預測標籤未知的鳶尾花的品種。例如,預測結果的格式如下所示:
- 0.03 for Iris Setosa
- 0.95 for Iris Versicolor
- 0.02 for Iris Virginica
3. 使用 Estimator 編程實現 Iris 分類
Estimator 是 TensorFlow 中的高階 API。它會處理 initialization、logging、saving、restoring 等細節,以便研究人員專注於模型。
Estimator API 中有不少的內置 Estimator。當然,除了這些內置 Estimator,你可以自定義 Estimator。推薦在解決問題時將內置 Estimator 作爲一個 baseline。
使用內置 Estimator 解決問題時,一般遵循以下流程:
- 創建一個或多個輸入函數。
- 定義模型的 feature columns。
- 實例化 Estimator,指定 feature columns 和各種超參數。
- 調用 Estimator 對象的一個或多個方法,傳遞合適的輸入函數作爲數據源。
下面詳細介紹下怎麼用內置 Estimator 來解決 Iris 分類問題。
3.1 創建 input 函數
首先創建輸入函數來爲訓練、評估、預測過程提供數據。
輸入函數的返回值爲 tf.data.Dataset
對象,其輸出一個兩元素的元組:
features
- Python 字典,其中:- 每個鍵都是特徵的名稱。
- 每個值都是包含此特徵所有值的數組。
label
- 包含每個樣本的標籤值的數組。
爲了向您展示輸入函數的格式,請查看下面這個簡單的實現:
def input_evaluation_set():
features = {'SepalLength': np.array([6.4, 5.0]),
'SepalWidth': np.array([2.8, 2.3]),
'PetalLength': np.array([5.6, 3.3]),
'PetalWidth': np.array([2.2, 1.0])}
labels = np.array([2, 1])
return features, labels
輸入函數可以以您需要的任何方式生成 features
字典和 label
列表。但是,我們推薦使用 TensorFlow 的 Dataset API,它可以解析各種數據。作爲一個高階 API,Dataset API 包含以下類:
各個類如下所示:
Dataset
:創建、變換數據集的方法 的基類。您還可以通過該類從 內存中的數據 或 Python 生成器 初始化數據集。TextLineDataset
:從文本文件讀取行。TFRecordDataset
:從 TFRecord 文件讀取 records。FixedLengthRecordDataset
:從二進制文件讀取固定大小的 records。Iterator
: 提供一次訪問一個數據集元素的方法。
爲了簡化此示例,我們採用 pandas 來從 csv 文件加載數據到內存,然後從內存中構建數據輸入管道。
以下是本文的示例在訓練過程中使用的輸入函數。詳見 iris_data.py
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) # 這裏的 features 是一個pandas DataFrame,labels 是一個 pandas Series
# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)
3.2 定義 feature columns
feature column 是一個對象,用於說明模型應該如何使用特徵字典中的原始輸入數據。在構建 Estimator 模型時,您會向其傳遞一個特徵列的列表,其中包含您希望模型使用的每個特徵。tf.feature_column
模塊提供很多用於向模型表示數據的選項。
對於 Iris,4個原始的特徵是數值,所以我們將構建一個 feature column 列表,以告知 Estimator 模型將這 4 個特徵都表示爲 32 位浮點值。因此,創建 feature column 的代碼如下:
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
Feature columns 的功能遠超上面的示例。我們將在後面有詳細的介紹。
我們已經介紹了希望模型如何表示原始特徵,現在可以構建 Estimator 了。
3.3 實例化 estimator
Iris 分類問題是一個經典的分類問題。TensorFlow中內置了很多分類 Estimators,包括:
tf.estimator.DNNClassifier
tf.estimator.DNNLinearCombinedClassifier
tf.estimator.LinearClassifier
對於Iris問題,tf.estimator.DNNClassifier
似乎是最好的選擇。
# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer.
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
# Two hidden layers of 10 nodes each.
hidden_units=[10, 10],
# The model must choose between 3 classes.
n_classes=3)
3.4 訓練、評估及預測
我們已經有一個 Estimator 對象,現在可以調用方法來執行下列操作:
- 訓練模型
- 評估經過訓練的模型
- 使用經過訓練的模型進行預測
訓練模型
通過調用 Estimator 的 train
方法來訓練模型:
# 訓練模型
classifier.train(
input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
steps=args.train_steps)
在上面,我們使用 lambda
來對 input_fn 函數進行了包裝。steps
參數用來告訴 train
方法在指定的 training steps 後停止訓練。
評估經過訓練的模型
模型已經過訓練,現在我們可以對模型性能進行一些統計。
# 評估模型
eval_result = classifier.evaluate(
input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
與 train
方法的調用不同,我們沒有給 evaluate
傳遞 steps
參數。因爲我們的 eval_input_fn
只生成一個 epoch 的數據。
運行上面的代碼,可以得到下面的結果:
Test set accuracy: 0.967
eval_result
字典也包含了 average_loss
(mean loss per sample)、loss
(mean loss per mini-batch)和 estimator 的 global_step
(the number of training iterations it underwent)。
使用經過訓練的模型進行預測(inference)
我們已經有一個經過訓練的模型(在測試集有比較好的效果)。我們現在可以使用訓練好的模型去預測 Iris 花的品種。與訓練、評估類似,我們通過調用 predict 方法來進行預測:
# 使用訓練好的模型產生預測
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7, 4.2, 5.4],
'PetalWidth': [0.5, 1.5, 2.1],
}
predictions = classifier.predict(
input_fn=lambda:iris_data.eval_input_fn(predict_x,
batch_size=args.batch_size))
predict
方法返回一個Python迭代器,給每一個 example 生成一個預測結果字典。
template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
for pred_dict, expec in zip(predictions, expected):
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print(template.format(iris_data.SPECIES[class_id],
100 * probability, expec))
運行上面的代碼,產生下面的結果:
...
Prediction is "Setosa" (99.6%), expected "Setosa"
Prediction is "Versicolor" (99.8%), expected "Versicolor"
Prediction is "Virginica" (97.9%), expected "Virginica"
總結
使用內置 Estimator 可以快速創建出一個基礎模型。
關於 Estimator,我們推薦以下閱讀資料:
- Checkpoints:瞭解如何保存和恢復模型。
- Datasets:瞭解如何將數據導入模型中。
- Creating Custom Estimators:瞭解如何針對特定問題,編寫自定義 Estimator。