How to run all tests in serial per host while each host is in a single thread.
maybe some code will explain what I'm trying to do.
conftest.py
from paramiko import SSHClient, AutoAddPolicy
from shared import Host, HOSTS
from pytest import fixture
from typing import Dict
FMT = '[%(name)s %(levelname)-7s(%(lineno)-4d)][%(threadName)s] %(message)s'
@fixture(scope="session")
def _ssh_con(request) -> Dict[Host, SSHClient]:
print('>>>>>>>> setup')
cons: Dict[Host, SSHClient] = {}
for item in request.session.items:
host = item.callspec.params.get("host")
if host not in cons:
con = SSHClient()
con.set_missing_host_key_policy(AutoAddPolicy())
con.connect(host.ip, host.port, host.username, host.password)
cons[host] = con
print('>>>>>>>> setup done')
yield cons
print('<<<<<<<<<< teardown')
for value in cons.values():
value.close()
print('<<<<<<<<<< teardown done')
@fixture(autouse=True)
def ssh(host: Host, _ssh_con: Dict[Host, SSHClient], logger) -> SSHClient:
rp_logger.info(f'yield {host}')
yield _ssh_con[host]
def pytest_generate_tests(metafunc: Metafunc):
metafunc.parametrize('host', HOSTS, ids=str)
@fixture(scope="session")
def logger() -> logging.Logger:
logger = logging.getLogger('Tester')
logger.setLevel(logging.DEBUG)
fmt = logging.Formatter(FMT)
hdlr = logging.StreamHandler()
hdlr.setFormatter(fmt)
logger.addHandler(hdlr)
return logger
shared.py
from dataclasses import dataclass, field
from typing import List
@dataclass()
class Host:
name: str
ip: str
port: int = field(repr=False, default=22)
username: str = 'myusername'
password: str = field(repr=False, default='myuserpassowrd')
def __hash__(self):
return hash(self.ip)
def __str__(self):
return self.name
HOSTS: List[Host] = [
Host('Host-1', '192.168.195.1'),
Host('Host-2', '192.168.195.2'),
Host('Host-3', '192.168.195.3'),
]
tests.py
from time import sleep
from paramiko import SSHClient
from shared import Host, HOSTS
SLEEP_TIME = 2
def test_ip(ssh: SSHClient, host: Host, logger):
logger.info(f"test_ip[{host}][{ssh}]")
command = "ip a s ens33 | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1"
sleep(SLEEP_TIME)
_, stdout, _ = ssh.exec_command(command)
output = stdout.read().decode('utf8').strip()
assert output == host.ip, "got un expected IP or didn't get any IP"
def test_machine_name(host: Host, ssh: SSHClient, logger):
logger.info(f"test_machine_name[{host}][{ssh}]")
command = f"ls /tmp | grep {host.name}"
sleep(SLEEP_TIME)
_, stdout, _ = ssh.exec_command(command)
output = stdout.read().decode('utf8').strip()
assert output, "didn't find file with host name"
What I want to achieve is the following:
Create all ssh connections for the session
start pytest_runtestloop
start Thread-1, Thread-2, Thread-3
for each thread start all tests in sequential order
Teardown all ssh connections for the session
I tried to use pytest-parallel
and pytest-xdist
(which doesn't fit my use case)
I also tried to write my own plugin, but I'm not able to get it right.
In the log output, I get that the thread name is MainThread _/(**)\_
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Dict, List
from shared import Host
WITH_SUB_THREAD = True
def _run_hosts_tests(session, host, tests):
if WITH_SUB_THREAD:
with ThreadPoolExecutor(1, f"Worker_{host}") as executor:
for test_idx in tests:
item = session.items[test_idx]
executor.submit(item.runtest)
else:
for test_idx in tests:
item = session.items[test_idx]
item.runtest()
def _done_callback(future: Future):
try:
result = future.result()
print(f"[\033[92;1mOK\033[0m] {result}")
return result
except Exception as e:
print(f"[\033[91;1mERR\033[0m] {e}")
raise e
class ParallelRunner:
def __init__(self):
self._tests_mapper: Dict[Host, List[int]] = dict()
def pytest_collection_finish(self, session):
for idx, item in enumerate(session.items):
host = item.callspec.getparam('host')
if host not in self._tests_mapper:
self._tests_mapper[host] = []
self._tests_mapper[host].append(idx)
def pytest_runtestloop(self, session):
if (
session.testsfailed
and not session.config.option.continue_on_collection_errors
):
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
with ThreadPoolExecutor(len(self._tests_mapper), 'Worker') as executor:
for host, tests in self._tests_mapper.items():
executor.submit(_run_hosts_tests, session, host, tests)\
.add_done_callback(_done_callback)
CodePudding user response:
The answer was it set the connection in a global variable,
although I'm sure that there is a better solution, in the meantime I will put this workaround here.
conftest.py
from threading import Lock
_cons: Dict[Host, SSHClient] = dict()
@fixture(scope="session")
def _ssh_con(request: SubRequest) -> Dict[Host, SSHClient]:
mutex.acquire()
global _cons
if not _cons:
for item in request.session.items:
host = item.callspec.params.get("host")
if host not in _cons:
con = SSHClient()
con.set_missing_host_key_policy(AutoAddPolicy())
con.connect(host.ip, host.port, host.username, host.password)
_cons[host] = con
mutex.release()
return _cons
def pytest_sessionfinish(session, exitstatus):
mutex.acquire()
global _cons
if _cons:
for value in _cons.values():
value.close()
mutex.release()