Source code for queens.utils.remote_operations

#
# SPDX-License-Identifier: LGPL-3.0-or-later
# Copyright (c) 2024-2025, QUEENS contributors.
#
# This file is part of QUEENS.
#
# QUEENS is free software: you can redistribute it and/or modify it under the terms of the GNU
# Lesser General Public License as published by the Free Software Foundation, either version 3 of
# the License, or (at your option) any later version. QUEENS is distributed in the hope that it will
# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You
# should have received a copy of the GNU Lesser General Public License along with QUEENS. If not,
# see <https://www.gnu.org/licenses/>.
#
"""Module supplies functions to conduct operation on remote resource."""

import atexit
import json
import logging
import pickle
import socket
import time
import uuid
from functools import partial
from pathlib import Path

import cloudpickle
from fabric import Connection
from invoke.exceptions import UnexpectedExit

from queens.utils.path import PATH_TO_QUEENS, is_empty
from queens.utils.rsync import assemble_rsync_command
from queens.utils.run_subprocess import start_subprocess

_logger = logging.getLogger(__name__)

DEFAULT_PACKAGE_MANAGER = "mamba"
FALLBACK_PACKAGE_MANAGER = "conda"
SUPPORTED_PACKAGE_MANAGERS = [DEFAULT_PACKAGE_MANAGER, FALLBACK_PACKAGE_MANAGER]


[docs] class RemoteConnection(Connection): """This is class wrapper around the Connection class of fabric. Attributes: remote_python (str): Path to Python with installed (editable) QUEENS (see remote_queens_repository) remote_queens_repository (str, Path): Path to the QUEENS source code on the remote host """ def __init__(self, host, remote_python, remote_queens_repository, user=None, gateway=None): """Initialize RemoteConnection object. Args: host (str): address of remote host remote_python (str, Path): Path to Python with installed (editable) QUEENS (see remote_queens_repository) remote_queens_repository (str, Path): Path to the QUEENS source code on the remote host user (str): Username on remote machine gateway (dict,Connection,None): An object to use as a proxy or gateway for this connection. See docs of Fabric's Connection object for details. """ if isinstance(gateway, dict): gateway = Connection(**gateway) super().__init__(host, user=user, gateway=gateway) self.remote_python = remote_python _logger.debug("remote python path: %s", self.remote_python) self.remote_queens_repository = remote_queens_repository _logger.debug("remote queens repository: %s", self.remote_queens_repository)
[docs] def open(self): """Initiate the SSH connection.""" super().open() atexit.register(self.close)
[docs] def start_cluster( self, workload_manager, dask_cluster_kwargs, dask_cluster_adapt_kwargs, experiment_dir, ): """Start a Dask Cluster remotely using an ssh connection. Args: workload_manager (str): Workload manager ("pbs" or "slurm") on cluster dask_cluster_kwargs (dict): collection of keyword arguments to be forwarded to DASK Cluster dask_cluster_adapt_kwargs (dict): collection of keyword arguments to be forwarded to DASK Cluster adapt method experiment_dir (str): directory holding all data of QUEENS experiment on remote Returns: return_value (obj): Return value of function """ _logger.info("Starting Dask cluster on %s", self.host) python_cmd = ( "source /etc/profile;" f"{self.remote_python} " f"{Path(self.remote_queens_repository) / 'queens' / 'utils' / 'start_dask_cluster.py'} " f"--workload-manager {workload_manager} " f"--dask-cluster-kwargs '{json.dumps(dask_cluster_kwargs)}' " f"--dask-cluster-adapt-kwargs '{json.dumps(dask_cluster_adapt_kwargs)}' " f"--experiment-dir {experiment_dir}" ) _logger.debug("Starting cluster with command:") _logger.debug("%s", python_cmd) _, stdout, stderr = self.client.exec_command(python_cmd, get_pty=True) return stdout, stderr
[docs] def run_function(self, func, *func_args, wait=True, **func_kwargs): """Run a python function remotely using an ssh connection. Args: func (Function): function that is executed func_args: Additional arguments for the functools.partial function wait (bool): Flag to decide whether to wait for result of function func_kwargs: Additional keyword arguments for the functools.partial function Returns: return_value (obj): Return value of function """ _logger.info("Running %s on %s", func.__name__, self.host) func_file_name = f"temp_func_{str(uuid.uuid4())}.pickle" output_file_name = f"output_{str(uuid.uuid4())}.pickle" python_cmd = ( f"{self.remote_python} -c 'import pickle; from pathlib import Path;" f'file = open("{func_file_name}", "rb");' f"func = pickle.load(file); file.close();" f'Path("{func_file_name}").unlink(); ' f"result = func();" f'file = open("{output_file_name}", "wb");' f"pickle.dump(result, file); file.close();'" ) partial_func = partial(func, *func_args, **func_kwargs) # insert function arguments with open(func_file_name, "wb") as file: cloudpickle.dump(partial_func, file) # pickle function by value self.put(func_file_name) # upload local function file Path(func_file_name).unlink() # delete local function file if not wait: _, stdout, stderr = self.client.exec_command(python_cmd, get_pty=True) return stdout, stderr try: result = self.run(python_cmd, in_stream=False, hide=True) # run function remote except UnexpectedExit as unexpected_exit: _logger.debug(unexpected_exit.result.stdout) _logger.debug(unexpected_exit.result.stderr) raise unexpected_exit _logger.debug(result.stdout) _logger.debug(result.stderr) self.get(output_file_name) # download result self.run(f"rm {output_file_name}", in_stream=False) # delete remote files with open(output_file_name, "rb") as file: # read return value from output file return_value = pickle.load(file) Path(output_file_name).unlink() # delete local output file return return_value
[docs] def get_free_local_port(self): """Get a free port on localhost.""" return get_port()
[docs] def get_free_remote_port(self): """Get a free port on remote host.""" return self.run_function(get_port)
[docs] def open_port_forwarding(self, local_port=None, remote_port=None): """Open port forwarding. Args: local_port (int): free local port remote_port (int): free remote port Returns: local_port (int): used local port remote_port (int): used remote port """ if local_port is None: local_port = self.get_free_local_port() if remote_port is None: remote_port = self.get_free_remote_port() proxyjump = "" if self.gateway is not None: proxyjump = f"-J {self.gateway.user}@{self.gateway.host}:{self.gateway.port}" cmd = ( f"ssh {proxyjump} -f -N -L {local_port}:{self.host}:{remote_port} " f"{self.user}@{self.host}" ) _logger.debug("\nOpening port-forwarding '%s'\n", cmd) start_subprocess(cmd) _logger.debug("Port-forwarding opened successfully.") kill_cmd = f'pkill -f "{cmd}"' atexit.register(start_subprocess, kill_cmd) return local_port, remote_port
[docs] def create_remote_directory(self, remote_directory): """Make a directory (including parents) on the remote host. Args: remote_directory (Path, str): path of the directory that will be created """ _logger.debug("Creating folder %s on %s@%s.", remote_directory, self.user, self.host) result = self.run(f"mkdir -v -p {remote_directory}", in_stream=False) stdout = result.stdout if stdout: _logger.debug(stdout) else: _logger.debug("%s already exists on %s@%s.", remote_directory, self.user, self.host)
[docs] def sync_remote_repository(self): """Synchronize local and remote QUEENS source files.""" _logger.info("Syncing remote QUEENS repository with local one...") start_time = time.time() self.create_remote_directory(self.remote_queens_repository) source = f"{PATH_TO_QUEENS}/" self.copy_to_remote( source, self.remote_queens_repository, exclude=".git", filters=":- .gitignore" ) _logger.info("Sync of remote repository was successful.") _logger.info("It took: %s s.\n", time.time() - start_time)
[docs] def copy_to_remote(self, source, destination, verbose=True, exclude=None, filters=None): """Copy files or folders to remote. Args: source (str, Path, list): paths to copy destination (str, Path): destination relative to host verbose (bool): true for verbose exclude (str, list): options to exclude filters (str): filters for rsync """ if not is_empty(source): host = f"{self.user}@{self.host}" _logger.debug("Copying from %s to %s", source, destination) remote_shell_command = None if self.gateway is not None: remote_shell_command = f"ssh {self.gateway.user}@{self.gateway.host} ssh" _logger.debug("Using remote shell command %s", remote_shell_command) rsync_cmd = assemble_rsync_command( source, destination, verbose=verbose, archive=True, exclude=exclude, filters=filters, rsh=remote_shell_command, host=host, rsync_options=["--out-format='%n'", "--checksum"], ) # Run rsync command result = self.local(rsync_cmd, in_stream=False) _logger.debug(result.stdout) _logger.debug("Copying complete.") else: _logger.debug("List of source files was empty. Did not copy anything.")
[docs] def build_remote_environment( self, package_manager=DEFAULT_PACKAGE_MANAGER, ): """Build remote QUEENS environment. Args: package_manager(str, optional): Package manager used for the creation of the environment ("mamba" or "conda") """ if package_manager not in SUPPORTED_PACKAGE_MANAGERS: raise ValueError( f"The package manager '{package_manager}' is not supported.\n" f"Supported package managers are: {SUPPORTED_PACKAGE_MANAGERS}" ) remote_connect = f"{self.user}@{self.host}" # check if requested package_manager is installed on remote machine: def package_manager_exists_remote(package_manager_name): """Check if requested package manager exists on remote. Args: package_manager_name (string): name of package manager """ result_which = self.run(f"which {package_manager_name}") if result_which.stderr: message = ( f"Could not find requested package manager '{package_manager_name}' " f"on '{remote_connect}'." ) if package_manager_name == DEFAULT_PACKAGE_MANAGER: _logger.warning(message) _logger.warning( "Trying to fall back to the '%s' package manager.", FALLBACK_PACKAGE_MANAGER ) package_manager_exists_remote(package_manager_name=FALLBACK_PACKAGE_MANAGER) else: raise RuntimeError(message) return False return True if not package_manager_exists_remote(package_manager_name=package_manager): package_manager = FALLBACK_PACKAGE_MANAGER _logger.info("Build remote QUEENS environment...") start_time = time.time() environment_name = Path(self.remote_python).parents[1].name command_string = ( f"cd {self.remote_queens_repository}; " f"{package_manager} remove --name {environment_name} --all -y;" f"{package_manager} env create -f environment.yml --name {environment_name}; " f"{package_manager} activate {environment_name};" f"pip install -e ." ) result = self.run(command_string, in_stream=False) _logger.debug(result.stdout) _logger.info("Build of remote queens environment was successful.") _logger.info("It took: %s s.\n", time.time() - start_time)
[docs] def get_port(): """Get free port. Returns: int: free port """ sock = socket.socket() sock.bind(("", 0)) return int(sock.getsockname()[1])
VALID_CONNECTION_TYPES = {"remote_connection": RemoteConnection}