Function to test
def get_adgroups_not_taked_share(
campaign_ids: List[str], src_table: str, spend_src_table: str
) -> List[Tuple[str, str]]:
loses_adgroups: List[Tuple[str, str]] = []
with RedshiftCursor() as cursor:
cursor.execute(
"""
SELET some_data from some_table WHERE some_condition
"""
)
for row in cursor.fetchall():
loses_adgroups.append((row[0], str(row[1])))
return loses_adgroups
There is a test for this function
import pytest
from my_ap import get_adgroups_not_taked_share
#pytest.fixture
def campaigns_redshift_cursor_mock(mocker):
cursor_mock = mocker.MagicMock()
cursor_mock.fetchall.return_value = [
('hs_video544', '123123123', 100),
('hs_video547', '123123123', 50),
]
rs_cursor_creator = mocker.patch('google_panel.logic.clean_creative.RedshiftCursor')
rs_cursor_creator.return_value.__enter__.return_value = cursor_mock
return rs_cursor_creator
#pytest.mark.django_db
def test_get_adgroups_not_taked_share(
campaigns_redshift_cursor_mock,
):
campaign_ids = ['1111', '2222', '3333']
result = get_adgroups_not_taked_share(campaign_ids, 'test_table', 'spend_src_table')
assert result == [('hs_video544', '123123123'), ('hs_video547', '123123123')]
Now I want to add a new feature to test the sql script. Checking that the correct sql query is being called. something like
def test_get_adgroups_not_taked_share(
campaigns_redshift_cursor_mock,
):
......
query = """SELET some_data from some_table WHERE some_condition"""
campaigns_redshift_cursor_mock.execute.assert_called_with(query)
But got
E AssertionError: Expected call: execute('query')
E Not called
The short answer is that you need to assert this: campaigns_redshift_cursor_mock.return_value.__enter__.return_value.execute.assert_called_once_with(query). The reason is that you're using RedshiftCursor as a context manager (hence the return_value.__enter__.return_value part) and only then calling the execute method.
A slightly longer answer is how I got this assert.
I wrote this library which adds the mg pytest fixture. To use it, pip install it and then add the mg fixture to your test and execute its generate_asserts method like so:
#pytest.mark.django_db
def test_get_adgroups_not_taked_share(
campaigns_redshift_cursor_mock,
mg
):
campaign_ids = ['1111', '2222', '3333']
result = get_adgroups_not_taked_share(campaign_ids, 'test_table', 'spend_src_table')
mg.generate_asserts(campaigns_redshift_cursor_mock)
assert result == [('hs_video544', '123123123'), ('hs_video547', '123123123')]
Then run the test as usual and you would get this output to the console:
assert 1 == campaigns_redshift_cursor_mock.call_count
campaigns_redshift_cursor_mock.assert_called_once_with()
campaigns_redshift_cursor_mock.return_value.__enter__.assert_called_once_with()
campaigns_redshift_cursor_mock.return_value.__enter__.return_value.execute.assert_called_once_with('\n SELET some_data from some_table WHERE some_condition\n ')
campaigns_redshift_cursor_mock.return_value.__enter__.return_value.fetchall.assert_called_once_with()
campaigns_redshift_cursor_mock.return_value.__exit__.assert_called_once_with(None, None, None)
These are all the calls made to the mock, you can filter the ones which are relevant for you.
Related
I've been trying to test the below metheod specially the if block and have tried multiple things like patching, mocking in various combinations of pyodbc but I've not been able to mock the if condition.
def execute_read(self, query):
dbconn = pyodbc.connect(self.connection_string, convert_unicode=True)
with dbconn.cursor() as cur:
cursor = cur.execute(query)
if not cursor.messages:
res = cursor.fetchall()
else:
raise Exception(cursor.messages[0][1])
return res;
# unit test method
#patch.object(pyodbc, 'connect')
def test_execute_read(self, pyodbc_mock):
pyodbc_mock.return_value = MagicMock()
self.assertIsNotNone(execute_read('query'))
I've read the docs of unittest.mock, but I haven't found a way to get this above if condition covered. Thank you.
You would want to patch the Connection class (given the Cursor object is immutable) and supply a return value for covering the if block. Something that may look like:
with patch.object("pyodbc.Connection") as conn:
conn.cursor().messages = []
...
Tried this with sqlite3 and that worked for me.
Here's an example of using the patch object, something I wrote for frappe/frappe:
def test_db_update(self):
with patch.object(Database, "sql") as sql_called:
frappe.db.set_value(
self.todo1.doctype,
self.todo1.name,
"description",
f"{self.todo1.description}-edit by `test_for_update`",
)
first_query = sql_called.call_args_list[0].args[0]
second_query = sql_called.call_args_list[1].args[0]
self.assertTrue(sql_called.call_count == 2)
self.assertTrue("FOR UPDATE" in first_query)
I am trying to mock sqlbuilder.func for test cases with pytest
I successfully mocked sqlbuilder.func.TO_BASE64 with correct output but when I tried mocking sqlbuilder.func.FROM_UNIXTIME I didn't get any error but the resulted output is incorrect with the generated query. Below is the minimal working example of the problem.
models.py
from sqlobject import (
sqlbuilder,
sqlhub,
SQLObject,
StringCol,
BLOBCol,
TimestampCol,
)
class Store(SQLObject):
name = StringCol()
sample = BLOBCol()
createdAt = TimestampCol()
DATE_FORMAT = "%Y-%m-%d"
def retrieve(name):
query = sqlbuilder.Select([
sqlbuilder.func.TO_BASE64(Store.q.sample),
],
sqlbuilder.AND(
Store.q.name == name,
sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
)
)
connection = sqlhub.getConnection()
query = connection.sqlrepr(query)
print(query)
queryResult = connection.queryAll(query)
return queryResult
conftest.py
import pytest
from models import Store
from sqlobject import sqlhub
from sqlobject.sqlite import sqliteconnection
#pytest.fixture(autouse=True, scope="session")
def sqlite_db_session(tmpdir_factory):
file = tmpdir_factory.mktemp("db").join("sqlite.db")
conn = sqliteconnection.SQLiteConnection(str(file))
sqlhub.processConnection = conn
init_tables()
yield conn
conn.close()
def init_tables():
Store.createTable(ifNotExists=True)
test_ex1.py
import pytest
from sqlobject import sqlbuilder
from models import retrieve
try:
import mock
from mock import MagicMock
except ImportError:
from unittest import mock
from unittest.mock import MagicMock
def TO_BASE64(x):
return x
def FROM_UNIXTIME(x, y):
return 'strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x)
# #mock.patch("sqlobject.sqlbuilder.func.TO_BASE64")
# #mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", MagicMock(side_effect=lambda x: x))
# #mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", new_callable=MagicMock(side_effect=lambda x: x))
#mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", TO_BASE64)
#mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME", FROM_UNIXTIME)
def test_retrieve():
result = retrieve('Some')
assert result == []
Current SQL:
SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND (1))
Expected SQL:
SELECT
store.sample
FROM
store
WHERE
store.name = 'Some'
AND
strftime(
'%Y%m%d',
datetime(store.created_at, 'unixepoch', 'localtime')
) >= strftime(
'%Y%m%d',
datetime('2018-10-12', 'unixepoch', 'localtime')
)
Edit Example
#! /usr/bin/env python
from sqlobject import *
__connection__ = "sqlite:/:memory:?debug=1&debugOutput=1"
try:
import mock
from mock import MagicMock
except ImportError:
from unittest import mock
from unittest.mock import MagicMock
class Store(SQLObject):
name = StringCol()
sample = BLOBCol()
createdAt = TimestampCol()
Store.createTable()
DATE_FORMAT = "%Y-%m-%d"
def retrieve(name):
query = sqlbuilder.Select([
sqlbuilder.func.TO_BASE64(Store.q.sample),
],
sqlbuilder.AND(
Store.q.name == name,
sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
)
)
connection = Store._connection
query = connection.sqlrepr(query)
queryResult = connection.queryAll(query)
return queryResult
def TO_BASE64(x):
return x
def FROM_UNIXTIME(x, y):
return 'strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x)
for p in [
mock.patch("sqlobject.sqlbuilder.func.TO_BASE64",TO_BASE64),
mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME",FROM_UNIXTIME),
]:
p.start()
retrieve('Some')
mock.patch.stopall()
By default, sqlbuilder.func is an SQLExpression that passes its attribute (sqlbuilder.func.datetime, e.g.) to the SQL backend as a constant (sqlbuilder.func actually is an alias for sqlbuilder.ConstantSpace). See the docs about SQLExpression, the FAQ and the code for func.
When you mock an attribute in func namespace it's evaluated by SQLObject and passed to the backend in reduced form. If you want to return a string literal from the mocking function you need to tell SQLObject it's a value that has to be passed to the backend as is, unevaluated. The way to do it is to wrap the literal in SQLConstant like this:
def FROM_UNIXTIME(x, y):
return sqlbuilder.SQLConstant('strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x))
See SQLConstant.
The entire test script now looks this
#! /usr/bin/env python3.7
from sqlobject import *
__connection__ = "sqlite:/:memory:?debug=1&debugOutput=1"
try:
import mock
from mock import MagicMock
except ImportError:
from unittest import mock
from unittest.mock import MagicMock
class Store(SQLObject):
name = StringCol()
sample = BLOBCol()
createdAt = TimestampCol()
Store.createTable()
DATE_FORMAT = "%Y-%m-%d"
def retrieve(name):
query = sqlbuilder.Select([
sqlbuilder.func.TO_BASE64(Store.q.sample),
],
sqlbuilder.AND(
Store.q.name == name,
sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
)
)
connection = Store._connection
query = connection.sqlrepr(query)
queryResult = connection.queryAll(query)
return queryResult
def TO_BASE64(x):
return x
def FROM_UNIXTIME(x, y):
return sqlbuilder.SQLConstant('strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x))
for p in [
mock.patch("sqlobject.sqlbuilder.func.TO_BASE64",TO_BASE64),
mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME",FROM_UNIXTIME),
]:
p.start()
retrieve('Some')
mock.patch.stopall()
The output is:
1/Query : CREATE TABLE store (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
sample TEXT,
created_at TIMESTAMP
)
1/QueryR : CREATE TABLE store (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
sample TEXT,
created_at TIMESTAMP
)
2/QueryAll: SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND ((strftime("%Y%m%d", datetime(store.created_at,"unixepoch", "localtime"))) >= (strftime("%Y%m%d", datetime(2018-10-12,"unixepoch", "localtime")))))
2/QueryR : SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND ((strftime("%Y%m%d", datetime(store.created_at,"unixepoch", "localtime"))) >= (strftime("%Y%m%d", datetime(2018-10-12,"unixepoch", "localtime")))))
2/QueryAll-> []
PS. Full disclosure: I'm the current maintainer of SQLObject.
As #phd pointed that SQLObject evaluate the expression before passing it to backend in reducted form.
Then we can also pass expression directly which SQLObject will evaluate so instead of passing string literal we can also do as below
def FROM_UNIXTIME(x, y):
return sqlbuilder.func.strftime("%Y%m%d", sqlbuilder.func.datetime(x, "unixepoch", "localtime"))
Output:
SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND ((strftime("%Y%m%d", datetime(store.created_at,"unixepoch", "localtime"))) >= (strftime("%Y%m%d", datetime(2018-10-12,"unixepoch", "localtime")))))
I am implementing unit test on one of the classes of my project. The method that I want to test is queryCfsNoteVariations:
class PdfRaportDaoImpl:
def queryCfsNoteVariations(self, reportId):
sql = """
select v.* from item_value_table v
where v.table_id in
(select table_id from table_table t
where t.report_id=%s and table_name='CFS')
"""
cfsItemList = dbFind(sql, (reportId))
sql = "select * from variations_cfs_note"
cfsNoteVariations = dbFind(sql)
if cfsNoteVariations == None or len(cfsNoteVariations) == 0:
raise Exception("cfs note variations is null!")
cfsNoteVariationList = []
for itemInfo in cfsItemList:
for cfsNoteVariation in cfsNoteVariations:
if (
cfsNoteVariation["item_name_cfs"].lower()
== itemInfo["item_name"].lower()
):
cfsNoteVariationList.append(cfsNoteVariation["item_name_cfs_note"])
if len(cfsNoteVariationList) > 0:
return cfsNoteVariationList, itemInfo["item_note"]
return None, None
Which has a path: /com/pdfgather/PDFReportDao.py
In my test I am doing patch on dbFind() method which is located in /com/pdfgather/GlobalHelper.py. My current test looks like this:
from com.pdfgather.PDFReportDao import PdfReportDaoImpl
#patch("com.pdfgather.GlobalHelper.dbFind")
def test_query_cfs_note_variations(self, mock_find):
mock_find.side_effect = iter([
[{"item_name" : "foo"}, {"item_name" : "hey"}],
[{"item_name_cfs": "foo"},
{"item_name_cfs": "foo"},
{"item_name_cfs": "hey"}]]
])
report_id = 3578
result = TestingDao.dao.queryCfsNoteVariations(report_id)
# Printing result
print(result)
However I am not getting my desired result which is getting inside a loop and returning from inside a loop. Instead the dbFind is returning nothing (but it shouldn't as I already preassigned returning values for dbFind).
Thanks in advance!
Python refers com.pdfgather.PDFReportDao.dbFind and com.pdfgather.GlobalHelper.dbFind as two different classes. The second one is the import you want to patch. Try changing your patch to:
#patch("com.pdfgather.PDFReportDao.dbFind")
Newby to python here.
My class uses a database connection to wrap some functions. I have figured out some basic examples successfully. For the more complex library that I am working with, I cannot find close examples of mocking the database connection. In mine, the
class DBSAccess():
def __init__(self, db_con):
self.db_con = db_con
def get_db_perm(self, target_user):
## this is where I start having trouble
with self.db_con.cursor() as cursor:
cursor.execute("SELECT CAST(sum(maxperm) AS bigint) \
FROM dbc.diskspace \
WHERE databasename = '%s' \
GROUP BY databasename" % (target_user))
res = cursor.fetchone()
if res is not None:
return res[0]
else:
msg = target_user + " does not exist"
return msg
where db_con is a teradata.UdaExec returns a connection
udaExec = teradata.UdaExec (appName="whatever", version="1.0", logConsole=True)
db_con = udaExec.connect(method="odbc", system='my_sys', username='my_name', password='my_pswd')
dbc_instance = tdtestpy.DBSaccess (db_con)
So for my test to not use any real connection, I have to mock some things out. I tried this combination:
class DBAccessTest(unittest.TestCase):
def test_get_db_free_perm_expects_500(self):
uda_exec = mock.Mock(spec=teradata.UdaExec)
db_con = MagicMock(return_value=None)
db_con.cursor.fetchone.return_value = [500]
uda_exec.connect.return_value = db_con
self.dbc_instance = DBSAccess(db_con)
self.assertEqual(self.dbc_instance.get_db_free_perm("dbc"), 500)
but my result is messed up because fetchone is returning a mock, not the [500] one item list I was expecting:
AssertionError: <MagicMock name='mock.connect().cursor().[54 chars]312'> != 500
I've found some examples where there is a 'with block' for testing an OS operation, but nothing with database. Plus, I don't know what data type the db_con.cursor is so I can't spec that precisely - I think the cursor is found in UdaExecConnection.cursor() found at Teradata/PyTd.
I need to know how to mock the response that will allow me to test the logic within my method.
The source of your problem is in the following line:
with self.db_con.cursor() as cursor:
with lines calls __enter__ method, which generate in your case a new mock.
The solution is to mock __enter__ method:
db_con.cursor.return_value.__enter__.return_value = cursor
Your tests:
class DBAccessTest(unittest.TestCase):
def test_get_db_free_perm_expects_500(self):
db_con = MagicMock(UdaExecConnection)
cursor = MagicMock(UdaExecCursor)
cursor.fetchone.return_value = [500]
db_con.cursor.return_value.__enter__.return_value = cursor
self.dbc_instance = DBSAccess(db_con)
self.assertEqual(self.dbc_instance.get_db_perm("dbc"), 500)
def test_get_db_free_perm_expects_None(self):
db_con = MagicMock(UdaExecConnection)
cursor = MagicMock(UdaExecCursor)
cursor.fetchone.return_value = None
db_con.cursor.return_value.__enter__.return_value = cursor
self.dbc_instance = DBSAccess(db_con)
self.assertEqual(self.dbc_instance.get_db_perm("dbc"), "dbc does not exist")
By default, cx_Oracle returns each row as a tuple.
>>> import cx_Oracle
>>> conn=cx_Oracle.connect('scott/tiger')
>>> curs=conn.cursor()
>>> curs.execute("select * from foo");
>>> curs.fetchone()
(33, 'blue')
How can I return each row as a dictionary?
You can override the cursor's rowfactory method. You will need to do this each time you perform the query.
Here's the results of the standard query, a tuple.
curs.execute('select * from foo')
curs.fetchone()
(33, 'blue')
Returning a named tuple:
def makeNamedTupleFactory(cursor):
columnNames = [d[0].lower() for d in cursor.description]
import collections
Row = collections.namedtuple('Row', columnNames)
return Row
curs.rowfactory = makeNamedTupleFactory(curs)
curs.fetchone()
Row(x=33, y='blue')
Returning a dictionary:
def makeDictFactory(cursor):
columnNames = [d[0] for d in cursor.description]
def createRow(*args):
return dict(zip(columnNames, args))
return createRow
curs.rowfactory = makeDictFactory(curs)
curs.fetchone()
{'Y': 'brown', 'X': 1}
Credit to Amaury Forgeot d'Arc:
http://sourceforge.net/p/cx-oracle/mailman/message/27145597
A very short version:
curs.rowfactory = lambda *args: dict(zip([d[0] for d in curs.description], args))
Tested on Python 3.7.0 & cx_Oracle 7.1.2
Old question but adding some helpful links with a Python recipe
According to cx_Oracle documentation:
Cursor.rowfactory
This read-write attribute specifies a method to call for each row that
is retrieved from the database. Ordinarily a tuple is returned for
each row but if this attribute is set, the method is called with the
tuple that would normally be returned, and the result of the method is
returned instead.
The cx_Oracle - Python Interface for Oracle Database Also points to GitHub repository for lots of helpful sample examples. Please check GenericRowFactory.py.
Googled: This PPT can be further helpful: [PDF]CON6543 Python and Oracle Database - RainFocus
Recipe
Django database backend for Oracle under the hood uses cx_Oracle. In earlier versions ( Django 1.11- ) they have written _rowfactory(cursor, row) That also cast cx_Oracle's numeric data types into relevant Python data and strings into unicode.
If you have installed Django Please check base.py as follows:
$ DJANGO_DIR="$(python -c 'import django, os; print(os.path.dirname(django.__file__))')"
$ vim $DJANGO_DIR/db/backends/oracle/base.py
One can borrow _rowfactory() from $DJANGO_DIR/db/backends/oracle/base.py and can apply below decorator naming to make it return namedtuple instead of simple tuple.
mybase.py
import functools
from itertools import izip, imap
from operator import itemgetter
from collections import namedtuple
import cx_Oracle as Database
import decimal
def naming(rename=False, case=None):
def decorator(rowfactory):
#functools.wraps(rowfactory)
def decorated_rowfactory(cursor, row, typename="GenericRow"):
field_names = imap(case, imap(itemgetter(0), cursor.description))
return namedtuple(typename, field_names)._make(rowfactory(cursor, row))
return decorated_rowfactory
return decorator
use it as:
#naming(rename=False, case=str.lower)
def rowfactory(cursor, row):
casted = []
....
....
return tuple(casted)
oracle.py
import cx_Oracle as Database
from cx_Oracle import *
import mybase
class Cursor(Database.Cursor):
def execute(self, statement, args=None):
prepareNested = (statement is not None and self.statement != statement)
result = super(self.__class__, self).execute(statement, args or [])
if prepareNested:
if self.description:
self.rowfactory = lambda *row: mybase.rowfactory(self, row)
return result
def close(self):
try:
super(self.__class__, self).close()
except Database.InterfaceError:
"already closed"
class Connection(Database.Connection):
def cursor(self):
Cursor(self)
connect = Connection
Now, instead of import cx_oracle import oracle in user script as:
user.py
import oracle
dsn = oracle.makedsn('HOSTNAME', 1521, service_name='dev_server')
db = connect('username', 'password', dsn)
cursor = db.cursor()
cursor.execute("""
SELECT 'Grijesh' as FirstName,
'Chauhan' as LastName,
CAST('10560.254' AS NUMBER(10, 2)) as Salary
FROM DUAL
""")
row = cursor.fetchone()
print ("First Name is %s" % row.firstname) # => Grijesh
print ("Last Name is %s" % row.lastname) # => Chauhan
print ("Salary is %r" % row.salary) # => Decimal('10560.25')
Give it a Try!!
Building up on answer by #maelcum73 :
curs.rowfactory = lambda *args: dict(zip([d[0] for d in curs.description], args))
The issue with this solution is that you need to re-set this after every execution.
Going one step further, you can create a shell around the cursor object like so:
class dictcur(object):
# need to monkeypatch the built-in execute function to always return a dict
def __init__(self, cursor):
self._original_cursor = cursor
def execute(self, *args, **kwargs):
# rowfactory needs to be set AFTER EACH execution!
self._original_cursor.execute(*args, **kwargs)
self._original_cursor.rowfactory = lambda *a: dict(
zip([d[0] for d in self._original_cursor.description], a)
)
# cx_Oracle's cursor's execute method returns a cursor object
# -> return the correct cursor in the monkeypatched version as well!
return self._original_cursor
def __getattr__(self, attr):
# anything other than the execute method: just go straight to the cursor
return getattr(self._original_cursor, attr)
dict_cursor = dictcur(cursor=conn.cursor())
Using this dict_cursor, every subsequent dict_cursor.execute() call will return a dictionary. Note: I tried monkeypatching the execute method directly, however that was not possible because it is a built-in method.