I have the following statement in one of the methods under unit test.
db_employees = self.db._session.query(Employee).filter(Employee.dept ==
new_employee.dept).all()
I want db_employees to get mock list of employees. I tried to achieve this using:
m = MagickMock()
m.return_value.filter().all().return_value = employees
where employees is a list of employee object. But this did not work. When I try to print the value of any attribute, it has a mock value. This is how the code looks:
class Database(object):
def __init__(self, user=None, passwd=None, db="sqlite:////tmp/emp.db"):
try:
engine = create_engine(db)
except Exception:
raise ValueError("Database '%s' does not exist." % db)
def on_connect(conn, record):
conn.execute('pragma foreign_keys=ON')
if 'sqlite://' in db:
event.listen(engine, 'connect', on_connect)
Base.metadata.bind = engine
DBSession = sessionmaker(bind=engine)
self._session = DBSession()
class TestEmployee(MyEmployee):
def setUp(self):
self.db = emp.database.Database(db=options.connection)
self.db._session._autoflush()
#mock.patch.object(session.Session, 'add')
#mock.patch.object(session.Session, 'query')
def test_update(self, mock_query, mock_add):
employees = [{'id': 1,
'name': 'Pradeep',
'department': 'IT',
'manager': 'John'}]
mock_add.side_effect = self.add_side_effect
mock_query.return_value = self.query_results()
self.update_employees(employees)
def add_side_effect(self, instance, _warn=True):
// Code to mock add
// Values will be stored in a dict which will be used to
// check with expected value.
def query_results(self):
m = MagicMock()
if self.count == 0:
m.return_value.filter.return_value.all.return_value = [employee]
elif:
m.return_value.filter.return_value.all.return_value = [department]
return m
I have query_results as the method under test calls query twice. First the employee table and next the department table.
How do I mock this chained function call?
m = Mock()
m.session.query().filter().all.return_value = employees
https://docs.python.org/3/library/unittest.mock.html
I found a solution to a similar problem where I needed to mock out a nested set of filtering calls.
Given code under test similar to the following:
interesting_cats = (session.query(Cats)
.filter(Cat.fur_type == 'furry')
.filter(Cat.voice == 'meowrific')
.filter(Cat.color == 'orande')
.all())
You can setup mocks like the following:
mock_session_response = MagicMock()
# This is the magic - create a mock loop
mock_session_response.filter.return_value = mock_session_response
# We can exit the loop with a call to 'all'
mock_session_response.all.return_value = provided_cats
mock_session = MagicMock(spec=Session)
mock_session.query.return_value = mock_session_response
You should patch query() method of _session's Database attribute and configure it to give you the right answer. You can do it in a lot of way, but IMHO the cleaner way is to patch DBSession's query static reference. I don't know from witch module you imported DBSession so I'll patch the local reference.
The other aspect is the mock configuration: we will set query's return value that in your case become the object that have filter() method.
class TestEmployee(MyEmployee):
def setUp(self):
self.db = emp.database.Database(db=options.connection)
self.db._session._autoflush()
self.log_add = {}
#mock.patch.object(__name__.'DBSession.add')
#mock.patch.object(__name__.'DBSession.query')
def test_update(self, mock_query, mock_add):
employees = [{'id': 1,
'name': 'Pradeep',
'department': 'IT',
'manager': 'John'}]
mock_add.side_effect = self.add_side_effect
mock_query.return_value = self.query_results()
self.update_employees(employees)
.... your test here
def add_side_effect(self, instance, _warn=True):
# ... storing data
self.log_add[...] = [...]
def query_results(self):
m = MagicMock()
value = "[department]"
if not self.count:
value = "[employee]"
m.filter.return_value.all.return_value = value
return m
Related
I want to unit test a method that includes a database (SQL Server) call.
I don't want the test to connect to the actual database.
I use unittest for testing, I have done some research and it seems that Mocking could do the trick but not sure about the syntax.
The select statement on the code below returns some integers. I guess that mocking will target the "cursor.execute" and "cursor.fetchall()" parts of the code.
from databaselibrary.Db import Db
class RandomClass():
def __init__(self, database):
self.database = database # Main DB for inserting data
def check_file_status(self, trimmed_file_data, file_date):
cursor = self.database.cursor()
cursor.execute(f"""SELECT DISTINCT query_id
FROM wordcloud_count
WHERE date = '{file_date}'""")
queries_in_DB = set(row.query_id for row in cursor.fetchall())
queries_in_file = set(trimmed_file_data.keys())
if queries_in_DB == queries_in_file:
return False
return True
def run(self):
print("Hello")
if __name__ == "__main__":
connection_string = 'sql://user:password#server/database'
database = Db(connection_string, autocommit=True)
random = RandomClass(database)
random.run()
The test class could look like that:
import unittest
from unittest.mock import Mock, patch
from project.RandomClass import RandomClass
from datetime import datetime
class testRandomClass(unittest.TestCase):
def setUp(self):
self.test_class = RandomClass("don't want to put actual database here")
#patch("project.RandomClass.check_file_status",return_value={123, 1234})
def test_check_file_status(self):
keys = {'1234':'2','123':'1','111':'5'}
result = self.test_class.check_file_status(keys, datetime(1900, 1, 1, 23, 59, 59))
self.assertTrue(result)
You should mock the db connection object, and the cursor. Then, set the return value of the cursor to return the expected value. I've tested the below code, and used class Row to mock the rows returned from fetchall call:
import unittest
from unittest.mock import MagicMock
from datetime import datetime
from project.RandomClass import RandomClass
class Row(object):
def __init__(self, x):
self.query_id = x
class testRandomClass(unittest.TestCase):
def setUp(self):
dbc = MagicMock(name="dbconn")
cursor = MagicMock(name="cursor")
cursor.fetchall.return_value = [Row(1), Row(2)]
dbc.cursor.return_value = cursor
self.test_class = RandomClass(dbc)
def test_check_file_status(self):
keys = {'1234': '2', '123': '1', '111': '5'}
result = self.test_class.check_file_status(keys, datetime(1900, 1, 1, 23, 59, 59))
self.assertTrue(result)
Since in your RandomClass you iterate rows and get their query_id, you need to use a class (or a named tuple) as the row objects returned by the mock.
You should create the row objects you expected, and set them as the return value of fetchall.
my project has a file called config.py which has, among others, the following code:
class Secret(Enum):
DATABASE_A = 'name_of_secret_database_A'
DATABASE_A = 'name_of_secret_database_A'
def secret(self):
if self.value:
return get_secret(self.value)
return {}
def get_secret(secret_name):
session = Session()
client = session.client(
service_name='secretsmanager',
region_name='us-east-1',
)
secret_value = client.get_secret_value(SecretId=secret_name)
return loads(secret_value.get('SecretString', "{}"))
I need to somehow mock get_secret in tests with pytest for all enum calls, for example Secret.DATABASE_A.secret ()
You can use monkeypatch to override the behaviour of get_secret(). I have made the get_secret() method a static method of the Secret class, but you can make it part of any module you want and import it as well. Just make sure you change in in the monkeypatch.setattr() call as well.
import pytest
from enum import Enum
class Secret(Enum):
DATABASE_A = 'name_of_secret_database_A'
DATABASE_B = 'name_of_secret_database_B'
def secret(self):
if self.value:
return Secret.get_secret(self.value)
return {}
#staticmethod
def get_secret(secret_name):
session = Session()
client = session.client(
service_name='secretsmanager',
region_name='us-east-1',
)
secret_value = client.get_secret_value(SecretId=secret_name)
return loads(secret_value.get('SecretString', "{}"))
def test_secret_method(monkeypatch):
def get_secret(secret_name):
return "supersecret"
monkeypatch.setattr(Secret, "get_secret", get_secret)
s = Secret.DATABASE_A
assert s.secret() == "supersecret"
This returns into 1 passed test.
What is happening here is, that I created a function get_secret() in my test_secret_method as well, and then overwrite the Secret.get_secret() with that new method. Now, you can use the Secret class in your test_method and be sure what the 'get_secret()' method will return without actually running the original code.
I have some trouble with integration test. I'm using python 3.5, SQLAlchemy 1.2.0b3, latest docker image of postgresql. So, I wrote test:
# tests/integration/usecases/test_users_usecase.py
class TestGetUsersUsecase(unittest.TestCase):
def setUp(self):
Base.metadata.reflect(_pg)
Base.metadata.drop_all(_pg)
Base.metadata.create_all(_pg)
self._session = sessionmaker(bind=_pg, autoflush=True, autocommit=False, expire_on_commit=True)
self.session = self._session()
self.session.add(User(id=1, username='user1'))
self.session.commit()
self.pg = PostgresService(session=self.session)
def test_get_user(self):
expected = User(id=1, username='user1')
boilerplates.get_user_usecase(storage_service=self.pg, id=1, expected=expected)
# tests/boilerplates/usecases/user_usecases.py
def get_user_usecase(storage_service, id, expected):
u = GetUser(storage_service=storage_service)
actual = u.apply(id=id)
assert expected == actual
In usecase I did next:
# usecases/users.py
class GetUser(object):
"""
Usecase for getting user from storage service by Id
"""
def __init__(self, storage_service):
self.storage_service = storage_service
def apply(self, id):
user = self.storage_service.get_user_by_id(id=id)
if user is None:
raise UserDoesNotExists('User with id=\'%s\' does not exists' % id)
return user
storage_service.get_user_by_id looks like:
# infrastructure/postgres.py (method of Postgres class)
def get_user_by_id(self, id):
from anna.domain.user import User
return self.session.query(User).filter(User.id == id).one_or_none()
And it does not work in my integration test. But if I add print(actual) in test before assert - all is OK. I thought that my test is bad, I try many variants and all does not works. Also I tried return generator from storage_service.get_user_by_id() and it also does not work. What I did wrong? It works good only if print() was called in test.
I'm working with postgresql and I have used MagicMock to test, but I'm not sure that I have understanded the mock's concepts. This is my example code (I have a dbname=test, table=py_test and user = simone):
import psycopg2
import sys
from mock import Mock, patch
import unittest
from mock import MagicMock
from collections import Counter
import doctest
class db(object):
def __init__(self,database, user):
self.con = None
self.database = database
self.user = user
def test_connection(self):
"""Connection DB"""
try:
self.con = psycopg2.connect(database=self.database, user=self.user)
return True
except psycopg2.DatabaseError, e:
print 'Error %s' % e
return False
def test_empty_table(self,table):
"""empty table?"""
try:
cur = self.con.cursor()
cur.execute('SELECT * from ' + table )
ver = cur.fetchone()
return ver
except psycopg2.DatabaseError, e:
print 'Error %s' % e
def test_data_type(self, table, column):
"""data type"""
try:
cur = self.con.cursor()
cur.execute("SELECT data_type from information_schema.columns where table_name = '"+ table + "' and column_name= '"+column+"'")
ver = cur.fetchone()
return ver
except psycopg2.DatabaseError, e:
print 'Error %s' % e
def __del__(self):
if self.con:
self.con.close()
class test_db(unittest.TestCase):
def testing(self):
tdb = db('test','simone')
self.assertTrue(tdb.test_connection(), 1)
self.assertTrue(tdb.test_empty_table('py_test'), 1)
self.assertTrue(tdb.test_data_type('py_test','id'), int)
class test_mock(object):
def __init__(self, db):
self.db = db
def execute(self, nomedb, user, table, field):
self.db(nomedb, user)
self.db.test_connection()
self.db.test_empty_table(table)
self.db.test_data_type(table, field)
if __name__ == "__main__":
c = MagicMock()
d = test_mock(c)
d.execute('test','simone','py_test','id')
method_count = Counter([str(method) for method in c.method_calls])
print c.method_calls
print method_count
print c.mock_calls
Maybe I'll give You some other example of mocking using Mockito package:
import sphinxsearch
import unittest
from mockito import mock, when, unstub, verify
class SearchManagerTest(unittest.TestCase):
def setUp(self):
self.sphinx_client = mock()
when(sphinxsearch).SphinxClient().thenReturn(self.sphinx_client)
def tearDown(self):
unstub()
def test_search_manager(self):
# given
value = {'id': 142564}
expected_result = 'some value returned from SphinxSearch'
# when
search_manager = SearchManager()
result = search_manager.get(value)
# then
verify(self.sphinx_client).SetServer('127.0.0.1', 9312)
verify(self.sphinx_client).SetMatchMode(sphinxsearch.SPH_MATCH_ALL)
verify(self.sphinx_client).SetRankingMode(sphinxsearch.SPH_RANK_WORDCOUNT)
self.assertEqual(result, expected_result)
Main concept is to replace some module (mock) that is tested some where else (it has it's own unittest module) and record some behavior.
Replace module You use with mock:
self.sphinx_client = mock()
and then record on this mock that if You call specific method, this method will return some data - simple values like strings or mocked data if You need to check behavior:
when(sphinxsearch).SphinxClient().thenReturn(self.sphinx_client)
In this case You tell that if You import sphinxsearch module and call SphinxClient() on it, You get mocked object.
Then the main test comes in. You call method or object to test (SearchManager here). It's body is tested with some given values:
self.search_manager = SearchManager()
When section verifies if some actions where made:
verify(self.sphinx_client).SetServer('127.0.0.1', 9312)
verify(self.sphinx_client).SetMatchMode(sphinxsearch.SPH_MATCH_ALL)
verify(self.sphinx_client).SetRankingMode(sphinxsearch.SPH_RANK_WORDCOUNT)
Here - if SetServer was called on self.sphinx_client with parameters '127.0.0.1' and 9312. Two other lines are self explanatory like above.
And here we do normal checks:
self.assertEqual(result, expected_result)
On this sample code i want to use the variables on the function db_properties at the function connect_and_query. To accomplish that I choose the return. So, using that strategy the code works perfectly. But, in this example the db.properties files only has 4 variables. That said, if the properties file had 20+ variables, should I continue using return? Or is there a most elegant/cleaner/correct way to do that?
import psycopg2
import sys
from ConfigParser import SafeConfigParser
class Main:
def db_properties(self):
cfgFile='c:\test\db.properties'
parser = SafeConfigParser()
parser.read(cfgFile)
dbHost = parser.get('database','db_host')
dbName = parser.get('database','db_name')
dbUser = parser.get('database','db_login')
dbPass = parser.get('database','db_pass')
return dbHost,dbName,dbUser,dbPass
def connect_and_query(self):
try:
con = None
dbHost=self.db_properties()[0]
dbName=self.db_properties()[1]
dbUser=self.db_properties()[2]
dbPass=self.db_properties()[3]
con = None
qry=("select star from galaxy")
con = psycopg2.connect(host=dbHost,database=dbName, user=dbUser,
password=dbPass)
cur = con.cursor()
cur.execute(qry)
data = cur.fetchall()
for result in data:
qryResult = result[0]
print "the test result is : " +qryResult
except psycopg2.DatabaseError, e:
print 'Error %s' % e
sys.exit(1)
finally:
if con:
con.close()
operation=Main()
operation.connect_and_query()
Im using python 2.7
Regards
If there are a lot of variables, or if you want to easily change the variables being read, return a dictionary.
def db_properties(self, *variables):
cfgFile='c:\test\db.properties'
parser = SafeConfigParser()
parser.read(cfgFile)
return {
variable: parser.get('database', variable) for variable in variables
}
def connect_and_query(self):
try:
con = None
config = self.db_properties(
'db_host',
'db_name',
'db_login',
'db_pass',
)
#or you can use:
# variables = ['db_host','db_name','db_login','db_pass','db_whatever','db_whatever2',...]
# config = self.db_properties(*variables)
#now you can use any variable like: config['db_host']
# ---rest of the function here---
Edit: I refactored the code so you can specify the variables you want to load in the calling function itself.
You certainly don't want to call db_properties() 4 times; just call it once and store the result.
It's also almost certainly better to return a dict rather than a tuple, since as it is the caller needs to know what the method returns in order, rather than just having access to the values by their names. As the number of values getting passed around grows, this gets even harder to maintain.
e.g.:
class Main:
def db_properties(self):
cfgFile='c:\test\db.properties'
parser = SafeConfigParser()
parser.read(cfgFile)
configDict= dict()
configDict['dbHost'] = parser.get('database','db_host')
configDict['dbName'] = parser.get('database','db_name')
configDict['dbUser'] = parser.get('database','db_login')
configDict['dbPass'] = parser.get('database','db_pass')
return configDict
def connect_and_query(self):
try:
con = None
conf = self.db_properties()
con = None
qry=("select star from galaxy")
con = psycopg2.connect(host=conf['dbHost'],database=conf['dbName'],
user=conf['dbUser'],
password=conf['dbPass'])
NB: untested
You could change your db_properties to return a dict:
from functools import partial
# call as db_properties('db_host', 'db_name'...)
def db_properties(self, *args):
parser = SafeConfigParser()
parser.read('config file')
getter = partial(parser.get, 'database')
return dict(zip(args, map(getter, args)))
But otherwise it's probably best to keep the parser as an attribute of the instance, and provide a convenience method...
class whatever(object):
def init(self, *args, **kwargs):
# blah blah blah
cfgFile='c:\test\db.properties'
self._parser = SafeConfigParser()
self._parser.read(cfgFile)
#property
def db_config(self, key):
return self._parser.get('database', key)
Then use con = psycopg2.connect(host=self.db_config('db_host')...)
I'd suggest returning a namedtuple:
from collections import namedtuple
# in db_properties()
return namedtuple("dbconfig", "host name user password")(
parser.get('database','db_host'),
parser.get('database','db_name'),
parser.get('database','db_login'),
parser.get('database','db_pass'),
)
Now you have an object that you can access either by index or by attribute.
config = self.db_properties()
print config[0] # db_host
print config.host # same