I’ve been working on a project for my company https://gitlab.com/paessler-labs/prtg-pyprobe and it’s using asyncio library for Python quite heavily. It’s a great framework and so far we’ve had very little issues implementing it but I recently implemented some new code that I wanted to mock for unit testing and it was a bit more challenging than I thought it would be.
Here is the code that I wanted to test
start = time.time()
async with aioboto3.resource(
"s3", aws_access_key_id=task_data["aws_access_key"], aws_secret_access_key=task_data["aws_secret_key"]
) as s3:
all_buckets = s3.buckets.all()
i = 0
async for _ in all_buckets:
i += 1
s3_bucket_data.add_channel(
name="Total Buckets", mode="integer", kind="Custom", customunit="buckets", value=i
)
s3_bucket_data.message = f"Your AWS account has {i} buckets."
end = (time.time() - start) * 1000
s3_bucket_data.add_channel(name="Total Query Time", mode="float", kind="TimeResponse", value=end)
And here is what the unit test ended up looking like
@pytest.mark.asyncio
class TestS3TotalWork:
@asynctest.patch("aioboto3.resource")
async def test_sensor_s3_total(self, aioboto_mock, s3_total_sensor):
buckets = asynctest.MagicMock()
buckets.__aiter__.return_value = ["bucket1", "bucket2", "bucket3"]
aioboto_mock.return_value.__aenter__.return_value.buckets.all.return_value = buckets
s3_total_queue = asyncio.Queue()
await s3_total_sensor.work(task_data=task_data(), q=s3_total_queue)
queue_result = await s3_total_queue.get()
aioboto_mock.assert_called_once_with("s3", aws_access_key_id="1123124", aws_secret_access_key="jkh2089")
assert queue_result["message"] == "Your AWS account has 3 buckets."
assert {
"customunit": "buckets",
"kind": "Custom",
"mode": "integer",
"name": "Total Buckets",
"value": 3,
} in queue_result["channel"]
To be able to test this, I ended up using the https://pypi.org/project/asynctest/ which was really helpful for mocking async iterables and context managers.
The trick for me to figure out was what was returning what when for the call
all_buckets = s3.buckets.all()
The way that I tried to think of it is as such:
@asynctest.patch("aioboto3.resource")
...
aioboto_mock.return_value.__aenter__.return_value.buckets.all.return_value = buckets
‘aioboto_mock’ mocks the library and it’s attribute ‘resource’
The ‘return_value’ of that is the ‘s3’ context manager
Then we step into the context which is the ‘__aenter__’ method
Then we need another ‘return_value’ since this is where the context is returning the results from..
The method buckets.all()’s return_value.
buckets = asynctest.MagicMock()
buckets.__aiter__.return_value = ["bucket1", "bucket2", "bucket3"]
The buckets.all() method also returns an async iterable and so to patch the result of the s3.buckets.all() method we also have to set this to be a MagicMock that returns values for the __aiter__ method of async iterables..
Totally simple right :D. Hope this helps anyone else trying to understand how to mock async functions!
Recent Comments