Python Mocks - Testing¶
Python's unittest
module provides Mock
class and patch()
that allows you to mock and substitute any object or behavior with mocks. This helps with abstracting the system under test and also safely reaching certain code paths that are reached only when exceptions are thrown.
Mock
can replace any object - because it can create any attribute or method on the fly when called- Return value of a fake, default method on a
Mock
is also aMock
. This allows you to chain and use mocks in complex code bases Mock
allows you inspect how it was called, how many times and which params were used.- Use
patch
to scope and apply your mock in a test - Use
patch()
as a decorator to apply the mock for the scope of the entire test function. - Use
patch()
as a context manager within the test function to narrow down its mocking scope. - Use
patch.object()
to further narrow down the scope to specific members of an object. - Use
Mock(spec=[<"member1_name">, <"property1_name"]>)
to prevent mock from auto creating attributes on the fly and stick to what is defined in the constructor - Use
unittest.mock.create_autospec(<your class>)
to easily create a mock the exact specification of a class or module - Use
patch(<object>, autospec=True)
to create a mock matching the exact spec of the object. This raises exceptions instead of auto creating members on the fly if they don't exist on the original object.
from unittest.mock import Mock
my_mock = Mock()
# on the fly attributes
print(my_mock.neo_name)
print(type(my_mock))
print(type(my_mock.neo_name))
<Mock name='mock.neo_name' id='4432181536'> <class 'unittest.mock.Mock'> <class 'unittest.mock.Mock'>
# on the fly method
resp = my_mock.do_something(1,2)
print(type(resp))
# assertions
my_mock.do_something.assert_called()
my_mock.do_something.assert_called_once()
my_mock.do_something.assert_called_with(1,2)
<class 'unittest.mock.Mock'>
my_mock.do_something(3,4)
# my_mock.do_something.assert_called_once() # will fail
<Mock name='mock.do_something()' id='4432504992'>
Mock - inspections¶
print(my_mock.do_something.call_count)
print(my_mock.do_something.call_args) # most recent
print(my_mock.method_calls) # all the calls do sfar
3 call(3, 4) [call.do_something(1, 2), call.do_something(3, 4), call.do_something(3, 4)]
Mock - return value¶
Using the return_value
parameter on a Mock
, you can explicitly state what should be returned when a mocked function is called.
# mocking datetime
from datetime import datetime
def is_tradable() -> bool:
"""Function returns True if trades can execute on that day."""
today = datetime.today()
return (0 <= today.weekday() < 5) # keep it simple, return True if weekday.
# unit test - by mocking different days
from unittest.mock import Mock
wed = datetime(year=2025, month=1, day=1)
sat = datetime(year=2025, month=1, day=4)
# now mock datetime and set return value
datetime = Mock()
datetime.today.return_value = wed
assert is_tradable() # should pass
datetime.today.return_value = sat
assert not is_tradable() # should pass as well
This way you can run tests any day of the week and expect consistent results.
Mock - side effects¶
Side effects are a step up from return values. It lets you control how a mocked function behaves. For example, you can simulate time outs and other exceptions and assert they are handled.
import unittest
from requests.exceptions import Timeout
from time import sleep
from unittest.mock import Mock
# mock requests module
requests = Mock()
resp = Mock()
def get_holidays() -> dict | None:
"""Get holidays in a year"""
r = requests.get("http://some_api/holidays")
if r.status_code == 200:
return r.json()
return None
resp.status_code = 200
resp.json.return_value = {"jan":[1,16]}
requests.get.return_value = resp
assert get_holidays() == {'jan': [1, 16]}
# check for timeout
requests.get.side_effect = sleep(5)
assert get_holidays() == {'jan': [1, 16]}
Patch¶
So far we mocked objects in local scope. The patch()
function allows the test to mock any object from any scope. You simply pass the path to the object patch(filename.filename.filename.object)
and it gets patched into a MagicMock
object. The Magic mock inherits the Mock()
class and sets some useful defaults, notably the magic methods like __len__()
, __str__()
etc, allowing you to write better tests.
Calling patch()
as a decorator¶
When you you call patch()
as a decorator, you patch for the entire context of that test function.
# Store in holidays.py file
import requests
from datetime import datetime
def is_tradable() -> bool:
"""Function returns True if trades can execute on that day."""
today = datetime.today()
return (0 <= today.weekday() < 5) # keep it simple, return True if weekday.
def get_holidays() -> dict | None:
"""Get holidays in a year"""
r = requests.get("http://some_api/holidays")
if r.status_code == 200:
return r.json()
return None
from unittest.mock import patch
from requests.exceptions import Timeout
from holidays import get_holidays
import pytest
@patch("holidays.requests") # patch the requests module imported in that file
def test_get_holidays_timeout(mock_obj): # mock_obj is passed as a param which mocks the requests module
mock_obj.get.side_effect = Timeout # set return values, side effects etc.
with pytest.raises(Timeout):
get_holidays()
test_get_holidays_timeout()
Patch as context manager¶
This allows you to mock an object with smaller scope than using it as a decorator.
from unittest.mock import patch
from holidays import get_holidays
def logger(content):
print(f"Log: {content}")
raise Timeout
def test_get_holidays_withlog():
with patch("holidays.requests") as mock_requests:
mock_requests.get.side_effect = logger # specify any function that needs to be executed when called
with pytest.raises(Timeout):
get_holidays()
test_get_holidays_withlog()
Log: http://some_api/holidays
Patch a particular property or method of an object.¶
Use patch.object()
to narrow the mock scope even further - down to a specific property or method
from unittest.mock import patch
from holidays import get_holidays, requests # import the requests from the holidays file
def logger(content):
print(f"Log: {content}")
raise Timeout
def test_get_holidays_withlog_patch_obj():
with patch.object(requests, "get", side_effect=logger):
with pytest.raises(Timeout):
get_holidays()
test_get_holidays_withlog_patch_obj()
Log: http://some_api/holidays