Home > front end >  Mocking complex data structures in Python
Mocking complex data structures in Python

Time:02-02

I have set of scripts designed to administer a database cluster. I have converted these scripts to an object-oriented design and am having a hell of a time trying to write unit tests. Here's a simplified example of what I'm trying to achieve:

class DBServer:
    def __init__(self,hostname):
        self.hostname=hostname

    def get_instance_status(self):
        if (<instance_is_online>):
            return True
        else:
            return false

    def shutdown_instance(self):
        #<insert code here to stop the DB instance on that node>


class DBCluster:
    def __init__(self,clustername):
        self.servers=[DBServer('hostname1'), DBServer('hostname2'), DBServer('hostname3'), DBServer('hostname4')]
    
    def get_servers_online(self):
        #returns an array of DBServer objects where the DB instance is online
        serversonline=[]
        for server in self.servers:
            if server.get_instance_status():
                serversonline.append(server)
        return serversonline
            
    def shutdown_cluster(self):
        for server in self.servers:
            if server.get_instance_status():
                server.shutdown_instance()

As you can see, a DBServer is an object with a hostname and some functions that can be run against each server. A DBCluster is an object with an array of DBServers, and some functions that manage the entire cluster.

Imagine a workflow where the user wants to get a list of servers in the cluster where the DB instance is online. Then the user should pick one and shut it down:

def shutdown_user_instance(node_number):
    #get list of online instances in cluster, display to user and shut down the chosen instance:
    cluster=DBCluster()
    servers=cluster.get_servers_online()
    print ('Choose server to shut down:')
    i=0
    for server in servers:
        print (str(i)  ': '  server.hostname)
        i  
    reply=str(input('Option: '))
    instance=int(reply)
    servers[instance].shutdown_instance()

How do we unit test this function? I'm thinking it should mock the cluster.get_instances_online() function to return an array of mock DBServer objects. I can then create a return value for server.hostname for each object in that array, and then mock server.shutdown_instance(). The trouble is I don't know where to start. Something like this?

class TestCluster(unittest.TestCase):

    @patch(stack.DBCluster)
    @patch(stack.DBServer)
    @patch(builtins.input)
    @patch.object(stack.DBserver,'shutdown_instance')
    def test_shutdown(self, mock_cluster, mock_server, mock_input, mock_shutdown):
        mock_server.get_instance_status.side_effect[True,False,True,False] #for the purpose of this test, let's say nodes 0 and 2 are online, 1 and 3 are offline
        mock_cluster.get_servers_online.return_value=[mock_server,mock_server] #How should I achieve this array of mocked objects to test??
        mock_input.return_value=0
        shutdown_user_instance() #theoretically this should mock a shutdown of node 0

CodePudding user response:

Here's a self-contained example that passes pytest:

class DBCluster:
    def get_servers_online():
        pass

def shutdown_user_instance(node_number):
    #get list of online instances in cluster, display to user and shut down the chosen instance:
    cluster = DBCluster()
    servers=cluster.get_servers_online()
    print ('Choose server to shut down:')
    i=0
    for server in servers:
        print (str(i)  ': '  server.hostname)
        i  = 1
    reply=str(input('Option: '))
    instance=int(reply)
    servers[instance].shutdown_instance()


from unittest.mock import patch, MagicMock

@patch.object(DBCluster, "__new__")  # same as @patch("stack.DBCluster")
@patch("builtins.input")
def test_shutdown(mock_input, mock_cluster):
    mock_servers = [MagicMock() for _ in range(3)]
    mock_cluster.return_value.get_servers_online.return_value = mock_servers
    mock_input.return_value = "1"
    shutdown_user_instance(0)
    assert [s.shutdown_instance.call_count for s in mock_servers] == [0, 1, 0]

The key thing is that you only need to patch the symbols that are used directly in the code under test. Once you've patched DBCluster with a mock object, you can just use that mock and its mock attributes in your test. In particular, study this line and make sure you understand exactly what it's doing:

mock_cluster.return_value.get_servers_online.return_value = mock_servers

mock_cluster is a mock of the DBCluster constructor -- so its return_value is what DBCluster() will return in the code under test, i.e. this is the value that cluster will have. In turn, we want that mock object's get_servers_online to have a return_value that will then become servers in the code under test, and each of those elements is in turn a mock whose shutdown_instance attributes can be inspected after the code is executed:

assert [s.shutdown_instance.call_count for s in mock_servers] == [0, 1, 0]

asserting that we shut down instance 1 but not instances 0 or 2. (I always use assert mock.call_count == 1 rather than mock.assert_called() because if you typo the former your test will break in an obvious way, whereas if you typo the latter your test will pass when it shouldn't.)

Pay careful attention also to the patch/patch.object syntax -- this is the easiest thing to get wrong when you're new to unit testing (and usually the trickiest to debug).

  •  Tags:  
  • Related