本文是参数服务器系列第二篇,介绍ps-lite的通讯模块 Van。html
本系列其余文章是:node
[源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOfficepython
邮局里有了地址簿,就须要有货车来负责拉送物件,Van 就是整个Parameter Server的通讯模块,其特色以下。c++
VAN 目前有两个实现:git
首先给出 UML 图。github
下面咱们只给出Van对象关键变量和成员函数说明。编程
其主要变量以下:缓存
Node scheduler_ :Scheduler 节点参数,每个node都会记录Scheduler 节点的信息;服务器
Node my_node_ : 本节点参数。若是本节点是Scheduler,则 my_node_ 会指向上面的 scheduler_ ;微信
bool is_scheduler_ : 本节点是不是 scheduler;
std::unique_ptr< std::thread> receiver_thread_ :接收消息线程指针;
std::unique_ptr< std::thread> heartbeat_thread_ :发送心跳线程指针;
std::vector
Resender *resender_ = nullptr :从新发送消息指针;
std::atomic
std::unordered_map<std::string, int> connected_nodes_ : 记录了目前链接到哪些 nodes;
其主要函数功能以下:
start :创建通讯初始化;
Receiving :接收消息线程的处理函数;
Heartbeat :发送心跳线程的处理函数;
ProcessAddNodeCommandAtScheduler :scheduler 的 AddNode 消息处理函数;
ProcessHearbeat:心跳包处理函数;
ProcessDataMsg :数据消息(push & pull)处理函数;
ProcessAddNodeCommand :worker 和 server 的 AddNode 消息处理函数;
ProcessBarrierCommand :Barrier 消息处理函数;
PS Lite 定义的三种角色采用多线程机制工做,每一个线程承担特定的职责,在所属的 Van 实例启动时被建立。
具体描述以下:
详细代码(摘要)以下:
class Van { public: static Van *Create(const std::string &type); virtual void Start(int customer_id); int Send(const Message &msg); virtual void Stop(); inline int GetTimestamp() { return timestamp_++; } inline bool IsReady() { return ready_; } protected: //连结节点 virtual void Connect(const Node &node) = 0; //绑定到本身节点之上 virtual int Bind(const Node &node, int max_retry) = 0; //接收消息,用阻塞方式 virtual int RecvMsg(Message *msg) = 0; //发送消息 virtual int SendMsg(const Message &msg) = 0; /** * \brief pack meta into a string */ void PackMeta(const Meta &meta, char **meta_buf, int *buf_size); /** * \brief pack meta into protobuf */ void PackMetaPB(const Meta &meta, PBMeta *pb); /** * \brief unpack meta from a string */ void UnpackMeta(const char *meta_buf, int buf_size, Meta *meta); Node scheduler_; Node my_node_; bool is_scheduler_; std::mutex start_mu_; private: /** thread function for receving */ void Receiving(); /** thread function for heartbeat */ void Heartbeat(); // node's address string (i.e. ip:port) -> node id // this map is updated when ip:port is received for the first time std::unordered_map<std::string, int> connected_nodes_; // maps the id of node which is added later to the id of node // which is with the same ip:port and added first std::unordered_map<int, int> shared_node_mapping_; /** whether it is ready for sending */ std::atomic<bool> ready_{false}; std::atomic<size_t> send_bytes_{0}; size_t recv_bytes_ = 0; int num_servers_ = 0; int num_workers_ = 0; /** the thread for receiving messages */ std::unique_ptr<std::thread> receiver_thread_; /** the thread for sending heartbeat */ std::unique_ptr<std::thread> heartbeat_thread_; std::vector<int> barrier_count_; /** msg resender */ Resender *resender_ = nullptr; int drop_rate_ = 0; std::atomic<int> timestamp_{0}; int init_stage = 0; //如下是处理各类类型消息 void ProcessAddNodeCommandAtScheduler(Message *msg, Meta *nodes, Meta *recovery_nodes); void ProcessTerminateCommand(); void ProcessAddNodeCommand(Message *msg, Meta *nodes, Meta *recovery_nodes); void ProcessBarrierCommand(Message *msg); void ProcessHearbeat(Message *msg); void ProcessDataMsg(Message *msg); //更新本地NodeID void UpdateLocalID(Message *msg, std::unordered_set<int> *deadnodes_set, Meta *nodes, Meta *recovery_nodes); const char *heartbeat_timeout_val = Environment::Get()->find("PS_HEARTBEAT_TIMEOUT"); int heartbeat_timeout_ = heartbeat_timeout_val ? atoi(heartbeat_timeout_val) : 0; DISALLOW_COPY_AND_ASSIGN(Van); };
Van对象的初始化函数做用就是依据本地节点类型的不一样,作不一样设置,从而启动端口,创建到scheduler的连结,启动接收消息线程,心跳线程等,这样就能够进行通讯了。具体以下:
receiver_thread_
,执行Van::Receiving
;关于7,8两点的进一步说明就是:
具体代码以下:
void Van::Start(int customer_id) { // get scheduler info start_mu_.lock(); if (init_stage == 0) { // 初始化scheduler_这个成员变量 scheduler_.hostname = std::string( CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_URI"))); scheduler_.port = atoi(CHECK_NOTNULL(Environment::Get()->find("DMLC_PS_ROOT_PORT"))); scheduler_.role = Node::SCHEDULER; scheduler_.id = kScheduler; // 确认本节点是scheduler节点 is_scheduler_ = Postoffice::Get()->is_scheduler(); // get my node info if (is_scheduler_) { // 初始化本节点,由于是scheduler,因此直接赋值 my_node_ = scheduler_; } else { auto role = Postoffice::Get()->is_worker() ? Node::WORKER : Node::SERVER; const char* nhost = Environment::Get()->find("DMLC_NODE_HOST"); std::string ip; if (nhost) ip = std::string(nhost); if (ip.empty()) { const char* itf = Environment::Get()->find("DMLC_INTERFACE"); std::string interface; if (itf) interface = std::string(itf); if (interface.size()) { GetIP(interface, &ip); } else { GetAvailableInterfaceAndIP(&interface, &ip); } } int port = GetAvailablePort(); const char* pstr = Environment::Get()->find("PORT"); if (pstr) port = atoi(pstr); my_node_.hostname = ip; my_node_.role = role; my_node_.port = port; // cannot determine my id now, the scheduler will assign it later // set it explicitly to make re-register within a same process possible my_node_.id = Node::kEmpty; my_node_.customer_id = customer_id; } // bind. //绑定接口,把本节点绑定到ip:port这个socket上,理论来讲这个函数就是初始化了receiver_ my_node_.port = Bind(my_node_, is_scheduler_ ? 0 : 40); // connect to the scheduler // 链接上scheduler_,因为本节点就是scheduler_,其实就是初始化senders_,因为发送的节点不少,因此这里是一个map<int,void*> // 在这里就是senders_[1] = socket_1, socket_1中的body设置一点字符“ps1***”, 注意连接不是sendMsg Connect(scheduler_); // for debug use if (Environment::Get()->find("PS_DROP_MSG")) { drop_rate_ = atoi(Environment::Get()->find("PS_DROP_MSG")); } // start receiver // 开启一个接收消息的线程,这里就是处理消息 receiver_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Receiving, this)); init_stage++; } start_mu_.unlock(); if (!is_scheduler_) { // let the scheduler know myself // worker和server节点会经过 ADD_NODE 消息把本地节点的信息告诉scheduler,好比角色,ip,port... Message msg; Node customer_specific_node = my_node_; customer_specific_node.customer_id = customer_id; msg.meta.recver = kScheduler; msg.meta.control.cmd = Control::ADD_NODE; msg.meta.control.node.push_back(customer_specific_node); msg.meta.timestamp = timestamp_++; Send(msg); } // wait until ready // 等待 ready_ 从false变成true,当是scheduler的时候,必需要有等worker和server节点过来,否则一直都是阻塞在这,若是是 worker/server,则是等待 scheduler 发送系统allready消息。 while (!ready_.load()) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } start_mu_.lock(); if (init_stage == 1) { // resender if (Environment::Get()->find("PS_RESEND") && atoi(Environment::Get()->find("PS_RESEND")) != 0) { int timeout = 1000; if (Environment::Get()->find("PS_RESEND_TIMEOUT")) { timeout = atoi(Environment::Get()->find("PS_RESEND_TIMEOUT")); } // 若是设置了超时重传,就初始化resender_这个变量 resender_ = new Resender(timeout, 10, this); } if (!is_scheduler_) { // start heartbeat thread // 初始化心跳线程 heartbeat_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Heartbeat, this)); } init_stage++; } start_mu_.unlock(); }
咱们首先介绍后台线程是如何运行,而后会具体分析如何处理各类消息。
ps-lite 启动了一个后台线程 receiver_thread_ 进行接受/处理消息。
// start receiver receiver_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Receiving, this));
receiver_thread_ 使用 Receiving 函数进行消息处理。
除了传递参数的数据消息外,各个节点之间控制信息有:
所以在 Receiving 之中会调用 不一样处理函数处理不一样类型的消息:
线程内有两个变量,由于其是在 while (true) 循环以外,因此属于线程内的全局变量,这点在阅读代码时候须要注意。
Receiving 逻辑以下:
具体代码以下
void Van::Receiving() { Meta nodes; // 如下两个能够认为是全局变量 Meta recovery_nodes; // store recovery nodes 储存康复重启的节点 recovery_nodes.control.cmd = Control::ADD_NODE; // 康复重启节点的control.cmd 都设置为 ADD_NODE while (true) { Message msg; int recv_bytes = RecvMsg(&msg); //利用receiver_ 变量拿到消息 // For debug, drop received message if (ready_.load() && drop_rate_ > 0) { unsigned seed = time(NULL) + my_node_.id; if (rand_r(&seed) % 100 < drop_rate_) { LOG(WARNING) << "Drop message " << msg.DebugString(); continue; } } CHECK_NE(recv_bytes, -1); recv_bytes_ += recv_bytes; //收到的字节数累加 if (Postoffice::Get()->verbose() >= 2) { PS_VLOG(2) << msg.DebugString(); } // duplicated message if (resender_ && resender_->AddIncomming(msg)) continue; //重传确认机制 if (!msg.meta.control.empty()) { //若是是控制类型的消息 // control msg auto& ctrl = msg.meta.control; if (ctrl.cmd == Control::TERMINATE) { ProcessTerminateCommand(); break; } else if (ctrl.cmd == Control::ADD_NODE) { ProcessAddNodeCommand(&msg, &nodes, &recovery_nodes); //当执行到这个位置的时候继续跳转 } else if (ctrl.cmd == Control::BARRIER) { ProcessBarrierCommand(&msg); } else if (ctrl.cmd == Control::HEARTBEAT) { ProcessHearbeat(&msg); // 发回Heartbeat的ACK } else { LOG(WARNING) << "Drop unknown typed message " << msg.DebugString(); } } else { //非控制类型的消息处理方式 ProcessDataMsg(&msg); } } }
ADD_NODE 是 worker / server 用来向 scheduler 注册自身的控制消息。
先回忆下注册基本思路。
ProcessAddNodeCommand 逻辑以下。
具体代码以下:
void Van::ProcessAddNodeCommand(Message* msg, Meta* nodes, Meta* recovery_nodes) { auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);//查出心跳包超时的id std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());//转存到dead_set之中 auto& ctrl = msg->meta.control; //拿到收到消息里面的control信息 UpdateLocalID(msg, &dead_set, nodes, recovery_nodes); if (is_scheduler_) { // Scheduler 节点 ProcessAddNodeCommandAtScheduler(msg, nodes, recovery_nodes); } else { // Worker & Server 节点 for (const auto& node : ctrl.node) { std::string addr_str = node.hostname + ":" + std::to_string(node.port); if (connected_nodes_.find(addr_str) == connected_nodes_.end()) { // 现有节点会在本身链接之中查找这个新节点,发现现有链接中没有这个新节点 // 若是是新节点,则会链接现有节点(非同类型) Connect(node); // 与新节点进行链接 connected_nodes_[addr_str] = node.id; // 加入已经链接的节点 } if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_; if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_; } ready_ = true; } }
此函数做用是更新节点内部的node id 信息,也是分为两种状况,函数逻辑以下:
具体代码以下:
void Van::UpdateLocalID(Message* msg, std::unordered_set<int>* deadnodes_set, Meta* nodes, Meta* recovery_nodes) { auto& ctrl = msg->meta.control; size_t num_nodes = Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers(); // assign an id if (msg->meta.sender == Meta::kEmpty) { //若是sender未设定,则处理此message的必定是Scheduler CHECK(is_scheduler_); CHECK_EQ(ctrl.node.size(), 1); //msg中的control命令中的节点集合就是worker本身,因此就是1个节点 if (nodes->control.node.size() < num_nodes) { //没有到齐 nodes->control.node.push_back(ctrl.node[0]); } else { //若是全部work和server到齐了,就进入else // some node dies and restarts CHECK(ready_.load()); for (size_t i = 0; i < nodes->control.node.size() - 1; ++i) { const auto& node = nodes->control.node[i]; if (deadnodes_set->find(node.id) != deadnodes_set->end() && node.role == ctrl.node[0].role) { auto& recovery_node = ctrl.node[0]; // assign previous node id recovery_node.id = node.id; recovery_node.is_recovery = true; nodes->control.node[i] = recovery_node; recovery_nodes->control.node.push_back(recovery_node); break; } } } } // update my id / 对普通的node,更新其rank,scheduler 节点不会起做用(由于找不到)。 // schedule发给此work节点的消息,若是发现本地的ip和port和消息中的某个一点重合,那么就把本地节点的ID(初始化时候没有ID,只是等于Empty)改成schedule发过来的 node id。 for (size_t i = 0; i < ctrl.node.size(); ++i) { const auto& node = ctrl.node[i]; if (my_node_.hostname == node.hostname && my_node_.port == node.port) { if (getenv("DMLC_RANK") == nullptr || my_node_.id == Meta::kEmpty) { my_node_ = node; std::string rank = std::to_string(Postoffice::IDtoRank(node.id)); #ifdef _MSC_VER _putenv_s("DMLC_RANK", rank.c_str()); #else setenv("DMLC_RANK", rank.c_str(), true); #endif } } } }
ProcessAddNodeCommandAtScheduler 是在 Scheduler 以内运行,是对控制类型消息的处理。
对于Scheduler节点来讲,scheduler收到全部worker和server的ADD_NODE的消息后进行节点id分配并应答,即,须要设定 最新的全部node的 全局rank 并发送给全部Worker和Server。
nodes->control.node.size() == num_nodes
):
ready_ = true
; 即 scheduler 是一个 ready 状态了,无论 worker 和 server 是否确认收到ADD_NODE消息。!recovery_nodes->control.node.empty()
,这就代表是处理某些重启节点的注册行为:
CHECK_EQ(recovery_nodes->control.node.size(), 1)
来确认重启节点为 1 个)。具体代码以下:
void Van::ProcessAddNodeCommandAtScheduler(Message* msg, Meta* nodes, Meta* recovery_nodes) { recovery_nodes->control.cmd = Control::ADD_NODE; time_t t = time(NULL); size_t num_nodes = Postoffice::Get()->num_servers() + Postoffice::Get()->num_workers(); // scheduler收到全部worker和server的ADD_NODE的消息后进行节点id分配并应答 if (nodes->control.node.size() == num_nodes) { // 节点收集彻底 // sort the nodes according their ip and port, 根据IP和port给worker,server排个序 std::sort(nodes->control.node.begin(), nodes->control.node.end(), [](const Node& a, const Node& b) { return (a.hostname.compare(b.hostname) | (a.port < b.port)) > 0; }); // assign node rank for (auto& node : nodes->control.node) { // 创建链接、更新心跳时间戳,给 scheduler全部链接的节点分配全局 rank。 std::string node_host_ip = node.hostname + ":" + std::to_string(node.port); if (connected_nodes_.find(node_host_ip) == connected_nodes_.end()) { //若是ip:port不存在van_中的话 CHECK_EQ(node.id, Node::kEmpty); //判断是否是初始化节点 int id = node.role == Node::SERVER ? Postoffice::ServerRankToID(num_servers_) : Postoffice::WorkerRankToID(num_workers_); //若是是sever的话,就id产生一个id号,num_servers_初始化为0 node.id = id; //将这个新节点的id赋值为id Connect(node); //链接这个新节点, 即创建一个socket, 而后senders_[id] = sender; 就是将目标id的socket存放起来后面使用 Postoffice::Get()->UpdateHeartbeat(node.id, t);//更新心跳包 connected_nodes_[node_host_ip] = id; //既然 worker, server 已经发message来了,scheduler要把这个节点做为已经连接的节点 } else { int id = node.role == Node::SERVER ? Postoffice::ServerRankToID(num_servers_) : Postoffice::WorkerRankToID(num_workers_); shared_node_mapping_[id] = connected_nodes_[node_host_ip]; node.id = connected_nodes_[node_host_ip]; } if (node.role == Node::SERVER) num_servers_++;//更新rank if (node.role == Node::WORKER) num_workers_++; } nodes->control.node.push_back(my_node_); //把本节点放到里面 nodes->control.cmd = Control::ADD_NODE; Message back; back.meta = *nodes; // 向全部的worker和server发送ADD_NODE消息 for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) { int recver_id = r; if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) { back.meta.recver = recver_id; back.meta.timestamp = timestamp_++; Send(back); } } ready_ = true; //scheduler已经准备好了 } else if (!recovery_nodes->control.node.empty()) { // 节点没有收集彻底 auto dead_nodes = Postoffice::Get()->GetDeadNodes(heartbeat_timeout_);//查出心跳包超时的id std::unordered_set<int> dead_set(dead_nodes.begin(), dead_nodes.end());//转存到dead_set // send back the recovery node CHECK_EQ(recovery_nodes->control.node.size(), 1); Connect(recovery_nodes->control.node[0]); Postoffice::Get()->UpdateHeartbeat(recovery_nodes->control.node[0].id, t); Message back; for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) { if (r != recovery_nodes->control.node[0].id && dead_set.find(r) != dead_set.end()) { // do not try to send anything to dead node continue; } // only send recovery_node to nodes already exist // but send all nodes to the recovery_node back.meta = (r == recovery_nodes->control.node[0].id) ? *nodes : *recovery_nodes; back.meta.recver = r; back.meta.timestamp = timestamp_++; Send(back); } } }
此部分流程逻辑以下:
+ Scheduler | Worker | + | + | | | | | | v | | Postoffice::Start +----> Van::Start | | + | | | | | | | | v | | Connect--do nothing | | + | v | | | | Postoffice::Start +-----> Van::Start | | + v | | receiver_thread_ +---+ | | + | | v | | | Connect--to scheduler | | | + | | | | | | | | | | | | | | | v | | | receiver_thread_ +----->+ | | | + | | | | | | | | | | | | | | v | | | <---------------------------------------+ Send | | | | ADD_NODE + | | v | | | | | | | | ProcessAddNodeCommand | | | | + | | | | | | | | | | All nodes OK | | | | | | | | v | | | | | set rank | | | wait until ready | | | | + | | | | | +----------------------------------------------------------------> | | | | ADD_NODE response(nodes info) | | | | | | ProcessAddNodeCommand | | | v | | | | | | <--------------+ | wait until ready | | ready_ = true | + | | | | <---------------+ +-------------------+ v | | | | +--------------------+ v | | | v | | | v Postoffice::Barrier | | Postoffice::Barrier +
手机以下,左侧是 Scheduler,右侧是 worker:
其互联过程能够分为3步:
第一步:worker/server节点初始化的时候,向schedular节点发送一个链接信息,假定自身是节点 2;
if (!is_scheduler_) { // let the scheduler know myself Message msg; Node customer_specific_node = my_node_; customer_specific_node.customer_id = customer_id; msg.meta.recver = kScheduler; msg.meta.control.cmd = Control::ADD_NODE; msg.meta.control.node.push_back(customer_specific_node); msg.meta.timestamp = timestamp_++; Send(msg); //发送给schedular, 创建连接信息。 }
第二步:Scheduler 节点收到信息后,在 ProcessAddNodeCommandAtScheduler 之中,首先会和 节点 2 创建一个链接。会向全部已经和schedular创建链接的worker节点/server节点 广播此 "节点的加入信息“,并把 节点 2 请求链接的信息放入meta信息中。
// assign node rank for (auto& node : nodes->control.node) { std::string node_host_ip = node.hostname + ":" + std::to_string(node.port); if (connected_nodes_.find(node_host_ip) == connected_nodes_.end()) { int id = node.role == Node::SERVER ? Postoffice::ServerRankToID(num_servers_) : Postoffice::WorkerRankToID(num_workers_); node.id = id; Connect(node); // 链接这个新节点, 即创建一个socket, 而后senders_[id] = sender; 就是将目标id的socket存放起来后面使用 Postoffice::Get()->UpdateHeartbeat(node.id, t); connected_nodes_[node_host_ip] = id; } else { int id = node.role == Node::SERVER ? Postoffice::ServerRankToID(num_servers_) : Postoffice::WorkerRankToID(num_workers_); shared_node_mapping_[id] = connected_nodes_[node_host_ip]; node.id = connected_nodes_[node_host_ip]; } if (node.role == Node::SERVER) num_servers_++; if (node.role == Node::WORKER) num_workers_++; } nodes->control.node.push_back(my_node_); nodes->control.cmd = Control::ADD_NODE; Message back; back.meta = *nodes; // 向全部已经和schedular创建链接的worker节点/server节点 广播此 "节点的加入信息“,并把 节点 2 请求链接的信息放入meta信息中。 for (int r : Postoffice::Get()->GetNodeIDs(kWorkerGroup + kServerGroup)) { int recver_id = r; if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) { back.meta.recver = recver_id; back.meta.timestamp = timestamp_++; Send(back); } }
第三步:现有worker/server节点收到这个命令后,在 ProcessAddNodeCommand 之中 会和 节点 2 造成链接。
for (const auto& node : ctrl.node) { std::string addr_str = node.hostname + ":" + std::to_string(node.port); if (connected_nodes_.find(addr_str) == connected_nodes_.end()) { // 现有链接中没有这个新节点 Connect(node); // 与新节点进行链接 connected_nodes_[addr_str] = node.id; } if (!node.is_recovery && node.role == Node::SERVER) ++num_servers_; if (!node.is_recovery && node.role == Node::WORKER) ++num_workers_;
至此,整个过程就描述完了。每一个新节点加入后,已经加入的节点都会经过schedular节点和这个新节点创建链接。
咱们接下来分析心跳机制。
为了记录网络的可达性,PS Lite 设计了心跳机制。具体而言:
具体以下:
std::unordered_map<int, time_t> heartbeats_ 就是存储了心跳关联的节点的活跃信息。键为节点编号,值为上次收到其 HEARTBEAT 消息的时间戳。
UpdateHeartbeat 会按期更新心跳。
void UpdateHeartbeat(int node_id, time_t t) { std::lock_guard<std::mutex> lk(heartbeat_mu_); heartbeats_[node_id] = t; } std::unordered_map<int, time_t> heartbeats_;
在这两种节点中,启动了一个线程,每个 Worker/Server 节点,每隔 PS_HEARTBEAT_INTERVAL 秒向 Scheduler 发送一条 HEARTBEAT 消息:
if (!is_scheduler_) { // start heartbeat thread heartbeat_thread_ = std::unique_ptr<std::thread>(new std::thread(&Van::Heartbeat, this)); }
具体心跳函数是:
void Van::Heartbeat() { const char* val = Environment::Get()->find("PS_HEARTBEAT_INTERVAL"); const int interval = val ? atoi(val) : kDefaultHeartbeatInterval; while (interval > 0 && ready_.load()) { std::this_thread::sleep_for(std::chrono::seconds(interval)); Message msg; msg.meta.recver = kScheduler; msg.meta.control.cmd = Control::HEARTBEAT; msg.meta.control.node.push_back(my_node_); msg.meta.timestamp = timestamp_++; Send(msg); } }
Scheduler 节点收到后 HEARTBEAT 消息后,响应一个 HEARTBEAT 消息。UpdateHeartbeat 会按期更新心跳。
void Van::ProcessHearbeat(Message* msg) { auto& ctrl = msg->meta.control; time_t t = time(NULL); for (auto& node : ctrl.node) { Postoffice::Get()->UpdateHeartbeat(node.id, t); if (is_scheduler_) { Message heartbeat_ack; heartbeat_ack.meta.recver = node.id; heartbeat_ack.meta.control.cmd = Control::HEARTBEAT; heartbeat_ack.meta.control.node.push_back(my_node_); heartbeat_ack.meta.timestamp = timestamp_++; // send back heartbeat Send(heartbeat_ack); } } }
Scheduler 在处理 ADD_NODE 消息时候,会看看是否已经有死亡节点,具体判经过当前时间戳与心跳包接收时间戳之差判断是否alive。
std::vector<int> Postoffice::GetDeadNodes(int t) { std::vector<int> dead_nodes; if (!van_->IsReady() || t == 0) return dead_nodes; time_t curr_time = time(NULL); const auto& nodes = is_scheduler_ ? GetNodeIDs(kWorkerGroup + kServerGroup) : GetNodeIDs(kScheduler); { std::lock_guard<std::mutex> lk(heartbeat_mu_); for (int r : nodes) { auto it = heartbeats_.find(r); if ((it == heartbeats_.end() || it->second + t < curr_time) && start_time_ + t < curr_time) { dead_nodes.push_back(r); } } } return dead_nodes; }
逻辑以下:
+----------------------------------------------------+ | Scheduler | | | | | | | | heartbeats_ | | | | receiver_thread_+--------> ProcessHearbeat | | ^ + ^ + | | | | | | | | | | | | | | | | | | | +----------------------------------------------------+ | | | | | | | | RESPONSE | | | +-------------------------------------+ | | | | | | +-------------------------------+ | | | | | HEARTBEAT | | RESPONSE HEARTBEAT | | | | | | +-----------------------------------------+ +-----------------------------------------+ | Worker | | | | Server | | | | | | | | | | | | | | | | | | | | | | | | | | | | heartbeats_ | | | | heartbeats_ | | | | + | | | + | | | heartbeat_thread_+----> Heartbeat | | | heartbeat_thread_+--> Heartbeat | | | | | | | | | v | | v | | receiver_thread_ +---> ProcessHearbeat | | receiver_thread_ +--> ProcessHearbeat | | | | | | | | | | | | | +-----------------------------------------+ +-----------------------------------------+
ProcessTerminateCommand 会处理结束消息,具体就是设定 ready_ 为 false。
这样就预示着 Van 状态不对,不能够继续处理。
void Van::ProcessTerminateCommand() { PS_VLOG(1) << my_node().ShortDebugString() << " is stopped"; ready_ = false; } inline bool IsReady() { return ready_; }
在分布式系统中,通讯也是不可靠的,丢包、延时都是必须考虑的场景。PS Lite 设计了 Resender类来提升通讯的可靠性,它引入了 ACK 机制。即:
定义以下,其中 send_buff_ 就是发送缓存,用来存储发送了的消息列表。acked_ 就是已经确认的消息。
class Resender { std::thread* monitor_; std::unordered_set<uint64_t> acked_; std::atomic<bool> exit_{false}; std::mutex mu_; int timeout_; int max_num_retry_; Van* van_; using Time = std::chrono::milliseconds; // the buffer entry struct Entry { Message msg; Time send; int num_retry = 0; }; std::unordered_map<uint64_t, Entry> send_buff_; };
监控线程以及函数以下以下,就是被唤醒时候,从send_buff_(本地缓存)找到每一个消息的发送时间戳和当前时间,找出超时的消息进行重发,并累加其重试次数。 :
monitor_ = new std::thread(&Resender::Monitoring, this); void Monitoring() { while (!exit_) { std::this_thread::sleep_for(Time(timeout_)); std::vector<Message> resend; Time now = Now(); mu_.lock(); for (auto& it : send_buff_) { if (it.second.send + Time(timeout_) * (1+it.second.num_retry) < now) { resend.push_back(it.second.msg); ++it.second.num_retry; CHECK_LT(it.second.num_retry, max_num_retry_); } } mu_.unlock(); for (const auto& msg : resend) van_->Send(msg); } }
当 Van 发送消息时候,若是配置了重传,就调用AddOutgoing函数把消息加入到发送缓存。
int Van::Send(const Message& msg) { int send_bytes = SendMsg(msg); CHECK_NE(send_bytes, -1); send_bytes_ += send_bytes; if (resender_) resender_->AddOutgoing(msg); if (Postoffice::Get()->verbose() >= 2) { PS_VLOG(2) << msg.DebugString(); } return send_bytes; }
下面函数就是加入到发送缓存。
/** * \brief add an outgoining message * */ void AddOutgoing(const Message& msg) { if (msg.meta.control.cmd == Control::ACK) return; CHECK_NE(msg.meta.timestamp, Meta::kEmpty) << msg.DebugString(); auto key = GetKey(msg); std::lock_guard<std::mutex> lk(mu_); // already buffered, which often due to call Send by the monitor thread if (send_buff_.find(key) != send_buff_.end()) return; auto& ent = send_buff_[key]; ent.msg = msg; ent.send = Now(); ent.num_retry = 0; }
下面函数有两个做用:
/** * \brief add an incomming message * \brief return true if msg has been added before or a ACK message */ bool AddIncomming(const Message& msg) { // a message can be received by multiple times if (msg.meta.control.cmd == Control::TERMINATE) { return false; } else if (msg.meta.control.cmd == Control::ACK) { mu_.lock(); auto key = msg.meta.control.msg_sig; auto it = send_buff_.find(key); if (it != send_buff_.end()) send_buff_.erase(it); mu_.unlock(); return true; } else { mu_.lock(); auto key = GetKey(msg); auto it = acked_.find(key); bool duplicated = it != acked_.end(); if (!duplicated) acked_.insert(key); mu_.unlock(); // send back ack message (even if it is duplicated) Message ack; ack.meta.recver = msg.meta.sender; ack.meta.sender = msg.meta.recver; ack.meta.control.cmd = Control::ACK; ack.meta.control.msg_sig = key; van_->Send(ack); // warning if (duplicated) LOG(WARNING) << "Duplicated message: " << msg.DebugString(); return duplicated; } }
ProcessDataMsg 用来处理 worker 发过来的数据消息(就是worker向server更新梯度),具体是取得对应的Customer后,调用 Customer 的方法进行处理,直接将msg
放入处理队列中。
咱们会放在 Customer 之中进行介绍。
void Van::ProcessDataMsg(Message* msg) { // data msg int app_id = msg->meta.app_id; int customer_id = Postoffice::Get()->is_worker() ? msg->meta.customer_id : app_id; auto* obj = Postoffice::Get()->GetCustomer(app_id, customer_id, 5); obj->Accept(*msg); // 这里给 Customer 添加消息 }
ZMQVan是基于zeromq的Van的实现,即为用zmq库实现了链接的底层细节(zmq库是一个开源库,对socket进行了优良的封装,他使得Socket编程更加简单、简洁和性能更高)。
ZMQVan定义以下:
ZMQVan 继承于Van ,在这个类的基础上加了两个成员变量,分别是:
具体以下:
class ZMQVan : public Van { void *context_ = nullptr; /** * \brief node_id to the socket for sending data to this node */ std::unordered_map<int, void*> senders_; std::mutex mu_; void *receiver_ = nullptr; };
Van类 有以下函数会调用到 ZMQVan 或者被 ZMQVan 调用。
Send 函数就是调用 ZMQVan 的 SendMsg 函数进行发送消息,发送以后若是设定了ACK机制,则会调用 resender_->AddOutgoing。
int Van::Send(const Message& msg) { int send_bytes = SendMsg(msg); CHECK_NE(send_bytes, -1); send_bytes_ += send_bytes; if (resender_) resender_->AddOutgoing(msg); if (Postoffice::Get()->verbose() >= 2) { PS_VLOG(2) << msg.DebugString(); } return send_bytes; }
Meta封装了元数据,发送者,接受者,时间戳,请求仍是响应等。
/** * \brief meta info of a message */ struct Meta { /** \brief the empty value */ static const int kEmpty; /** \brief an int head */ int head; /** \brief the unique id of the application of messsage is for*/ int app_id; /** \brief customer id*/ int customer_id; /** \brief the timestamp of this message */ int timestamp; /** \brief the node id of the sender of this message */ int sender; /** \brief the node id of the receiver of this message */ int recver; /** \brief whether or not this is a request message*/ bool request; /** \brief whether or not a push message */ bool push; /** \brief whether or not a pull message */ bool pull; /** \brief whether or not it's for SimpleApp */ bool simple_app; /** \brief an string body */ std::string body; /** \brief data type of message.data[i] */ std::vector<DataType> data_type; /** \brief system control message */ Control control; /** \brief the byte size */ int data_size = 0; /** \brief message priority */ int priority = 0; };
为了缓解通讯压力,ps-lite 使用了Protobuf对 Meta 进行数据压缩。
就是按照 protobuf 来进行数据压缩。
void Van::PackMeta(const Meta& meta, char** meta_buf, int* buf_size) { // convert into protobuf PBMeta pb; pb.set_head(meta.head); if (meta.app_id != Meta::kEmpty) pb.set_app_id(meta.app_id); if (meta.timestamp != Meta::kEmpty) pb.set_timestamp(meta.timestamp); if (meta.body.size()) pb.set_body(meta.body); pb.set_push(meta.push); pb.set_pull(meta.pull); pb.set_request(meta.request); pb.set_simple_app(meta.simple_app); pb.set_priority(meta.priority); pb.set_customer_id(meta.customer_id); for (auto d : meta.data_type) pb.add_data_type(d); if (!meta.control.empty()) { auto ctrl = pb.mutable_control(); ctrl->set_cmd(meta.control.cmd); if (meta.control.cmd == Control::BARRIER) { ctrl->set_barrier_group(meta.control.barrier_group); } else if (meta.control.cmd == Control::ACK) { ctrl->set_msg_sig(meta.control.msg_sig); } for (const auto& n : meta.control.node) { auto p = ctrl->add_node(); p->set_id(n.id); p->set_role(n.role); p->set_port(n.port); p->set_hostname(n.hostname); p->set_is_recovery(n.is_recovery); p->set_customer_id(n.customer_id); } } // to string *buf_size = pb.ByteSize(); *meta_buf = new char[*buf_size + 1]; CHECK(pb.SerializeToArray(*meta_buf, *buf_size)) << "failed to serialize protobuf"; }
按照protobuf 预先生成的 PBMeta 格式进行解压。
void Van::UnpackMeta(const char* meta_buf, int buf_size, Meta* meta) { // to protobuf PBMeta pb; CHECK(pb.ParseFromArray(meta_buf, buf_size)) << "failed to parse string into protobuf"; // to meta meta->head = pb.head(); meta->app_id = pb.has_app_id() ? pb.app_id() : Meta::kEmpty; meta->timestamp = pb.has_timestamp() ? pb.timestamp() : Meta::kEmpty; meta->request = pb.request(); meta->push = pb.push(); meta->pull = pb.pull(); meta->simple_app = pb.simple_app(); meta->priority = pb.priority(); meta->body = pb.body(); meta->customer_id = pb.customer_id(); meta->data_type.resize(pb.data_type_size()); for (int i = 0; i < pb.data_type_size(); ++i) { meta->data_type[i] = static_cast<DataType>(pb.data_type(i)); } if (pb.has_control()) { const auto& ctrl = pb.control(); meta->control.cmd = static_cast<Control::Command>(ctrl.cmd()); meta->control.barrier_group = ctrl.barrier_group(); meta->control.msg_sig = ctrl.msg_sig(); for (int i = 0; i < ctrl.node_size(); ++i) { const auto& p = ctrl.node(i); Node n; n.role = static_cast<Node::Role>(p.role()); n.port = p.port(); n.hostname = p.hostname(); n.id = p.has_id() ? p.id() : Node::kEmpty; n.is_recovery = p.is_recovery(); n.customer_id = p.customer_id(); meta->control.node.push_back(n); } } else { meta->control.cmd = Control::EMPTY; } }
PackMetaPB 从注释看,是字节跳动提交的,主要用于 ibverbs_van.h,因此咱们不作深刻研究。
void Van::PackMetaPB(const Meta& meta, PBMeta* pb) { pb->set_head(meta.head); if (meta.app_id != Meta::kEmpty) pb->set_app_id(meta.app_id); if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp); if (meta.body.size()) pb->set_body(meta.body); pb->set_push(meta.push); pb->set_request(meta.request); pb->set_simple_app(meta.simple_app); pb->set_priority(meta.priority); pb->set_customer_id(meta.customer_id); for (auto d : meta.data_type) pb->add_data_type(d); if (!meta.control.empty()) { auto ctrl = pb->mutable_control(); ctrl->set_cmd(meta.control.cmd); if (meta.control.cmd == Control::BARRIER) { ctrl->set_barrier_group(meta.control.barrier_group); } else if (meta.control.cmd == Control::ACK) { ctrl->set_msg_sig(meta.control.msg_sig); } for (const auto& n : meta.control.node) { auto p = ctrl->add_node(); p->set_id(n.id); p->set_role(n.role); p->set_port(n.port); p->set_hostname(n.hostname); p->set_is_recovery(n.is_recovery); p->set_customer_id(n.customer_id); } } pb->set_data_size(meta.data_size); }
ZMQVan 有以下重要的派生函数。
Bind 逻辑以下:
int Bind(const Node& node, int max_retry) override { receiver_ = zmq_socket(context_, ZMQ_ROUTER); int local = GetEnv("DMLC_LOCAL", 0); std::string hostname = node.hostname.empty() ? "*" : node.hostname; int use_kubernetes = GetEnv("DMLC_USE_KUBERNETES", 0); if (use_kubernetes > 0 && node.role == Node::SCHEDULER) { hostname = "0.0.0.0"; } std::string addr = local ? "ipc:///tmp/" : "tcp://" + hostname + ":"; int port = node.port; unsigned seed = static_cast<unsigned>(time(NULL) + port); for (int i = 0; i < max_retry + 1; ++i) { auto address = addr + std::to_string(port); if (zmq_bind(receiver_, address.c_str()) == 0) break; if (i == max_retry) { port = -1; } else { port = 10000 + rand_r(&seed) % 40000; } } return port; }
主要就是初始化 Sender_,逻辑以下:
具体以下:
void Connect(const Node& node) override { int id = node.id; auto it = senders_.find(id); if (it != senders_.end()) { zmq_close(it->second); // 若是找到了对应socket就关闭socket } // worker doesn't need to connect to the other workers. same for server if ((node.role == my_node_.role) && (node.id != my_node_.id)) { return; } void *sender = zmq_socket(context_, ZMQ_DEALER); //创建一个socket //若是自己是scheduler,则一开始就是知道本身的id = 1,因此这个if条件就是说把本身的id绑定到socket上 if (my_node_.id != Node::kEmpty) { std::string my_id = "ps" + std::to_string(my_node_.id); zmq_setsockopt(sender, ZMQ_IDENTITY, my_id.data(), my_id.size()); const char* watermark = Environment::Get()->find("DMLC_PS_WATER_MARK"); if (watermark) { const int hwm = atoi(watermark); zmq_setsockopt(sender, ZMQ_SNDHWM, &hwm, sizeof(hwm)); } } // connect std::string addr = "tcp://" + node.hostname + ":" + std::to_string(node.port); if (GetEnv("DMLC_LOCAL", 0)) { addr = "ipc:///tmp/" + std::to_string(node.port); } if (zmq_connect(sender, addr.c_str()) != 0) { //将sender这个socket和目标地址链接 LOG(FATAL) << "connect to " + addr + " failed: " + zmq_strerror(errno); } senders_[id] = sender; //将目标id的socket存放起来后面使用 }
逻辑以下:
int SendMsg(const Message& msg) override { std::lock_guard<std::mutex> lk(mu_); // find the socket int id = msg.meta.recver; CHECK_NE(id, Meta::kEmpty); auto it = senders_.find(id); if (it == senders_.end()) { LOG(WARNING) << "there is no socket to node " << id; return -1; } void *socket = it->second; // send meta int meta_size; char* meta_buf; PackMeta(msg.meta, &meta_buf, &meta_size); int tag = ZMQ_SNDMORE; int n = msg.data.size(); if (n == 0) tag = 0; zmq_msg_t meta_msg; zmq_msg_init_data(&meta_msg, meta_buf, meta_size, FreeData, NULL); while (true) { if (zmq_msg_send(&meta_msg, socket, tag) == meta_size) break; if (errno == EINTR) continue; return -1; } // zmq_msg_close(&meta_msg); int send_bytes = meta_size; // send data for (int i = 0; i < n; ++i) { zmq_msg_t data_msg; SArray<char>* data = new SArray<char>(msg.data[i]); int data_size = data->size(); zmq_msg_init_data(&data_msg, data->data(), data->size(), FreeData, data); if (i == n - 1) tag = 0; while (true) { if (zmq_msg_send(&data_msg, socket, tag) == data_size) break; if (errno == EINTR) continue; return -1; } // zmq_msg_close(&data_msg); send_bytes += data_size; } return send_bytes; }
RecvMsg 就是在绑定的端口上接受消息。
接受消息时候,会判断是第几个消息,而后作不一样的处理。
int RecvMsg(Message* msg) override { msg->data.clear(); size_t recv_bytes = 0; for (int i = 0; ; ++i) { zmq_msg_t* zmsg = new zmq_msg_t; CHECK(zmq_msg_init(zmsg) == 0) << zmq_strerror(errno); while (true) { if (zmq_msg_recv(zmsg, receiver_, 0) != -1) break; if (errno == EINTR) { std::cout << "interrupted"; continue; } return -1; } char* buf = CHECK_NOTNULL((char *)zmq_msg_data(zmsg)); size_t size = zmq_msg_size(zmsg); recv_bytes += size; if (i == 0) { // identify msg->meta.sender = GetNodeID(buf, size); msg->meta.recver = my_node_.id; CHECK(zmq_msg_more(zmsg)); zmq_msg_close(zmsg); delete zmsg; } else if (i == 1) { // task UnpackMeta(buf, size, &(msg->meta)); zmq_msg_close(zmsg); bool more = zmq_msg_more(zmsg); delete zmsg; if (!more) break; } else { // zero-copy SArray<char> data; data.reset(buf, size, [zmsg, size](char* buf) { zmq_msg_close(zmsg); delete zmsg; }); msg->data.push_back(data); if (!zmq_msg_more(zmsg)) { break; } } } return recv_bytes; }
GetNodeID 函数是
/** * return the node id given the received identity * \return -1 if not find */ int GetNodeID(const char* buf, size_t size) { if (size > 2 && buf[0] == 'p' && buf[1] == 's') { int id = 0; size_t i = 2; for (; i < size; ++i) { if (buf[i] >= '0' && buf[i] <= '9') { id = id * 10 + buf[i] - '0'; } else { break; } } if (i == size) return id; } return Meta::kEmpty; }
咱们最后进行一下总结:
邮局里有了地址簿,就须要有货车来负责拉送物件,Van 就是整个Parameter Server的通讯模块,其特色以下。
★★★★★★关于生活和技术的思考★★★★★★
微信公众帐号:罗西的思考
若是您想及时获得我的撰写文章的消息推送,或者想看看我的推荐的技术资料,敬请关注。