I have set of scripts designed to administer a database cluster. I have converted these scripts to an object-oriented design and am having a hell of a time trying to write unit tests. Here's a simplified example of what I'm trying to achieve:
class DBServer:
def __init__(self,hostname):
self.hostname=hostname
def get_instance_status(self):
if (<instance_is_online>):
return True
else:
return false
def shutdown_instance(self):
#<insert code here to stop the DB instance on that node>
class DBCluster:
def __init__(self,clustername):
self.servers=[DBServer('hostname1'), DBServer('hostname2'), DBServer('hostname3'), DBServer('hostname4')]
def get_servers_online(self):
#returns an array of DBServer objects where the DB instance is online
serversonline=[]
for server in self.servers:
if server.get_instance_status():
serversonline.append(server)
return serversonline
def shutdown_cluster(self):
for server in self.servers:
if server.get_instance_status():
server.shutdown_instance()
As you can see, a DBServer is an object with a hostname and some functions that can be run against each server. A DBCluster is an object with an array of DBServers, and some functions that manage the entire cluster.
Imagine a workflow where the user wants to get a list of servers in the cluster where the DB instance is online. Then the user should pick one and shut it down:
def shutdown_user_instance(node_number):
#get list of online instances in cluster, display to user and shut down the chosen instance:
cluster=DBCluster()
servers=cluster.get_servers_online()
print ('Choose server to shut down:')
i=0
for server in servers:
print (str(i) ': ' server.hostname)
i
reply=str(input('Option: '))
instance=int(reply)
servers[instance].shutdown_instance()
How do we unit test this function? I'm thinking it should mock the cluster.get_instances_online() function to return an array of mock DBServer objects. I can then create a return value for server.hostname for each object in that array, and then mock server.shutdown_instance(). The trouble is I don't know where to start. Something like this?
class TestCluster(unittest.TestCase):
@patch(stack.DBCluster)
@patch(stack.DBServer)
@patch(builtins.input)
@patch.object(stack.DBserver,'shutdown_instance')
def test_shutdown(self, mock_cluster, mock_server, mock_input, mock_shutdown):
mock_server.get_instance_status.side_effect[True,False,True,False] #for the purpose of this test, let's say nodes 0 and 2 are online, 1 and 3 are offline
mock_cluster.get_servers_online.return_value=[mock_server,mock_server] #How should I achieve this array of mocked objects to test??
mock_input.return_value=0
shutdown_user_instance() #theoretically this should mock a shutdown of node 0
CodePudding user response:
Here's a self-contained example that passes pytest
:
class DBCluster:
def get_servers_online():
pass
def shutdown_user_instance(node_number):
#get list of online instances in cluster, display to user and shut down the chosen instance:
cluster = DBCluster()
servers=cluster.get_servers_online()
print ('Choose server to shut down:')
i=0
for server in servers:
print (str(i) ': ' server.hostname)
i = 1
reply=str(input('Option: '))
instance=int(reply)
servers[instance].shutdown_instance()
from unittest.mock import patch, MagicMock
@patch.object(DBCluster, "__new__") # same as @patch("stack.DBCluster")
@patch("builtins.input")
def test_shutdown(mock_input, mock_cluster):
mock_servers = [MagicMock() for _ in range(3)]
mock_cluster.return_value.get_servers_online.return_value = mock_servers
mock_input.return_value = "1"
shutdown_user_instance(0)
assert [s.shutdown_instance.call_count for s in mock_servers] == [0, 1, 0]
The key thing is that you only need to patch the symbols that are used directly in the code under test. Once you've patched DBCluster
with a mock object, you can just use that mock and its mock attributes in your test. In particular, study this line and make sure you understand exactly what it's doing:
mock_cluster.return_value.get_servers_online.return_value = mock_servers
mock_cluster
is a mock of the DBCluster
constructor -- so its return_value
is what DBCluster()
will return in the code under test, i.e. this is the value that cluster
will have. In turn, we want that mock object's get_servers_online
to have a return_value
that will then become servers
in the code under test, and each of those elements is in turn a mock whose shutdown_instance
attributes can be inspected after the code is executed:
assert [s.shutdown_instance.call_count for s in mock_servers] == [0, 1, 0]
asserting that we shut down instance 1 but not instances 0 or 2. (I always use assert mock.call_count == 1
rather than mock.assert_called()
because if you typo the former your test will break in an obvious way, whereas if you typo the latter your test will pass when it shouldn't.)
Pay careful attention also to the patch
/patch.object
syntax -- this is the easiest thing to get wrong when you're new to unit testing (and usually the trickiest to debug).