【NVCaffe源碼分析】AnnotatedDataLayer(2)

LayerSetUp

    上一篇講了 AnnotatedDataLayer 的構造函數部分,在完成對象的初始化後緊接着會調用每一層的 LayerSetUp(...)函數來對各層進行相應的設置。注意到 AnnotatedDataLayer以及其直接父類 DataLayer都沒有重寫 LayerSetUp函數,所以責任自然就落到了 爺爺類 BasePrefetchingDataLayer 身上了。

template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::LayerSetUp(const vector<Blob*>& bottom,
    const vector<Blob*>& top) {
  bottom_init_ = bottom;
  top_init_ = top;
  //先調了父類的layersetup
  BaseDataLayer<Ftype, Btype>::LayerSetUp(bottom, top);

  //這裏也很重要,設置了 bwd和 fwd
  for (int i = 0; i < transf_num_; ++i) {
    bwd_data_transformers_.emplace_back(
        make_shared<DataTransformer<Btype>>(this->transform_param_, this->phase_));
    fwd_data_transformers_.emplace_back(
        make_shared<DataTransformer<Ftype>>(this->transform_param_, this->phase_));
  }
  const Solver* psolver = this->parent_solver();
  const uint64_t random_seed = (psolver == nullptr ||
      static_cast<uint64_t>(psolver->param().random_seed()) == Caffe::SEED_NOT_SET) ?
          Caffe::next_seed() : static_cast<uint64_t>(psolver->param().random_seed());
  //在這啓動了預取內部線程
  StartInternalThread(false, random_seed);
}

這個函數裏面最重要的部分就是:StartInternalThread(false, random_seed),由它啓動了內部的數據預取線程。啓動了多少個呢?這在上一篇文章裏面有作分析,像我的配置文件的話那就是4個線程。也就是說一個 solver 裏面就有4個數據預取線程,一共有4張卡就有16個數據預取線程。線程函數入口爲:InternalThreadEntryN。同樣因爲 AnnotatedDataLayer 和 DataLayer 都沒有重寫 InternalThreadEntryN,所以我們直接看 BasePrefetchingDataLayer 裏面重寫的 InternalThreadEntryN 就好了。

template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::InternalThreadEntryN(size_t thread_id) {
  const bool auto_mode = this->auto_mode(); //我的爲 false
  if (auto_mode) {
    iter0_.wait_reset();  // sample reader first
  } else if (this->phase_ == TRAIN) {
    iter0_.wait(); //等待,等待什麼??
  }
  if (auto_mode && this->net_inititialized_flag_ != nullptr) {
    this->net_inititialized_flag_->wait();
  }
  
  InitializePrefetch(); //一定不要漏過這個函數
  start_reading(); //置位 start_reading_flag
  
  // 感覺這個函數要和datareader.cpp中的結合着讀
  try {
    while (!must_stop(thread_id)) {
      //隊列id 和線程 id 的關係?
      const size_t qid = this->queue_id(thread_id);
      //喲,這裏看到了和 batch_transformer 的交互
      shared_ptr<Batch> batch = batch_transformer_->prefetched_pop_free(qid);
      CHECK_EQ((size_t) -1L, batch->id());
      // 數據從哪裏來? 從 data reader 裏面來,data reader 纔是真正與數據庫交互的
      load_batch(batch.get(), thread_id, qid); 
      if (must_stop(thread_id)) {
        break;
      }
      // 取得的數據交給 batch transformer 由他做進一步的處理
      batch_transformer_->prefetched_push_full(qid, batch);

      if (auto_mode) {
        iter0_.set();
        break;
      }
    }
  } catch (boost::thread_interrupted&) {
  }
}

核心邏輯在那個 while(...)循環裏面,但是在這之前千萬不要漏過了那個InitializePrefetch()函數。

template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::InitializePrefetch() {
  ResizeQueues();
  //這裏會調用 AnnodatedDataLayer 中的 DataLayerSetUp.這很合理,因爲那裏會創建 areader
  //這個DataLayerSetUp 在 BaseDataLayer中是純虛函數,在BasePrefetchingDataLayer
  //中沒有實現這個函數,它的兒子類DataLayer和孫子類AnnodatedDataLayer都做了實現.
  this->DataLayerSetUp(bottom_init_, top_init_); 
}

函數內部會繼續調用 this->DataLayerSetup,調誰的呢?調 AnnotatedDataLayer 的。

template <typename Ftype, typename Btype>
void AnnotatedDataLayer<Ftype, Btype>::DataLayerSetUp(
    const vector<Blob*>& bottom, const vector<Blob*>& top) {
  const LayerParameter& param = this->layer_param();
  const AnnotatedDataParameter& anno_data_param = param.annotated_data_param();
  const int batch_size = param.data_param().batch_size(); //還不是在prototxt中設定的咯
  const bool cache = this->cache_ && this->phase_ == TRAIN;
  const bool shuffle = cache && this->shuffle_ && this->phase_ == TRAIN;
  TBlob<Ftype> transformed_datum;
  //根據你 prototxt 中設置的 batch_sampler{...}逐個添加
  for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i) {
    batch_samplers_.push_back(anno_data_param.batch_sampler(i));
  }
  //auto_mode 爲 false可以跳過這一大段不看
  if (this->auto_mode()) {
    if (!sample_areader_) {
      sample_areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
          Caffe::solver_count(),
          this->rank_,
          this->parsers_num_,
          this->threads_num(),
          batch_size,
          true,
          false,
          cache,
          shuffle,
          false);
    } else if (!areader_) {
      areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
          Caffe::solver_count(),
          this->rank_,
          this->parsers_num_,
          this->threads_num(),
          batch_size,
          false,
          true,
          cache,
          shuffle,
          this->phase_ == TRAIN);
    }
  } else if (!areader_) {
    //這就是傳說中的data reader,它是一個數據讀取器,派生自 InternalThread
    areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
        Caffe::solver_count(),
        this->rank_,
        this->parsers_num_,  //parser_threads: 4
        this->threads_num(), //threads: 4
        batch_size, //注意,這是一個 solver 中的 batch_size
        false,
        false,
        cache,
        shuffle,
        this->phase_ == TRAIN);
    //這個函數貌似只是設置一個read的flag,so線程的初始化部分
    //應該在上面創建areader_這一步就完成了,然後在等待這個讀的信號?
    start_reading(); //會調用areader的start_reading
  }

最重要的就是創建了 DataReader 對象,這個 data reader 是真正與底層數據庫(lmdb、leveldb 等)打交道的。它 類似於 BatchTransfer也是 直接派生自 InternalThread,重寫了 InternalThreadEntryN線程函數。

template<typename DatumType>
void DataReader<DatumType>::InternalThreadEntryN(size_t thread_id) {
  if (cache_) {
    data_cache_->check_db(db_source_);
    data_cache_->register_new_thread();
  }

  unique_ptr<db::DB> db; //unique_ptr也是智能指針,但它獨佔所指向的對象
  {
    std::lock_guard<std::mutex> lock(db_mutex_);
    db.reset(db::GetDB(backend_));
    db->Open(db_source_, db::READ); //打開數據庫開始讀數據
  }

  CursorManager cm(db.get(),
      this,
      solver_count_,
      solver_rank_,
      parser_threads_num_, //test:1
      thread_id,
      batch_size_,
      cache_ && !sample_only_,
      shuffle_ && !sample_only_,
      epoch_count_required_); //train:true, test:false
  shared_ptr<DatumType> init_datum = make_shared<DatumType>();
  cm.fetch(init_datum.get());
  init_->push(init_datum);

  if (!sample_only_) {
    start_reading_flag_.wait();
  }
  cm.rewind(); 
  size_t skip = skip_one_batch_ ? batch_size_ : 0UL;

  size_t queue_id, ranked_rec, batch_on_solver, sample_count = 0UL;
  shared_ptr<DatumType> datum = make_shared<DatumType>();
  /*
  1.每一次datareader將free中已經被消費過的對象取出,填上新的數據,然後將其塞入full中;
  2.每一次BasePrefetchingDataLayer將full中的數據取出並消費,然後將其塞入free中.
  */
  try {
    while (!must_stop(thread_id)) {
      cm.next(datum);
      // See comment below
      ranked_rec = (size_t) datum->record_id() / cm.full_cycle();
      batch_on_solver = ranked_rec * parser_threads_num_ + thread_id;
      //這個queue_id感覺大有學問...
      queue_id = batch_on_solver % queues_num_;

      if (thread_id == 0 && skip > 0U) {
        --skip;
        continue;
      }

      full_push(queue_id, datum);

      if (sample_only_) {
        ++sample_count;
        if (sample_count >= batch_size_) {
          // sample batch complete
          break;
        }
      }
      datum = free_pop(queue_id);
    }
  } catch (boost::thread_interrupted&) {
  }
}

DataReader通過雙阻塞隊列(BBQ)和 BasePrefetchingDataLayer 進行交互。

  vector<shared_ptr<BlockingQueue<shared_ptr<DatumType>>>> free_;
  vector<shared_ptr<BlockingQueue<shared_ptr<DatumType>>>> full_;

在 BasePrefetchingDataLayer 的線程函數while 循環中會不斷地調用 load_batch(batch.get(), thread_id, qid),正是這個 load_batch函數它的數據就源自 data reader,到這裏爲止好多東西都通暢了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章