Skip to main content Brad's PyNotes

Unittest Module: Built-in Testing Framework for Python

TL;DR

The unittest module provides TestCase classes, assertion methods, test discovery, and mock objects for comprehensive unit testing with setUp/tearDown methods and test runners.

Interesting!

Python’s unittest module follows the xUnit pattern used by testing frameworks across many languages, and it can automatically discover and run tests by looking for methods that start with “test” in classes that inherit from unittest.TestCase.

Basic Test Structure

python code snippet start

import unittest

class TestBasicOperations(unittest.TestCase):
    
    def test_addition(self):
        """Test basic addition"""
        result = 2 + 2
        self.assertEqual(result, 4)
    
    def test_subtraction(self):
        """Test basic subtraction"""
        result = 5 - 3
        self.assertEqual(result, 2)
    
    def test_multiplication(self):
        """Test basic multiplication"""
        result = 3 * 4
        self.assertEqual(result, 12)
        
    def test_division(self):
        """Test basic division"""
        result = 10 / 2
        self.assertEqual(result, 5.0)

# Run tests
if __name__ == '__main__':
    unittest.main()

python code snippet end

Common Assertion Methods

python code snippet start

import unittest

class TestAssertions(unittest.TestCase):
    
    def test_equality(self):
        """Test equality assertions"""
        self.assertEqual(1, 1)
        self.assertNotEqual(1, 2)
    
    def test_truth_values(self):
        """Test boolean assertions"""
        self.assertTrue(True)
        self.assertFalse(False)
        self.assertIsNone(None)
        self.assertIsNotNone("not none")
    
    def test_membership(self):
        """Test membership assertions"""
        self.assertIn(1, [1, 2, 3])
        self.assertNotIn(4, [1, 2, 3])
    
    def test_exceptions(self):
        """Test exception assertions"""
        with self.assertRaises(ValueError):
            int("not a number")
        
        with self.assertRaises(ZeroDivisionError):
            1 / 0
    
    def test_almost_equal(self):
        """Test floating point comparisons"""
        self.assertAlmostEqual(1.1, 1.15, places=1)
        self.assertAlmostEqual(1.0, 1.001, delta=0.01)
    
    def test_containers(self):
        """Test container assertions"""
        self.assertListEqual([1, 2, 3], [1, 2, 3])
        self.assertDictEqual({'a': 1}, {'a': 1})
        self.assertSetEqual({1, 2}, {2, 1})

python code snippet end

setUp and tearDown Methods

python code snippet start

import unittest
import tempfile
import os

class TestFileOperations(unittest.TestCase):
    
    def setUp(self):
        """Set up test fixtures before each test method"""
        # Create temporary directory
        self.test_dir = tempfile.mkdtemp()
        self.test_file = os.path.join(self.test_dir, 'test.txt')
        
        # Create test file
        with open(self.test_file, 'w') as f:
            f.write('Hello, World!')
    
    def tearDown(self):
        """Clean up after each test method"""
        # Remove test files
        import shutil
        shutil.rmtree(self.test_dir)
    
    def test_file_exists(self):
        """Test if file was created"""
        self.assertTrue(os.path.exists(self.test_file))
    
    def test_file_content(self):
        """Test file content"""
        with open(self.test_file, 'r') as f:
            content = f.read()
        self.assertEqual(content, 'Hello, World!')

# Class-level setup
class TestDatabase(unittest.TestCase):
    
    @classmethod
    def setUpClass(cls):
        """Set up class fixtures once for all test methods"""
        cls.database = "test_database"
        print(f"Setting up {cls.database}")
    
    @classmethod
    def tearDownClass(cls):
        """Clean up class fixtures after all tests"""
        print(f"Tearing down {cls.database}")
    
    def test_connection(self):
        """Test database connection"""
        self.assertIsNotNone(self.database)

python code snippet end

Testing Real Code

python code snippet start

# calculator.py
class Calculator:
    def add(self, a, b):
        return a + b
    
    def divide(self, a, b):
        if b == 0:
            raise ValueError("Cannot divide by zero")
        return a / b
    
    def get_history(self):
        return []

# test_calculator.py
import unittest
from calculator import Calculator

class TestCalculator(unittest.TestCase):
    
    def setUp(self):
        """Create calculator instance for each test"""
        self.calc = Calculator()
    
    def test_add_positive_numbers(self):
        """Test adding positive numbers"""
        result = self.calc.add(3, 5)
        self.assertEqual(result, 8)
    
    def test_add_negative_numbers(self):
        """Test adding negative numbers"""
        result = self.calc.add(-3, -5)
        self.assertEqual(result, -8)
    
    def test_add_mixed_numbers(self):
        """Test adding positive and negative numbers"""
        result = self.calc.add(10, -3)
        self.assertEqual(result, 7)
    
    def test_divide_normal(self):
        """Test normal division"""
        result = self.calc.divide(10, 2)
        self.assertEqual(result, 5.0)
    
    def test_divide_by_zero(self):
        """Test division by zero raises exception"""
        with self.assertRaises(ValueError) as context:
            self.calc.divide(10, 0)
        
        self.assertEqual(str(context.exception), "Cannot divide by zero")
    
    def test_history_starts_empty(self):
        """Test that history starts empty"""
        history = self.calc.get_history()
        self.assertEqual(len(history), 0)

python code snippet end

Mocking with unittest.mock

python code snippet start

import unittest
from unittest.mock import Mock, patch, MagicMock
import requests

# Code to test
class WeatherService:
    def get_temperature(self, city):
        response = requests.get(f"http://api.weather.com/{city}")
        data = response.json()
        return data['temperature']
    
    def is_hot(self, city):
        temp = self.get_temperature(city)
        return temp > 30

# Test with mocks
class TestWeatherService(unittest.TestCase):
    
    def setUp(self):
        self.service = WeatherService()
    
    @patch('requests.get')
    def test_get_temperature(self, mock_get):
        """Test temperature retrieval with mocked API call"""
        # Configure mock
        mock_response = Mock()
        mock_response.json.return_value = {'temperature': 25}
        mock_get.return_value = mock_response
        
        # Test
        temp = self.service.get_temperature('Sydney')
        
        # Assertions
        self.assertEqual(temp, 25)
        mock_get.assert_called_once_with("http://api.weather.com/Sydney")
    
    @patch.object(WeatherService, 'get_temperature')
    def test_is_hot_true(self, mock_get_temp):
        """Test is_hot returns True for high temperature"""
        mock_get_temp.return_value = 35
        
        result = self.service.is_hot('Dubai')
        
        self.assertTrue(result)
        mock_get_temp.assert_called_once_with('Dubai')
    
    @patch.object(WeatherService, 'get_temperature')
    def test_is_hot_false(self, mock_get_temp):
        """Test is_hot returns False for low temperature"""
        mock_get_temp.return_value = 20
        
        result = self.service.is_hot('London')
        
        self.assertFalse(result)
        mock_get_temp.assert_called_once_with('London')

python code snippet end

Test Organization with Test Suites

python code snippet start

import unittest

# Create test suite
def create_test_suite():
    """Create a test suite with specific tests"""
    suite = unittest.TestSuite()
    
    # Add specific test methods
    suite.addTest(TestCalculator('test_add_positive_numbers'))
    suite.addTest(TestCalculator('test_divide_normal'))
    
    # Add all tests from a class
    suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestWeatherService))
    
    return suite

# Run specific suite
if __name__ == '__main__':
    runner = unittest.TextTestRunner(verbosity=2)
    suite = create_test_suite()
    runner.run(suite)

python code snippet end

Skipping Tests and Expected Failures

python code snippet start

import unittest
import sys

class TestConditional(unittest.TestCase):
    
    @unittest.skip("Temporarily disabled")
    def test_temporarily_disabled(self):
        """This test is skipped"""
        self.fail("This shouldn't run")
    
    @unittest.skipIf(sys.version_info < (3, 8), "Python 3.8+ required")
    def test_python38_feature(self):
        """Test that requires Python 3.8+"""
        # Test walrus operator
        data = [1, 2, 3, 4, 5]
        if (n := len(data)) > 3:
            self.assertGreater(n, 3)
    
    @unittest.skipUnless(sys.platform.startswith("win"), "Windows only")
    def test_windows_only(self):
        """Test that only runs on Windows"""
        import winreg  # Windows-specific module
        self.assertIsNotNone(winreg)
    
    @unittest.expectedFailure
    def test_known_bug(self):
        """Test for a known bug that's not fixed yet"""
        self.assertEqual(1, 2)  # This will fail but is expected

python code snippet end

Practical Testing Examples

Testing a User Class

python code snippet start

class User:
    def __init__(self, username, email):
        if not username:
            raise ValueError("Username cannot be empty")
        if '@' not in email:
            raise ValueError("Invalid email format")
        
        self.username = username
        self.email = email
        self.is_active = True
        self.login_count = 0
    
    def login(self):
        if not self.is_active:
            raise RuntimeError("User account is disabled")
        self.login_count += 1
    
    def disable(self):
        self.is_active = False

class TestUser(unittest.TestCase):
    
    def test_valid_user_creation(self):
        """Test creating a valid user"""
        user = User("john_doe", "john@example.com")
        self.assertEqual(user.username, "john_doe")
        self.assertEqual(user.email, "john@example.com")
        self.assertTrue(user.is_active)
        self.assertEqual(user.login_count, 0)
    
    def test_empty_username_raises_error(self):
        """Test that empty username raises ValueError"""
        with self.assertRaises(ValueError) as context:
            User("", "test@example.com")
        self.assertIn("Username cannot be empty", str(context.exception))
    
    def test_invalid_email_raises_error(self):
        """Test that invalid email raises ValueError"""
        with self.assertRaises(ValueError) as context:
            User("testuser", "invalid-email")
        self.assertIn("Invalid email format", str(context.exception))
    
    def test_successful_login(self):
        """Test successful login increments counter"""
        user = User("testuser", "test@example.com")
        user.login()
        self.assertEqual(user.login_count, 1)
    
    def test_disabled_user_cannot_login(self):
        """Test disabled user cannot login"""
        user = User("testuser", "test@example.com")
        user.disable()
        
        with self.assertRaises(RuntimeError) as context:
            user.login()
        self.assertIn("User account is disabled", str(context.exception))
    
    def test_multiple_logins(self):
        """Test multiple logins increment counter correctly"""
        user = User("testuser", "test@example.com")
        for i in range(5):
            user.login()
        self.assertEqual(user.login_count, 5)

python code snippet end

Testing File Operations

python code snippet start

import unittest
import tempfile
import os
import json

class ConfigManager:
    def __init__(self, config_file):
        self.config_file = config_file
        self.config = {}
    
    def load(self):
        if os.path.exists(self.config_file):
            with open(self.config_file, 'r') as f:
                self.config = json.load(f)
    
    def save(self):
        with open(self.config_file, 'w') as f:
            json.dump(self.config, f)
    
    def get(self, key, default=None):
        return self.config.get(key, default)
    
    def set(self, key, value):
        self.config[key] = value

class TestConfigManager(unittest.TestCase):
    
    def setUp(self):
        """Create temporary config file for each test"""
        self.temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json')
        self.temp_file.close()
        self.config_manager = ConfigManager(self.temp_file.name)
    
    def tearDown(self):
        """Clean up temporary file"""
        if os.path.exists(self.temp_file.name):
            os.unlink(self.temp_file.name)
    
    def test_load_nonexistent_file(self):
        """Test loading non-existent config file"""
        # Delete the temp file to simulate non-existent file
        os.unlink(self.temp_file.name)
        
        self.config_manager.load()
        self.assertEqual(self.config_manager.config, {})
    
    def test_save_and_load(self):
        """Test saving and loading configuration"""
        # Set some config values
        self.config_manager.set('database_url', 'sqlite:///test.db')
        self.config_manager.set('debug', True)
        
        # Save to file
        self.config_manager.save()
        
        # Create new instance and load
        new_manager = ConfigManager(self.temp_file.name)
        new_manager.load()
        
        # Verify data was persisted
        self.assertEqual(new_manager.get('database_url'), 'sqlite:///test.db')
        self.assertTrue(new_manager.get('debug'))
    
    def test_get_with_default(self):
        """Test getting values with default"""
        self.assertEqual(self.config_manager.get('nonexistent', 'default'), 'default')
        self.assertIsNone(self.config_manager.get('nonexistent'))

python code snippet end

Advanced Testing Patterns

Custom Assertions

python code snippet start

class CustomAssertions(unittest.TestCase):
    
    def assertBetween(self, value, min_val, max_val):
        """Custom assertion to check if value is between min and max"""
        if not (min_val <= value <= max_val):
            raise AssertionError(f"{value} is not between {min_val} and {max_val}")
    
    def assertValidEmail(self, email):
        """Custom assertion to validate email format"""
        import re
        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        if not re.match(pattern, email):
            raise AssertionError(f"{email} is not a valid email address")
    
    def test_custom_assertions(self):
        """Test using custom assertions"""
        self.assertBetween(50, 1, 100)
        self.assertValidEmail("test@example.com")
        
        with self.assertRaises(AssertionError):
            self.assertBetween(150, 1, 100)

python code snippet end

Running Tests

python code snippet start

# Command line test discovery
# python -m unittest discover

# Run specific test file
# python -m unittest test_calculator.py

# Run specific test class
# python -m unittest test_calculator.TestCalculator

# Run specific test method
# python -m unittest test_calculator.TestCalculator.test_add_positive_numbers

# Verbose output
# python -m unittest -v

# Test runner with custom configuration
if __name__ == '__main__':
    # Configure test runner
    unittest.main(
        verbosity=2,
        buffer=True,  # Capture stdout/stderr
        failfast=True,  # Stop on first failure
        warnings='ignore'  # Suppress warnings
    )

python code snippet end

Best Practices

python code snippet start

class TestBestPractices(unittest.TestCase):
    """Demonstrates testing best practices"""
    
    def test_descriptive_names(self):
        """Use descriptive test names that explain what is being tested"""
        # Good: test_user_login_increments_counter
        # Bad: test_login
        pass
    
    def test_one_assertion_per_concept(self):
        """Each test should focus on one specific behavior"""
        user = User("testuser", "test@example.com")
        user.login()
        
        # Test one specific behavior
        self.assertEqual(user.login_count, 1)
    
    def test_independent_tests(self):
        """Tests should be independent and not rely on other tests"""
        # Each test should set up its own data
        # Don't rely on test execution order
        pass
    
    def test_edge_cases(self):
        """Test edge cases and boundary conditions"""
        # Test empty inputs, null values, maximum values, etc.
        with self.assertRaises(ValueError):
            User("", "test@example.com")  # Edge case: empty username
    
    def test_arrange_act_assert_pattern(self):
        """Follow the Arrange-Act-Assert pattern"""
        # Arrange: Set up test data
        user = User("testuser", "test@example.com")
        
        # Act: Perform the action being tested
        user.login()
        
        # Assert: Verify the expected outcome
        self.assertEqual(user.login_count, 1)

python code snippet end

The unittest module provides a robust foundation for testing Python applications with comprehensive assertion methods, fixtures, and mocking capabilities.

Unit testing integrates naturally with Python's module system for organizing test code and benefits from proper exception handling patterns.

Reference: Python Unittest Module Documentation