【NVCaffe源码分析】AnnotatedDataLayer(2)

LayerSetUp

    上一篇讲了 AnnotatedDataLayer 的构造函数部分,在完成对象的初始化后紧接着会调用每一层的 LayerSetUp(...)函数来对各层进行相应的设置。注意到 AnnotatedDataLayer以及其直接父类 DataLayer都没有重写 LayerSetUp函数,所以责任自然就落到了 爷爷类 BasePrefetchingDataLayer 身上了。

template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::LayerSetUp(const vector<Blob*>& bottom,
    const vector<Blob*>& top) {
  bottom_init_ = bottom;
  top_init_ = top;
  //先调了父类的layersetup
  BaseDataLayer<Ftype, Btype>::LayerSetUp(bottom, top);

  //这里也很重要,设置了 bwd和 fwd
  for (int i = 0; i < transf_num_; ++i) {
    bwd_data_transformers_.emplace_back(
        make_shared<DataTransformer<Btype>>(this->transform_param_, this->phase_));
    fwd_data_transformers_.emplace_back(
        make_shared<DataTransformer<Ftype>>(this->transform_param_, this->phase_));
  }
  const Solver* psolver = this->parent_solver();
  const uint64_t random_seed = (psolver == nullptr ||
      static_cast<uint64_t>(psolver->param().random_seed()) == Caffe::SEED_NOT_SET) ?
          Caffe::next_seed() : static_cast<uint64_t>(psolver->param().random_seed());
  //在这启动了预取内部线程
  StartInternalThread(false, random_seed);
}

这个函数里面最重要的部分就是:StartInternalThread(false, random_seed),由它启动了内部的数据预取线程。启动了多少个呢?这在上一篇文章里面有作分析,像我的配置文件的话那就是4个线程。也就是说一个 solver 里面就有4个数据预取线程,一共有4张卡就有16个数据预取线程。线程函数入口为:InternalThreadEntryN。同样因为 AnnotatedDataLayer 和 DataLayer 都没有重写 InternalThreadEntryN,所以我们直接看 BasePrefetchingDataLayer 里面重写的 InternalThreadEntryN 就好了。

template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::InternalThreadEntryN(size_t thread_id) {
  const bool auto_mode = this->auto_mode(); //我的为 false
  if (auto_mode) {
    iter0_.wait_reset();  // sample reader first
  } else if (this->phase_ == TRAIN) {
    iter0_.wait(); //等待,等待什么??
  }
  if (auto_mode && this->net_inititialized_flag_ != nullptr) {
    this->net_inititialized_flag_->wait();
  }
  
  InitializePrefetch(); //一定不要漏过这个函数
  start_reading(); //置位 start_reading_flag
  
  // 感觉这个函数要和datareader.cpp中的结合着读
  try {
    while (!must_stop(thread_id)) {
      //队列id 和线程 id 的关系?
      const size_t qid = this->queue_id(thread_id);
      //哟,这里看到了和 batch_transformer 的交互
      shared_ptr<Batch> batch = batch_transformer_->prefetched_pop_free(qid);
      CHECK_EQ((size_t) -1L, batch->id());
      // 数据从哪里来? 从 data reader 里面来,data reader 才是真正与数据库交互的
      load_batch(batch.get(), thread_id, qid); 
      if (must_stop(thread_id)) {
        break;
      }
      // 取得的数据交给 batch transformer 由他做进一步的处理
      batch_transformer_->prefetched_push_full(qid, batch);

      if (auto_mode) {
        iter0_.set();
        break;
      }
    }
  } catch (boost::thread_interrupted&) {
  }
}

核心逻辑在那个 while(...)循环里面,但是在这之前千万不要漏过了那个InitializePrefetch()函数。

template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::InitializePrefetch() {
  ResizeQueues();
  //这里会调用 AnnodatedDataLayer 中的 DataLayerSetUp.这很合理,因为那里会创建 areader
  //这个DataLayerSetUp 在 BaseDataLayer中是纯虚函数,在BasePrefetchingDataLayer
  //中没有实现这个函数,它的儿子类DataLayer和孙子类AnnodatedDataLayer都做了实现.
  this->DataLayerSetUp(bottom_init_, top_init_); 
}

函数内部会继续调用 this->DataLayerSetup,调谁的呢?调 AnnotatedDataLayer 的。

template <typename Ftype, typename Btype>
void AnnotatedDataLayer<Ftype, Btype>::DataLayerSetUp(
    const vector<Blob*>& bottom, const vector<Blob*>& top) {
  const LayerParameter& param = this->layer_param();
  const AnnotatedDataParameter& anno_data_param = param.annotated_data_param();
  const int batch_size = param.data_param().batch_size(); //还不是在prototxt中设定的咯
  const bool cache = this->cache_ && this->phase_ == TRAIN;
  const bool shuffle = cache && this->shuffle_ && this->phase_ == TRAIN;
  TBlob<Ftype> transformed_datum;
  //根据你 prototxt 中设置的 batch_sampler{...}逐个添加
  for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i) {
    batch_samplers_.push_back(anno_data_param.batch_sampler(i));
  }
  //auto_mode 为 false可以跳过这一大段不看
  if (this->auto_mode()) {
    if (!sample_areader_) {
      sample_areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
          Caffe::solver_count(),
          this->rank_,
          this->parsers_num_,
          this->threads_num(),
          batch_size,
          true,
          false,
          cache,
          shuffle,
          false);
    } else if (!areader_) {
      areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
          Caffe::solver_count(),
          this->rank_,
          this->parsers_num_,
          this->threads_num(),
          batch_size,
          false,
          true,
          cache,
          shuffle,
          this->phase_ == TRAIN);
    }
  } else if (!areader_) {
    //这就是传说中的data reader,它是一个数据读取器,派生自 InternalThread
    areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
        Caffe::solver_count(),
        this->rank_,
        this->parsers_num_,  //parser_threads: 4
        this->threads_num(), //threads: 4
        batch_size, //注意,这是一个 solver 中的 batch_size
        false,
        false,
        cache,
        shuffle,
        this->phase_ == TRAIN);
    //这个函数貌似只是设置一个read的flag,so线程的初始化部分
    //应该在上面创建areader_这一步就完成了,然后在等待这个读的信号?
    start_reading(); //会调用areader的start_reading
  }

最重要的就是创建了 DataReader 对象,这个 data reader 是真正与底层数据库(lmdb、leveldb 等)打交道的。它 类似于 BatchTransfer也是 直接派生自 InternalThread,重写了 InternalThreadEntryN线程函数。

template<typename DatumType>
void DataReader<DatumType>::InternalThreadEntryN(size_t thread_id) {
  if (cache_) {
    data_cache_->check_db(db_source_);
    data_cache_->register_new_thread();
  }

  unique_ptr<db::DB> db; //unique_ptr也是智能指针,但它独占所指向的对象
  {
    std::lock_guard<std::mutex> lock(db_mutex_);
    db.reset(db::GetDB(backend_));
    db->Open(db_source_, db::READ); //打开数据库开始读数据
  }

  CursorManager cm(db.get(),
      this,
      solver_count_,
      solver_rank_,
      parser_threads_num_, //test:1
      thread_id,
      batch_size_,
      cache_ && !sample_only_,
      shuffle_ && !sample_only_,
      epoch_count_required_); //train:true, test:false
  shared_ptr<DatumType> init_datum = make_shared<DatumType>();
  cm.fetch(init_datum.get());
  init_->push(init_datum);

  if (!sample_only_) {
    start_reading_flag_.wait();
  }
  cm.rewind(); 
  size_t skip = skip_one_batch_ ? batch_size_ : 0UL;

  size_t queue_id, ranked_rec, batch_on_solver, sample_count = 0UL;
  shared_ptr<DatumType> datum = make_shared<DatumType>();
  /*
  1.每一次datareader将free中已经被消费过的对象取出,填上新的数据,然后将其塞入full中;
  2.每一次BasePrefetchingDataLayer将full中的数据取出并消费,然后将其塞入free中.
  */
  try {
    while (!must_stop(thread_id)) {
      cm.next(datum);
      // See comment below
      ranked_rec = (size_t) datum->record_id() / cm.full_cycle();
      batch_on_solver = ranked_rec * parser_threads_num_ + thread_id;
      //这个queue_id感觉大有学问...
      queue_id = batch_on_solver % queues_num_;

      if (thread_id == 0 && skip > 0U) {
        --skip;
        continue;
      }

      full_push(queue_id, datum);

      if (sample_only_) {
        ++sample_count;
        if (sample_count >= batch_size_) {
          // sample batch complete
          break;
        }
      }
      datum = free_pop(queue_id);
    }
  } catch (boost::thread_interrupted&) {
  }
}

DataReader通过双阻塞队列(BBQ)和 BasePrefetchingDataLayer 进行交互。

  vector<shared_ptr<BlockingQueue<shared_ptr<DatumType>>>> free_;
  vector<shared_ptr<BlockingQueue<shared_ptr<DatumType>>>> full_;

在 BasePrefetchingDataLayer 的线程函数while 循环中会不断地调用 load_batch(batch.get(), thread_id, qid),正是这个 load_batch函数它的数据就源自 data reader,到这里为止好多东西都通畅了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章