LayerSetUp
上一篇讲了 AnnotatedDataLayer 的构造函数部分,在完成对象的初始化后紧接着会调用每一层的 LayerSetUp(...)函数来对各层进行相应的设置。注意到 AnnotatedDataLayer以及其直接父类 DataLayer都没有重写 LayerSetUp函数,所以责任自然就落到了 爷爷类 BasePrefetchingDataLayer 身上了。
template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::LayerSetUp(const vector<Blob*>& bottom,
const vector<Blob*>& top) {
bottom_init_ = bottom;
top_init_ = top;
//先调了父类的layersetup
BaseDataLayer<Ftype, Btype>::LayerSetUp(bottom, top);
//这里也很重要,设置了 bwd和 fwd
for (int i = 0; i < transf_num_; ++i) {
bwd_data_transformers_.emplace_back(
make_shared<DataTransformer<Btype>>(this->transform_param_, this->phase_));
fwd_data_transformers_.emplace_back(
make_shared<DataTransformer<Ftype>>(this->transform_param_, this->phase_));
}
const Solver* psolver = this->parent_solver();
const uint64_t random_seed = (psolver == nullptr ||
static_cast<uint64_t>(psolver->param().random_seed()) == Caffe::SEED_NOT_SET) ?
Caffe::next_seed() : static_cast<uint64_t>(psolver->param().random_seed());
//在这启动了预取内部线程
StartInternalThread(false, random_seed);
}
这个函数里面最重要的部分就是:StartInternalThread(false, random_seed),由它启动了内部的数据预取线程。启动了多少个呢?这在上一篇文章里面有作分析,像我的配置文件的话那就是4个线程。也就是说一个 solver 里面就有4个数据预取线程,一共有4张卡就有16个数据预取线程。线程函数入口为:InternalThreadEntryN。同样因为 AnnotatedDataLayer 和 DataLayer 都没有重写 InternalThreadEntryN,所以我们直接看 BasePrefetchingDataLayer 里面重写的 InternalThreadEntryN 就好了。
template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::InternalThreadEntryN(size_t thread_id) {
const bool auto_mode = this->auto_mode(); //我的为 false
if (auto_mode) {
iter0_.wait_reset(); // sample reader first
} else if (this->phase_ == TRAIN) {
iter0_.wait(); //等待,等待什么??
}
if (auto_mode && this->net_inititialized_flag_ != nullptr) {
this->net_inititialized_flag_->wait();
}
InitializePrefetch(); //一定不要漏过这个函数
start_reading(); //置位 start_reading_flag
// 感觉这个函数要和datareader.cpp中的结合着读
try {
while (!must_stop(thread_id)) {
//队列id 和线程 id 的关系?
const size_t qid = this->queue_id(thread_id);
//哟,这里看到了和 batch_transformer 的交互
shared_ptr<Batch> batch = batch_transformer_->prefetched_pop_free(qid);
CHECK_EQ((size_t) -1L, batch->id());
// 数据从哪里来? 从 data reader 里面来,data reader 才是真正与数据库交互的
load_batch(batch.get(), thread_id, qid);
if (must_stop(thread_id)) {
break;
}
// 取得的数据交给 batch transformer 由他做进一步的处理
batch_transformer_->prefetched_push_full(qid, batch);
if (auto_mode) {
iter0_.set();
break;
}
}
} catch (boost::thread_interrupted&) {
}
}
核心逻辑在那个 while(...)循环里面,但是在这之前千万不要漏过了那个InitializePrefetch()函数。
template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::InitializePrefetch() {
ResizeQueues();
//这里会调用 AnnodatedDataLayer 中的 DataLayerSetUp.这很合理,因为那里会创建 areader
//这个DataLayerSetUp 在 BaseDataLayer中是纯虚函数,在BasePrefetchingDataLayer
//中没有实现这个函数,它的儿子类DataLayer和孙子类AnnodatedDataLayer都做了实现.
this->DataLayerSetUp(bottom_init_, top_init_);
}
函数内部会继续调用 this->DataLayerSetup,调谁的呢?调 AnnotatedDataLayer 的。
template <typename Ftype, typename Btype>
void AnnotatedDataLayer<Ftype, Btype>::DataLayerSetUp(
const vector<Blob*>& bottom, const vector<Blob*>& top) {
const LayerParameter& param = this->layer_param();
const AnnotatedDataParameter& anno_data_param = param.annotated_data_param();
const int batch_size = param.data_param().batch_size(); //还不是在prototxt中设定的咯
const bool cache = this->cache_ && this->phase_ == TRAIN;
const bool shuffle = cache && this->shuffle_ && this->phase_ == TRAIN;
TBlob<Ftype> transformed_datum;
//根据你 prototxt 中设置的 batch_sampler{...}逐个添加
for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i) {
batch_samplers_.push_back(anno_data_param.batch_sampler(i));
}
//auto_mode 为 false可以跳过这一大段不看
if (this->auto_mode()) {
if (!sample_areader_) {
sample_areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
Caffe::solver_count(),
this->rank_,
this->parsers_num_,
this->threads_num(),
batch_size,
true,
false,
cache,
shuffle,
false);
} else if (!areader_) {
areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
Caffe::solver_count(),
this->rank_,
this->parsers_num_,
this->threads_num(),
batch_size,
false,
true,
cache,
shuffle,
this->phase_ == TRAIN);
}
} else if (!areader_) {
//这就是传说中的data reader,它是一个数据读取器,派生自 InternalThread
areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
Caffe::solver_count(),
this->rank_,
this->parsers_num_, //parser_threads: 4
this->threads_num(), //threads: 4
batch_size, //注意,这是一个 solver 中的 batch_size
false,
false,
cache,
shuffle,
this->phase_ == TRAIN);
//这个函数貌似只是设置一个read的flag,so线程的初始化部分
//应该在上面创建areader_这一步就完成了,然后在等待这个读的信号?
start_reading(); //会调用areader的start_reading
}
最重要的就是创建了 DataReader 对象,这个 data reader 是真正与底层数据库(lmdb、leveldb 等)打交道的。它 类似于 BatchTransfer也是 直接派生自 InternalThread,重写了 InternalThreadEntryN线程函数。
template<typename DatumType>
void DataReader<DatumType>::InternalThreadEntryN(size_t thread_id) {
if (cache_) {
data_cache_->check_db(db_source_);
data_cache_->register_new_thread();
}
unique_ptr<db::DB> db; //unique_ptr也是智能指针,但它独占所指向的对象
{
std::lock_guard<std::mutex> lock(db_mutex_);
db.reset(db::GetDB(backend_));
db->Open(db_source_, db::READ); //打开数据库开始读数据
}
CursorManager cm(db.get(),
this,
solver_count_,
solver_rank_,
parser_threads_num_, //test:1
thread_id,
batch_size_,
cache_ && !sample_only_,
shuffle_ && !sample_only_,
epoch_count_required_); //train:true, test:false
shared_ptr<DatumType> init_datum = make_shared<DatumType>();
cm.fetch(init_datum.get());
init_->push(init_datum);
if (!sample_only_) {
start_reading_flag_.wait();
}
cm.rewind();
size_t skip = skip_one_batch_ ? batch_size_ : 0UL;
size_t queue_id, ranked_rec, batch_on_solver, sample_count = 0UL;
shared_ptr<DatumType> datum = make_shared<DatumType>();
/*
1.每一次datareader将free中已经被消费过的对象取出,填上新的数据,然后将其塞入full中;
2.每一次BasePrefetchingDataLayer将full中的数据取出并消费,然后将其塞入free中.
*/
try {
while (!must_stop(thread_id)) {
cm.next(datum);
// See comment below
ranked_rec = (size_t) datum->record_id() / cm.full_cycle();
batch_on_solver = ranked_rec * parser_threads_num_ + thread_id;
//这个queue_id感觉大有学问...
queue_id = batch_on_solver % queues_num_;
if (thread_id == 0 && skip > 0U) {
--skip;
continue;
}
full_push(queue_id, datum);
if (sample_only_) {
++sample_count;
if (sample_count >= batch_size_) {
// sample batch complete
break;
}
}
datum = free_pop(queue_id);
}
} catch (boost::thread_interrupted&) {
}
}
DataReader通过双阻塞队列(BBQ)和 BasePrefetchingDataLayer 进行交互。
vector<shared_ptr<BlockingQueue<shared_ptr<DatumType>>>> free_;
vector<shared_ptr<BlockingQueue<shared_ptr<DatumType>>>> full_;
在 BasePrefetchingDataLayer 的线程函数while 循环中会不断地调用 load_batch(batch.get(), thread_id, qid),正是这个 load_batch函数它的数据就源自 data reader,到这里为止好多东西都通畅了。