Home > Back-end >  Test if given input value does not break from while True loop
Test if given input value does not break from while True loop

Time:10-13

I want to unit test a function called prompt_process. The function's role is to prompt whether to call another function, process, or exit the program altogether, based on the return value of input:

process.py

import sys


def process():
    pass


def prompt_process():
    answer = input("Do you want to process? (y/n)").lower()
    while True:
        if answer == 'y':
            return process()
        elif answer == 'n':
            sys.exit(0)
        else:
            answer = input("Invalid answer - please type y or n").lower()

Testing the if/elif clauses is simple enough using mocks:

test_process.py

import unittest
from unittest import mock

import process


class TestProcess(unittest.TestCase):
    @mock.patch('process.process')
    def test_prompt_process_calls_process_if_input_is_y(self, mock_process):
        with mock.patch('process.input', return_value='y'):
            process.prompt_process()
        mock_process.assert_called_once()
        
    def test_prompt_process_exits_if_input_is_n(self):
        with mock.patch('process.input', return_value='n'):
            with self.assertRaises(SystemExit):
                process.prompt_process()


if __name__ == '__main__':
    unittest.main()

But I don't know how to test the else clause of the function. Ideally, it should test that the loop keeps running for input values other than y and n, but since I can't "mock" a while True loop I'm not sure how to proceed.

Is there any solution for testing this, or is this something I shouldn't bother with in the first place?

CodePudding user response:

Calling sys.exit(0) within a function cannot be considered a best practice, there may be open files, pending transactions or other transactional/transient pending operations.

The suggestion is to create a function that executes the exit and call that one inside the function to be tested. This way your function should be easily testable.

import sys


def process():
    pass

def leave():
    sys.exit(0)

def prompt_process():
    answer = input("Do you want to process? (y/n)").lower()
    while True:
        if answer == 'y':
            return process()
        elif answer == 'n':
            return leave()
        else:
            answer = input("Invalid answer - please type y or n").lower()

CodePudding user response:

Non-invasive way

One way, that is non-invasive, is to check that input is called twice by raising an exception using side_effect instead of return_value.

This makes a small assumption about the implementation.

# Define a custom exception to ensure it's not raised elsewhere
class EndError(Exception):
    pass
def test_prompt_process_does_not_break_from_while_true_loop_if_input_is_invalid(self):
    with mock.patch('process.input', side_effect=('invalid', EndError)):
        with self.assertRaises(EndError):
            process.prompt_process()

Minimally invasive way (no-op function)

Another way, that is minimally invasive, is to define and call a noop function in the while loop.

def noop():
    pass
while True:
    noop()  # Add this
    # ...
@mock.patch('process.noop', side_effect=(None, EndError))
def test_prompt_process_does_not_break_from_while_true_loop_if_input_is_invalid(self, mock_noop):
    with mock.patch('process.input', return_value='invalid'):
        with self.assertRaises(EndError):
            process.prompt_process()
    self.assertEqual(mock_noop.call_count, 2)

Testing max iterations

If your implementation limits the number of tries, then you don't have to mess with exceptions.

max_tries = 10
for i in range(max_tries):
    noop()
    # ...
    if i < max_tries - 1:
        answer = input("Invalid answer - please type y or n").lower()
    else:
        _ = input("Invalid answer - press enter to end")  # Either
        # print("Invalid answer")                         # or
@mock.patch('process.noop')
def test_prompt_process_accepts_10_tries_if_input_is_invalid(self, mock_noop):
    with mock.patch('process.input', return_value='invalid') as mock_input:
        process.prompt_process()
    self.assertEqual(mock_input.call_count, 11)  # This requires knowledge about the implementation
    self.assertEqual(mock_noop.call_count, 10)   # This only makes an assumption that noop is called
  • Related