LayerSetUp
上一篇講了 AnnotatedDataLayer 的構造函數部分,在完成對象的初始化後緊接着會調用每一層的 LayerSetUp(...)函數來對各層進行相應的設置。注意到 AnnotatedDataLayer以及其直接父類 DataLayer都沒有重寫 LayerSetUp函數,所以責任自然就落到了 爺爺類 BasePrefetchingDataLayer 身上了。
template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::LayerSetUp(const vector<Blob*>& bottom,
const vector<Blob*>& top) {
bottom_init_ = bottom;
top_init_ = top;
//先調了父類的layersetup
BaseDataLayer<Ftype, Btype>::LayerSetUp(bottom, top);
//這裏也很重要,設置了 bwd和 fwd
for (int i = 0; i < transf_num_; ++i) {
bwd_data_transformers_.emplace_back(
make_shared<DataTransformer<Btype>>(this->transform_param_, this->phase_));
fwd_data_transformers_.emplace_back(
make_shared<DataTransformer<Ftype>>(this->transform_param_, this->phase_));
}
const Solver* psolver = this->parent_solver();
const uint64_t random_seed = (psolver == nullptr ||
static_cast<uint64_t>(psolver->param().random_seed()) == Caffe::SEED_NOT_SET) ?
Caffe::next_seed() : static_cast<uint64_t>(psolver->param().random_seed());
//在這啓動了預取內部線程
StartInternalThread(false, random_seed);
}
這個函數裏面最重要的部分就是:StartInternalThread(false, random_seed),由它啓動了內部的數據預取線程。啓動了多少個呢?這在上一篇文章裏面有作分析,像我的配置文件的話那就是4個線程。也就是說一個 solver 裏面就有4個數據預取線程,一共有4張卡就有16個數據預取線程。線程函數入口爲:InternalThreadEntryN。同樣因爲 AnnotatedDataLayer 和 DataLayer 都沒有重寫 InternalThreadEntryN,所以我們直接看 BasePrefetchingDataLayer 裏面重寫的 InternalThreadEntryN 就好了。
template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::InternalThreadEntryN(size_t thread_id) {
const bool auto_mode = this->auto_mode(); //我的爲 false
if (auto_mode) {
iter0_.wait_reset(); // sample reader first
} else if (this->phase_ == TRAIN) {
iter0_.wait(); //等待,等待什麼??
}
if (auto_mode && this->net_inititialized_flag_ != nullptr) {
this->net_inititialized_flag_->wait();
}
InitializePrefetch(); //一定不要漏過這個函數
start_reading(); //置位 start_reading_flag
// 感覺這個函數要和datareader.cpp中的結合着讀
try {
while (!must_stop(thread_id)) {
//隊列id 和線程 id 的關係?
const size_t qid = this->queue_id(thread_id);
//喲,這裏看到了和 batch_transformer 的交互
shared_ptr<Batch> batch = batch_transformer_->prefetched_pop_free(qid);
CHECK_EQ((size_t) -1L, batch->id());
// 數據從哪裏來? 從 data reader 裏面來,data reader 纔是真正與數據庫交互的
load_batch(batch.get(), thread_id, qid);
if (must_stop(thread_id)) {
break;
}
// 取得的數據交給 batch transformer 由他做進一步的處理
batch_transformer_->prefetched_push_full(qid, batch);
if (auto_mode) {
iter0_.set();
break;
}
}
} catch (boost::thread_interrupted&) {
}
}
核心邏輯在那個 while(...)循環裏面,但是在這之前千萬不要漏過了那個InitializePrefetch()函數。
template<typename Ftype, typename Btype>
void BasePrefetchingDataLayer<Ftype, Btype>::InitializePrefetch() {
ResizeQueues();
//這裏會調用 AnnodatedDataLayer 中的 DataLayerSetUp.這很合理,因爲那裏會創建 areader
//這個DataLayerSetUp 在 BaseDataLayer中是純虛函數,在BasePrefetchingDataLayer
//中沒有實現這個函數,它的兒子類DataLayer和孫子類AnnodatedDataLayer都做了實現.
this->DataLayerSetUp(bottom_init_, top_init_);
}
函數內部會繼續調用 this->DataLayerSetup,調誰的呢?調 AnnotatedDataLayer 的。
template <typename Ftype, typename Btype>
void AnnotatedDataLayer<Ftype, Btype>::DataLayerSetUp(
const vector<Blob*>& bottom, const vector<Blob*>& top) {
const LayerParameter& param = this->layer_param();
const AnnotatedDataParameter& anno_data_param = param.annotated_data_param();
const int batch_size = param.data_param().batch_size(); //還不是在prototxt中設定的咯
const bool cache = this->cache_ && this->phase_ == TRAIN;
const bool shuffle = cache && this->shuffle_ && this->phase_ == TRAIN;
TBlob<Ftype> transformed_datum;
//根據你 prototxt 中設置的 batch_sampler{...}逐個添加
for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i) {
batch_samplers_.push_back(anno_data_param.batch_sampler(i));
}
//auto_mode 爲 false可以跳過這一大段不看
if (this->auto_mode()) {
if (!sample_areader_) {
sample_areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
Caffe::solver_count(),
this->rank_,
this->parsers_num_,
this->threads_num(),
batch_size,
true,
false,
cache,
shuffle,
false);
} else if (!areader_) {
areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
Caffe::solver_count(),
this->rank_,
this->parsers_num_,
this->threads_num(),
batch_size,
false,
true,
cache,
shuffle,
this->phase_ == TRAIN);
}
} else if (!areader_) {
//這就是傳說中的data reader,它是一個數據讀取器,派生自 InternalThread
areader_ = std::make_shared<DataReader<AnnotatedDatum>>(param,
Caffe::solver_count(),
this->rank_,
this->parsers_num_, //parser_threads: 4
this->threads_num(), //threads: 4
batch_size, //注意,這是一個 solver 中的 batch_size
false,
false,
cache,
shuffle,
this->phase_ == TRAIN);
//這個函數貌似只是設置一個read的flag,so線程的初始化部分
//應該在上面創建areader_這一步就完成了,然後在等待這個讀的信號?
start_reading(); //會調用areader的start_reading
}
最重要的就是創建了 DataReader 對象,這個 data reader 是真正與底層數據庫(lmdb、leveldb 等)打交道的。它 類似於 BatchTransfer也是 直接派生自 InternalThread,重寫了 InternalThreadEntryN線程函數。
template<typename DatumType>
void DataReader<DatumType>::InternalThreadEntryN(size_t thread_id) {
if (cache_) {
data_cache_->check_db(db_source_);
data_cache_->register_new_thread();
}
unique_ptr<db::DB> db; //unique_ptr也是智能指針,但它獨佔所指向的對象
{
std::lock_guard<std::mutex> lock(db_mutex_);
db.reset(db::GetDB(backend_));
db->Open(db_source_, db::READ); //打開數據庫開始讀數據
}
CursorManager cm(db.get(),
this,
solver_count_,
solver_rank_,
parser_threads_num_, //test:1
thread_id,
batch_size_,
cache_ && !sample_only_,
shuffle_ && !sample_only_,
epoch_count_required_); //train:true, test:false
shared_ptr<DatumType> init_datum = make_shared<DatumType>();
cm.fetch(init_datum.get());
init_->push(init_datum);
if (!sample_only_) {
start_reading_flag_.wait();
}
cm.rewind();
size_t skip = skip_one_batch_ ? batch_size_ : 0UL;
size_t queue_id, ranked_rec, batch_on_solver, sample_count = 0UL;
shared_ptr<DatumType> datum = make_shared<DatumType>();
/*
1.每一次datareader將free中已經被消費過的對象取出,填上新的數據,然後將其塞入full中;
2.每一次BasePrefetchingDataLayer將full中的數據取出並消費,然後將其塞入free中.
*/
try {
while (!must_stop(thread_id)) {
cm.next(datum);
// See comment below
ranked_rec = (size_t) datum->record_id() / cm.full_cycle();
batch_on_solver = ranked_rec * parser_threads_num_ + thread_id;
//這個queue_id感覺大有學問...
queue_id = batch_on_solver % queues_num_;
if (thread_id == 0 && skip > 0U) {
--skip;
continue;
}
full_push(queue_id, datum);
if (sample_only_) {
++sample_count;
if (sample_count >= batch_size_) {
// sample batch complete
break;
}
}
datum = free_pop(queue_id);
}
} catch (boost::thread_interrupted&) {
}
}
DataReader通過雙阻塞隊列(BBQ)和 BasePrefetchingDataLayer 進行交互。
vector<shared_ptr<BlockingQueue<shared_ptr<DatumType>>>> free_;
vector<shared_ptr<BlockingQueue<shared_ptr<DatumType>>>> full_;
在 BasePrefetchingDataLayer 的線程函數while 循環中會不斷地調用 load_batch(batch.get(), thread_id, qid),正是這個 load_batch函數它的數據就源自 data reader,到這裏爲止好多東西都通暢了。