- 概览
- 安装
- 教程
- 算法接口文档
- 简易高效的并行接口
- APIS
- FREQUENTLY ASKED QUESTIONS
- EVOKIT
- 其他
- parl.algorithms.paddle.policy_gradient
- parl.algorithms.paddle.dqn
- parl.algorithms.paddle.ddpg
- parl.algorithms.paddle.ddqn
- parl.algorithms.paddle.oac
- parl.algorithms.paddle.a2c
- parl.algorithms.paddle.qmix
- parl.algorithms.paddle.td3
- parl.algorithms.paddle.sac
- parl.algorithms.paddle.ppo
- parl.algorithms.paddle.maddpg
- parl.core.paddle.model
- parl.core.paddle.algorithm
- parl.remote.remote_decorator
- parl.core.paddle.agent
- parl.remote.client
文章来源于网络收集而来,版权归原创者所有,如有侵权请及时联系!
parl.remote.client
parl.remote.client 源代码
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import cloudpickle import datetime import os import socket import sys import threading import zmq import parl import time import glob import multiprocessing as mp from parl.utils import to_str, to_byte, get_ip_address, logger, isnotebook from parl.remote.utils import get_subfiles_recursively from parl.remote import remote_constants from parl.remote.grpc_heartbeat import HeartbeatServerThread, HeartbeatServerProcess from parl.remote.utils import get_version class Client(object): """Base class for the remote client. For each training task, there is a global client in the cluster which submits jobs to the master node. Different `@parl.remote_class` objects connect to the same global client in a training task. Attributes: submit_job_socket (zmq.Context.socket): A socket which submits job to the master node. pyfiles (bytes): A serialized dictionary containing the code of python files in local working directory. executable_path (str): File path of the executable python script. start_time (time): A timestamp to record the start time of the program. """ def __init__(self, master_address, process_id, distributed_files=[]): """ Args: master_addr (str): IP address of the master node. job_heartbeat_server_addr(str): Server address for heartbeat detection from jobs. process_id (str): Process id in which client is created. Should use os.getpid() to get the process id. distributed_files (list): A list of files to be distributed at all remote instances(e,g. the configuration file for initialization) . """ self.dead_job_queue = mp.Queue() self.client_is_alive = mp.Value('i', True) self._create_heartbeat_server() th = threading.Thread(target=self._update_job_status, args=(self.dead_job_queue, )) th.setDaemon(True) th.start() self.master_address = master_address self.process_id = process_id self.ctx = zmq.Context() self.lock = threading.Lock() self.log_monitor_url = None self.threads = [] self.executable_path = self.get_executable_path() self._create_sockets(master_address) self.connected_to_master = True self.check_env_consistency() self.instance_count = 0 self.instance_id_to_job = dict() thread = threading.Thread(target=self._update_client_status_to_master) thread.setDaemon(True) thread.start() self.threads.append(thread) self.pyfiles = self.read_local_files(distributed_files) def destroy(self): """Destructor function""" self.connected_to_master = False self.dead_job_queue.put('exit') self.master_heartbeat_thread.exit() for th in self.threads: th.join() self.ctx.destroy() self.client_is_alive.value = False self.job_heartbeat_process.join() def get_executable_path(self): """Return current executable path.""" mod = sys.modules['__main__'] if hasattr(mod, '__file__'): executable_path = os.path.abspath(mod.__file__) else: executable_path = os.getcwd() executable_path = executable_path[:executable_path.rfind('/')] return executable_path def read_local_files(self, distributed_files=[]): """Read local python code and store them in a dictionary, which will then be sent to the job. Args: distributed_files (list): A list of files to be distributed at all remote instances(e,g. the configuration file for initialization) . RegExp of file names is supported. e.g. distributed_files = ['./*.npy', './test*'] Returns: A cloudpickled dictionary containing the python code in current working directory. """ pyfiles = dict() pyfiles['python_files'] = {} pyfiles['other_files'] = {} user_files = [] user_empty_subfolders = [] for distributed_file in distributed_files: parsed_list = glob.glob(distributed_file) if not parsed_list: raise ValueError("no local file is matched with '{}', please check your input".format(distributed_file)) for pathname in parsed_list: if os.path.isdir(pathname): pythonfiles, otherfiles, emptysubfolders = get_subfiles_recursively(pathname) user_files.extend(pythonfiles) user_files.extend(otherfiles) user_empty_subfolders.extend(emptysubfolders) else: user_files.append(pathname) if isnotebook(): main_folder = './' else: main_file = sys.argv[0] main_folder = './' sep = os.sep if sep in main_file: main_folder = sep.join(main_file.split(sep)[:-1]) code_files = filter(lambda x: x.endswith('.py'), os.listdir(main_folder)) for file_name in code_files: file_path = os.path.join(main_folder, file_name) assert os.path.exists(file_path) with open(file_path, 'rb') as code_file: code = code_file.read() pyfiles['python_files'][file_name] = code for file_name in set(user_files): assert os.path.exists(file_name) assert not os.path.isabs(file_name), "[XPARL] Please do not distribute a file with absolute path." with open(file_name, 'rb') as f: content = f.read() pyfiles['other_files'][file_name] = content pyfiles['empty_subfolders'] = set(user_empty_subfolders) return cloudpickle.dumps(pyfiles) def _create_sockets(self, master_address): """ Each client has 1 sockets as start: (1) submit_job_socket: submits jobs to master node. """ # submit_job_socket: submits job to master self.submit_job_socket = self.ctx.socket(zmq.REQ) self.submit_job_socket.linger = 0 self.submit_job_socket.setsockopt(zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) self.submit_job_socket.connect("tcp://{}".format(master_address)) self.start_time = time.time() def master_heartbeat_exit_callback_func(): logger.warning("[Client] Cannot connect to the master. " "Please check if it is still alive.") logger.warning("Client exit replying heartbeat for master.") self.connected_to_master = False self.master_heartbeat_thread = HeartbeatServerThread( heartbeat_exit_callback_func=master_heartbeat_exit_callback_func) self.master_heartbeat_thread.setDaemon(True) self.master_heartbeat_thread.start() self.reply_master_heartbeat_address = self.master_heartbeat_thread.get_address() self.threads.append(self.master_heartbeat_thread) self.client_id = self.reply_master_heartbeat_address.replace(':', '_') + \ '_' + str(int(time.time())) # check if the master is connected properly try: self.submit_job_socket.send_multipart([ remote_constants.CLIENT_CONNECT_TAG, to_byte(self.reply_master_heartbeat_address), to_byte(socket.gethostname()), to_byte(self.client_id), ]) message = self.submit_job_socket.recv_multipart() self.log_monitor_url = to_str(message[1]) except zmq.error.Again as e: logger.warning("[Client] Can not connect to the master, please " "check if master is started and ensure the input " "address {} is correct.".format(master_address)) self.connected_to_master = False raise Exception("Client can not connect to the master, please " "check if master is started and ensure the input " "address {} is correct.".format(master_address)) def _update_job_status(self, dead_job_queue): while True: instance_id = dead_job_queue.get() # the client calls the destroy function if isinstance(instance_id, str) and instance_id == 'exit': break logger.error("[Client] lost connection with a remote instance. ID: {}".format(instance_id)) job_is_alive = self.instance_id_to_job[instance_id] job_is_alive.value = False def check_env_consistency(self): '''Verify that the parl & python version as well as some other packages in 'worker' process matches that of the 'master' process''' self.submit_job_socket.send_multipart([remote_constants.CHECK_VERSION_TAG]) message = self.submit_job_socket.recv_multipart() tag = message[0] if tag == remote_constants.NORMAL_TAG: client_parl_version = parl.__version__ client_python_version_major = str(sys.version_info.major) client_python_version_minor = str(sys.version_info.minor) assert client_parl_version == to_str(message[1]) and client_python_version_major == to_str(message[2])\ and client_python_version_minor == to_str(message[3]),\ '''Version mismatch: the 'master' is of version 'parl={}, python={}.{}'. However, \ 'parl={}, python={}.{}'is provided in your environment.'''.format( to_str(message[1]), to_str(message[2]), to_str(message[3]), client_parl_version, client_python_version_major, client_python_version_minor ) client_pyarrow_version = str(get_version('pyarrow')) master_pyarrow_version = to_str(message[4]) if client_pyarrow_version != master_pyarrow_version: if master_pyarrow_version == 'None': error_message = """"pyarrow" is provided in your current environment, however, it is not \ found in "master"'s environment. To use "pyarrow" for serialization, please install \ "pyarrow={}" in "master"'s environment!""".format(client_pyarrow_version) elif client_pyarrow_version == 'None': error_message = """"pyarrow" is provided in "master"'s environment, however, it is not \ found in your current environment. To use "pyarrow" for serialization, please install \ "pyarrow={}" in your current environment!""".format(master_pyarrow_version) else: error_message = '''Version mismatch: the 'master' is of version 'pyarrow={}'. However, \ 'pyarrow={}'is provided in your current environment.'''.format(master_pyarrow_version, client_pyarrow_version) raise Exception(error_message) else: raise NotImplementedError def _update_client_status_to_master(self): while self.connected_to_master: elapsed_time = datetime.timedelta(seconds=int(time.time() - self.start_time)) client_status = { 'file_path': self.executable_path, 'actor_num': self.actor_num.value, 'time': str(elapsed_time), 'log_monitor_url': self.log_monitor_url } self.lock.acquire() try: self.submit_job_socket.send_multipart([ remote_constants.CLIENT_STATUS_UPDATE_TAG, to_byte(self.reply_master_heartbeat_address), cloudpickle.dumps(client_status) ]) message = self.submit_job_socket.recv_multipart() except zmq.error.Again as e: self.connected_to_master = False finally: self.lock.release() time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) def _check_job(self, job_ping_address, max_memory, gpu): """ Check if this job is still alive before establishing connection with it. Return: instance_id (int): an unique isntance id. -1 if the job is not ready for connection. """ # job_ping_socket: sends ping signal to job job_ping_socket = self.ctx.socket(zmq.REQ) job_ping_socket.linger = 0 job_ping_socket.setsockopt(zmq.RCVTIMEO, int(0.9 * 1000)) job_ping_socket.connect("tcp://" + job_ping_address) instance_id = self._generate_instance_id() try: job_ping_socket.send_multipart([ remote_constants.HEARTBEAT_TAG, to_byte(self.job_heartbeat_server_addr), to_byte(str(max_memory)), to_byte(gpu), to_byte(instance_id) ], ) job_ping_socket.recv_multipart() except zmq.error.Again: logger.error( "[Client] connects to a finished job, will try again, job_ping_address:{}".format(job_ping_address)) instance_id = -1 finally: job_ping_socket.close(0) return instance_id def _create_heartbeat_server(self): """ Create the grpc-based heartbeat server at the subprocess. """ job_heartbeat_port = mp.Value('i', 0) self.actor_num = mp.Value('i', 0) self.job_heartbeat_process = HeartbeatServerProcess(job_heartbeat_port, self.actor_num, self.client_is_alive, self.dead_job_queue) self.job_heartbeat_process.daemon = True self.job_heartbeat_process.start() assert job_heartbeat_port.value != 0, "fail to initialize heartbeat server for jobs." self.job_heartbeat_server_addr = "{}:{}".format(get_ip_address(), job_heartbeat_port.value) def _generate_instance_id(self): """Return an unique instance id for the remote instance""" self.instance_count += 1 unique_id = f"{self.instance_count:05}" return unique_id def submit_job(self, max_memory, n_gpu, job_is_alive): """Send a job to the Master node. When a `@parl.remote_class` object is created, the global client sends a job to the master node. Then the master node will allocate a vacant job from its job pool to the remote object. Args: max_memory (float): Maximum memory (MB) can be used by each remote instance, the unit is in MB and default value is none(unlimited). n_gpu (int): Number of GPUs can used in this remote instance. Returns: An ``InitializedJob`` that has information about available job address. """ if self.connected_to_master: while True: self.lock.acquire() n_cpu = 0 if n_gpu > 0 else 1 self.submit_job_socket.send_multipart([ remote_constants.CLIENT_SUBMIT_TAG, to_byte(self.reply_master_heartbeat_address), to_byte(self.client_id), to_byte(str(n_cpu)), to_byte(str(n_gpu)) ]) message = self.submit_job_socket.recv_multipart() self.lock.release() tag = message[0] if tag == remote_constants.NORMAL_TAG: job_info = cloudpickle.loads(message[1]) job_ping_address = job_info.ping_heartbeat_address self.lock.acquire() instance_id = self._check_job(job_ping_address, max_memory, job_info.allocated_gpu.gpu) self.lock.release() if instance_id != -1: self.instance_id_to_job[instance_id] = job_is_alive return job_info # no vacant CPU resources, cannot submit a new job elif tag == remote_constants.CPU_TAG: # wait 1 second to avoid requesting in a high frequency. time.sleep(1) return None # no vacant GPU resources, cannot submit a new job elif tag == remote_constants.GPU_TAG: # wait 5 second to avoid requesting in a high frequency. time.sleep(1) return None elif tag == remote_constants.REJECT_GPU_JOB_TAG: error_message = "[Client] Request fails. It is not allowed to request CPU resource from a GPU cluster." logger.error(error_message) raise Exception(error_message) elif tag == remote_constants.REJECT_CPU_JOB_TAG: error_message = "[Client] Request fails. It is not allowed to request GPU resource from a CPU cluster." logger.error(error_message) raise Exception(error_message) elif tag == remote_constants.REJECT_INVALID_GPU_JOB_TAG: error_message = "[Client] request {} GPUs, but rejected.".format(n_gpu) logger.error(error_message) raise Exception(error_message) else: raise NotImplementedError else: raise Exception("Client can not submit job to the master, please check if master is connected.") return None GLOBAL_CLIENT = None [文档]def connect(master_address, distributed_files=[]): """Create a global client which connects to the master node. .. code-block:: python parl.connect(master_address='localhost:1234') Args: master_address (str): The address of the Master node to connect to. distributed_files (list): A list of files to be distributed at all remote instances(e,g. the configuration file for initialization) . Raises: Exception: An exception is raised if the master node is not started. """ assert len(master_address.split(":")) == 2, "Please input address in " +\ "{ip}:{port} format" assert isinstance(distributed_files, list), "`distributed_files` should be a list." global GLOBAL_CLIENT addr = master_address.split(":")[0] cur_process_id = os.getpid() if GLOBAL_CLIENT is None: GLOBAL_CLIENT = Client(master_address, cur_process_id, distributed_files) else: if GLOBAL_CLIENT.process_id != cur_process_id: GLOBAL_CLIENT = Client(master_address, cur_process_id, distributed_files) logger.info("Remote actors log url: {}".format(GLOBAL_CLIENT.log_monitor_url)) def get_global_client(): """Get the global client. To support process-based programming, we will create a new global client in the new process. Returns: The global client. """ global GLOBAL_CLIENT assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\ " job, have you connected to the cluster by calling " +\ "parl.connect(master_ip, master_port)?" cur_process_id = os.getpid() if GLOBAL_CLIENT.process_id != cur_process_id: GLOBAL_CLIENT = Client(GLOBAL_CLIENT.master_address, cur_process_id) return GLOBAL_CLIENT def disconnect(): """Disconnect the global client from the master node.""" global GLOBAL_CLIENT if GLOBAL_CLIENT is not None: GLOBAL_CLIENT.destroy() GLOBAL_CLIENT = None logger.info("The client is disconneced to the master node.") else: logger.info("No client to be released. Please make sure that you have called `parl.connect`")
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论