利用spark生成tfrecord文件

目前數據越來越多,數據一般存儲在hdfs上,但是目前許多深度學習算法是基於TensorFlow、pytorch等框架實現,使用單機python、java做數據轉換都比較慢,怎麼大規模把hdfs數據直接喂到TensorFlow中,在這裏TensorFlow提供了一種解決方案,利用spark生成tfrecord文件,項目名稱叫spark-tensorflow-connector,GitHub主頁在https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector 這下面,按照readme編譯jar包,放在自己的項目裏面做依賴既可以使用,如果實在不想自己編譯jar包,也可以在這上面直接添加依賴下載https://mvnrepository.com/artifact/org.tensorflow/spark-tensorflow-connector,主要原理是在這個項目裏面寫了一些隱士轉換類類,重寫了輸出的格式,對上層輸出的接口都比較簡單,提供了scala、python的接口,實際背後全部是依賴於proto,不得不佩服google的技術的強大以及推廣能力,下面看下怎麼使用:

 

官方例子:

package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

object TFRecordsExample {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()

    val path = "file/test-output.tfrecord"
    val testRows: Array[Row] = Array(
      new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
      new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
      
    val schema = StructType(List(
      StructField("id", IntegerType),
      StructField("IntegerCol", IntegerType),
      StructField("LongCol", LongType),
      StructField("FloatCol", FloatType),
      StructField("DoubleCol", DoubleType),
      StructField("VectorCol", ArrayType(DoubleType, true)),
      StructField("StringCol", StringType)))

    val rdd = spark.sparkContext.parallelize(testRows)

    //Save DataFrame as TFRecords
    val df: DataFrame = spark.createDataFrame(rdd, schema)
    df.write.format("tfrecords").option("recordType", "Example").save(path)

    //Read TFRecords into DataFrame.
    //The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
    val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
    importedDf1.show()

    //Read TFRecords into DataFrame using custom schema
    val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
    importedDf2.show()

  }

}

 

讀取bert模型訓練的數據測試:

package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRcordsBert {
  def main(args: Array[String]): Unit = {
    
     Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()

    val path = "/Users/shuubiasahi/Desktop/textclass/"
    val schema = StructType(List(
      StructField("input_ids",  ArrayType(IntegerType, true)),
      StructField("input_mask",  ArrayType(IntegerType, true)),
      StructField("label_ids",  ArrayType(IntegerType, true))))
     
      
   val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").load(path)
    importedDf1.show()

    val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
    importedDf2.show()

    
    
  }
}

+--------------------+--------------------+---------+
|           input_ids|          input_mask|label_ids|
+--------------------+--------------------+---------+
|[101, 4281, 3566,...|[1, 1, 1, 1, 1, 1...|     [25]|
|[101, 3433, 5866,...|[1, 1, 1, 1, 0, 0...|     [40]|
|[101, 6631, 5277,...|[1, 1, 1, 1, 1, 1...|      [5]|

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章