目前數據越來越多,數據一般存儲在hdfs上,但是目前許多深度學習算法是基於TensorFlow、pytorch等框架實現,使用單機python、java做數據轉換都比較慢,怎麼大規模把hdfs數據直接喂到TensorFlow中,在這裏TensorFlow提供了一種解決方案,利用spark生成tfrecord文件,項目名稱叫spark-tensorflow-connector,GitHub主頁在https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector 這下面,按照readme編譯jar包,放在自己的項目裏面做依賴既可以使用,如果實在不想自己編譯jar包,也可以在這上面直接添加依賴下載https://mvnrepository.com/artifact/org.tensorflow/spark-tensorflow-connector,主要原理是在這個項目裏面寫了一些隱士轉換類類,重寫了輸出的格式,對上層輸出的接口都比較簡單,提供了scala、python的接口,實際背後全部是依賴於proto,不得不佩服google的技術的強大以及推廣能力,下面看下怎麼使用:
官方例子:
package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRecordsExample {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()
val path = "file/test-output.tfrecord"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(
StructField("id", IntegerType),
StructField("IntegerCol", IntegerType),
StructField("LongCol", LongType),
StructField("FloatCol", FloatType),
StructField("DoubleCol", DoubleType),
StructField("VectorCol", ArrayType(DoubleType, true)),
StructField("StringCol", StringType)))
val rdd = spark.sparkContext.parallelize(testRows)
//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").option("recordType", "Example").save(path)
//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
importedDf1.show()
//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()
}
}
讀取bert模型訓練的數據測試:
package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRcordsBert {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()
val path = "/Users/shuubiasahi/Desktop/textclass/"
val schema = StructType(List(
StructField("input_ids", ArrayType(IntegerType, true)),
StructField("input_mask", ArrayType(IntegerType, true)),
StructField("label_ids", ArrayType(IntegerType, true))))
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").load(path)
importedDf1.show()
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()
}
}
+--------------------+--------------------+---------+
| input_ids| input_mask|label_ids|
+--------------------+--------------------+---------+
|[101, 4281, 3566,...|[1, 1, 1, 1, 1, 1...| [25]|
|[101, 3433, 5866,...|[1, 1, 1, 1, 0, 0...| [40]|
|[101, 6631, 5277,...|[1, 1, 1, 1, 1, 1...| [5]|