Home > database >  Python parallel testing with serial grouped tests
Python parallel testing with serial grouped tests

Time:06-17


How to run all tests in serial per host while each host is in a single thread.

maybe some code will explain what I'm trying to do.

conftest.py

from paramiko import SSHClient, AutoAddPolicy
from shared import Host, HOSTS
from pytest import fixture
from typing import Dict

FMT = '[%(name)s %(levelname)-7s(%(lineno)-4d)][%(threadName)s] %(message)s'

@fixture(scope="session")
def _ssh_con(request) -> Dict[Host, SSHClient]:
    print('>>>>>>>> setup')
    cons: Dict[Host, SSHClient] = {}
    for item in request.session.items:
        host = item.callspec.params.get("host")
        if host not in cons:
            con = SSHClient()
            con.set_missing_host_key_policy(AutoAddPolicy())
            con.connect(host.ip, host.port, host.username, host.password)
            cons[host] = con

    print('>>>>>>>> setup done')
    yield cons
    print('<<<<<<<<<< teardown')
    for value in cons.values():
        value.close()
    print('<<<<<<<<<< teardown done')


@fixture(autouse=True)
def ssh(host: Host, _ssh_con: Dict[Host, SSHClient], logger) -> SSHClient:
    rp_logger.info(f'yield {host}')
    yield _ssh_con[host]


def pytest_generate_tests(metafunc: Metafunc):
    metafunc.parametrize('host', HOSTS, ids=str)

@fixture(scope="session")
def logger() -> logging.Logger:
    logger = logging.getLogger('Tester')
    logger.setLevel(logging.DEBUG)

    fmt = logging.Formatter(FMT)
    hdlr = logging.StreamHandler()
    hdlr.setFormatter(fmt)
    logger.addHandler(hdlr)
    return logger

shared.py

from dataclasses import dataclass, field
from typing import List


@dataclass()
class Host:
    name: str
    ip: str
    port: int = field(repr=False, default=22)
    username: str = 'myusername'
    password: str = field(repr=False, default='myuserpassowrd')

    def __hash__(self):
        return hash(self.ip)

    def __str__(self):
        return self.name


HOSTS: List[Host] = [
    Host('Host-1', '192.168.195.1'),
    Host('Host-2', '192.168.195.2'),
    Host('Host-3', '192.168.195.3'),
]

tests.py

from time import sleep
from paramiko import SSHClient
from shared import Host, HOSTS


SLEEP_TIME = 2

def test_ip(ssh: SSHClient, host: Host, logger):
    logger.info(f"test_ip[{host}][{ssh}]")
    command = "ip a s ens33 | grep 'inet ' | awk '{print $2}' | cut -d'/' -f1"
    sleep(SLEEP_TIME)
    _, stdout, _ = ssh.exec_command(command)
    output = stdout.read().decode('utf8').strip()
    assert output == host.ip, "got un expected IP or didn't get any IP"


def test_machine_name(host: Host, ssh: SSHClient, logger):
    logger.info(f"test_machine_name[{host}][{ssh}]")
    command = f"ls /tmp | grep {host.name}"
    sleep(SLEEP_TIME)
    _, stdout, _ = ssh.exec_command(command)
    output = stdout.read().decode('utf8').strip()
    assert output, "didn't find file with host name"

What I want to achieve is the following:
Create all ssh connections for the session
start pytest_runtestloop
start Thread-1, Thread-2, Thread-3
for each thread start all tests in sequential order
Teardown all ssh connections for the session

I tried to use pytest-parallel and pytest-xdist (which doesn't fit my use case)
I also tried to write my own plugin, but I'm not able to get it right.
In the log output, I get that the thread name is MainThread _/(**)\_

from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Dict, List

from shared import Host

WITH_SUB_THREAD = True


def _run_hosts_tests(session, host, tests):
    if WITH_SUB_THREAD:
        with ThreadPoolExecutor(1, f"Worker_{host}") as executor:
            for test_idx in tests:
                item = session.items[test_idx]
                executor.submit(item.runtest)
    else:
        for test_idx in tests:
            item = session.items[test_idx]
            item.runtest()


def _done_callback(future: Future):
    try:
        result = future.result()
        print(f"[\033[92;1mOK\033[0m] {result}")
        return result
    except Exception as e:
        print(f"[\033[91;1mERR\033[0m] {e}")
        raise e


class ParallelRunner:
    def __init__(self):
        self._tests_mapper: Dict[Host, List[int]] = dict()

    def pytest_collection_finish(self, session):
        for idx, item in enumerate(session.items):
            host = item.callspec.getparam('host')

            if host not in self._tests_mapper:
                self._tests_mapper[host] = []

            self._tests_mapper[host].append(idx)

    def pytest_runtestloop(self, session):
        if (
            session.testsfailed
            and not session.config.option.continue_on_collection_errors
        ):
            raise session.Interrupted(
                "%d error%s during collection"
                % (session.testsfailed, "s" if session.testsfailed != 1 else "")
            )

        if session.config.option.collectonly:
            return True

        with ThreadPoolExecutor(len(self._tests_mapper), 'Worker') as executor:
            for host, tests in self._tests_mapper.items():
                executor.submit(_run_hosts_tests, session, host, tests)\
                    .add_done_callback(_done_callback)

CodePudding user response:

The answer was it set the connection in a global variable,
although I'm sure that there is a better solution, in the meantime I will put this workaround here.

conftest.py

from threading import Lock

_cons: Dict[Host, SSHClient] = dict()


@fixture(scope="session")
def _ssh_con(request: SubRequest) -> Dict[Host, SSHClient]:
    mutex.acquire()
    global _cons
    if not _cons:
        for item in request.session.items:
            host = item.callspec.params.get("host")
            if host not in _cons:
                con = SSHClient()
                con.set_missing_host_key_policy(AutoAddPolicy())
                con.connect(host.ip, host.port, host.username, host.password)
                _cons[host] = con


    mutex.release()
    return _cons


def pytest_sessionfinish(session, exitstatus):
    mutex.acquire()
    global _cons
    if _cons:
        for value in _cons.values():
            value.close()
    mutex.release()
  • Related