Using z3 where constraint depends on output of function - python

I want to use z3 to solve this case. The input is a 10 character string. Each character of the input is a printable character (ASCII). The input should be such that when calc2() function is called with input as a parameter, the result should be: 0x0009E38E1FB7629B.
How can I use z3py in such cases?
Usually I would just add independent equations as a constraint to z3. In this case, I am not sure how to use z3.
def calc2(input):
result = 0
for i in range(len(input)):
r1 = (result << 0x5) & 0xffffffffffffffff
r2 = result >> 0x1b
r3 = (r1 ^ r2)
result = (r3 ^ ord(input[i]))
return result
if __name__ == "__main__":
input = sys.argv[1]
result = calc2(input)
if result == 0x0009E38E1FB7629B:
print "solved"
Update: I tried the following however it does not give me correct answer:
from z3 import *
def calc2(input):
result = 0
for i in range(len(input)):
r1 = (result << 0x5) & 0xffffffffffffffff
r2 = result >> 0x1b
r3 = (r1 ^ r2)
result = r3 ^ Concat(BitVec(0, 56), input[i])
return result
if __name__ == "__main__":
s = Solver()
X = [BitVec('x' + str(i), 8) for i in range(10)]
s.add(calc2(X) == 0x0009E38E1FB7629B)
if s.check() == sat:
print(s.model())

I hope this isn't homework, but here's one way to go about it:
from z3 import *
s = Solver()
# Input is 10 character long; represent with 10 8-bit symbolic variables
input = [BitVec("input%s" % i, 8) for i in range(10)]
# Make sure each character is printable ASCII, i.e., between 0x20 and 0x7E
for i in range(10):
s.add(input[i] >= 0x20)
s.add(input[i] <= 0x7E)
def calc2(input):
# result is a 64-bit value
result = BitVecVal(0, 64)
for i in range(len(input)):
# NB. We don't actually need to mask with 0xffffffffffffffff
# Since we explicitly have a 64-bit value in result.
# But it doesn't hurt to mask it, so we do it here.
r1 = (result << 0x5) & 0xffffffffffffffff
r2 = result >> 0x1b
r3 = r1 ^ r2
# We need to zero-extend to match sizes
result = r3 ^ ZeroExt(56, input[i])
return result
# Assert the required equality
s.add(calc2(input) == 0x0009E38E1FB7629B)
# Check and get model
print s.check()
m = s.model()
# reconstruct the string:
s = ''.join([chr (m[input[i]].as_long()) for i in range(10)])
print s
This prints:
$ python a.py
sat
L`p:LxlBVU
Looks like your secret string is
"L`p:LxlBVU"
I've put in some comments in the program to help you with how things are coded in z3py, but feel free to ask for clarification. Hope this helps!
Getting all solutions
To get other solutions, you simply loop and assert that the solution shouldn't be the previous one. You can use the following while loop after the assertion:
while s.check() == sat:
m = s.model()
print ''.join([chr (m[input[i]].as_long()) for i in range(10)])
s.add(Or([input[i] != m[input[i]] for i in range(10)]))
When I ran it, it kept going! You might want to stop it after a while.

You can encode calc2 in Z3. you'll need to unroll the loop for 1,2,3,4,..,n times (for n = max input size expected), but that's it.
(You don't actually need to unroll the loop, you can use z3py to create the constraints)

Related

z3py: Symbolic expressions cannot be cast to concrete Boolean values?

from z3 import *
x = Real('x')
s = Solver()
s.add(x > 1 or x < -1)
print(s.check())
if s.check() == sat:
print(s.model())
I want to solve a or expressions , how can i do it?
when z3 told me "Symbolic expressions cannot be cast to concrete Boolean values"
Python's or is not symbolic aware. Instead, use z3py's Or:
from z3 import *
x = Real('x')
s = Solver()
s.add(Or(x > 1, x < -1))
r = s.check()
print(r)
if r == sat:
print(s.model())
This prints:
sat
[x = -2]
Note that I'd also avoid two separate calls to check, by storing the result in a variable first. (Which I called r above.) In general, the second call to check will be cheap since you haven't added any constraints after the first, but this makes the intention clearer.

Z3 String/Char xor?

I'm working with Z3 in Python and am trying to figure out how to do String operations. In general, I've played around with z3.String as the object, doing things like str1 + str2 == 'hello world'. However, I have been unable to accomplish the following behavior:
solver.add(str1[1] ^ str1[2] == 12) # -- or --
solver.add(str1[1] ^ str1[2] == str2[1])
So basically add the constraint that character 1 xor character 2 equals 12. My understanding is that the string is defined as a sequence of 8-bit BitVectors under the hood, and BitVectors should be able to be xor'd.
Thanks!
So far I don't expose ways to access characters with a function. You would have to define auxiliary functions and axioms that capture extraction. The operator [] extracts a sub-sequence, which is of length 1 if the index is within bounds.
Here is a way to access the elements:
from z3 import *
nth = Function('nth', StringSort(), IntSort(), BitVecSort(8))
k = Int('k')
str1, str2, s = Strings('str1 str2 s')
s = Solver()
s.add(ForAll([str1, k], Implies(And(0 <= k, k < Length(str1)), Unit(nth(str1, k)) == str1[k])))
s.add( ((nth(str1, 1)) ^ (nth(str2, 2))) == 12)

Convert hex string to hex number in python without loss of precision

So, I am using the answer to this question to color some values I have for some polygons to plot to a basemap instance. I modified the function found in that link to be the following. The issue I'm having is that I have to convert the string that it returns to a hex digit to use so that I can color the polygons. But when I convert something like "0x00ffaa" to a python hex digit, it changes it to be "0xffaa", which cannot be used to color the polygon
How can I get around this?
Here is the modified function:
def rgb(mini,maxi,value):
mini, maxi, value = float(mini), float(maxi), float(value)
ratio = 2* (value - mini) / (maxi-mini)
b = int(max(0,255*(1-ratio)))
r = int(max(0,255*(ratio -1)))
g = 255 - b - r
b = hex(b)
r = hex(r)
g = hex(g)
if len(b) == 3:
b = b[0:2] + '0' + b[-1]
if len(r) == 3:
r = r[0:2] + '0' + r[-1]
if len(g) == 3:
g = g[0:2] + '0' + g[-1]
string = r+g[2:]+b[2:]
return string
The answer from cdarke is OK, but using the % operator for string interpolation is kind of deprecated. For the sake of completion, here is the format function or the str.format method:
>>> format(254, '06X')
'0000FE'
>>> "#{:06X}".format(255)
'#0000FF'
New code is expected to use one of the above instead of the % operator. If you are curious about "why does Python have a format function as well as a format method?", see my answer to this question.
But usually you don't have to worry about the representation of the value if the function/method you are using takes integers as well as strings, because in this case the string '0x0000AA' is the same as the integer value 0xAA or 170.
Use string formatting, for example:
>>> "0x%08x" % 0xffaa
'0x0000ffaa'

How to keep leading zeros in binary integer (python)?

I need to calculate a checksum for a hex serial word string using XOR. To my (limited) knowledge this has to be performed using the bitwise operator ^. Also, the data has to be converted to binary integer form. Below is my rudimentary code - but the checksum it calculates is 1000831. It should be 01001110 or 47hex. I think the error may be due to missing the leading zeros. All the formatting I've tried to add the leading zeros turns the binary integers back into strings. I appreciate any suggestions.
word = ('010900004f')
#divide word into 5 separate bytes
wd1 = word[0:2]
wd2 = word[2:4]
wd3 = word[4:6]
wd4 = word[6:8]
wd5 = word[8:10]
#this converts a hex string to a binary string
wd1bs = bin(int(wd1, 16))[2:]
wd2bs = bin(int(wd2, 16))[2:]
wd3bs = bin(int(wd3, 16))[2:]
wd4bs = bin(int(wd4, 16))[2:]
#this converts binary string to binary integer
wd1i = int(wd1bs)
wd2i = int(wd2bs)
wd3i = int(wd3bs)
wd4i = int(wd4bs)
wd5i = int(wd5bs)
#now that I have binary integers, I can use the XOR bitwise operator to cal cksum
checksum = (wd1i ^ wd2i ^ wd3i ^ wd4i ^ wd5i)
#I should get 47 hex as the checksum
print (checksum, type(checksum))
Why use all this conversions and the costly string functions?
(I will answer the X part of your XY-Problem, not the Y part.)
def checksum (s):
v = int (s, 16)
checksum = 0
while v:
checksum ^= v & 0xff
v >>= 8
return checksum
cs = checksum ('010900004f')
print (cs, bin (cs), hex (cs) )
Result is 0x47 as expected. Btw 0x47 is 0b1000111 and not as stated 0b1001110.
s = '010900004f'
b = int(s, 16)
print reduce(lambda x, y: x ^ y, ((b>> 8*i)&0xff for i in range(0, len(s)/2)), 0)
Just modify like this.
before:
wd1i = int(wd1bs)
wd2i = int(wd2bs)
wd3i = int(wd3bs)
wd4i = int(wd4bs)
wd5i = int(wd5bs)
after:
wd1i = int(wd1bs, 2)
wd2i = int(wd2bs, 2)
wd3i = int(wd3bs, 2)
wd4i = int(wd4bs, 2)
wd5i = int(wd5bs, 2)
Why your code doesn't work?
Because you are misunderstanding int(wd1bs) behavior.
See doc here. So Python int function expect wd1bs is 10 base by default.
But you expect int function to treat its argument as 2 base.
So you need to write as int(wd1bs, 2)
Or you can also rewrite your entire code like this. So you don't need to use bin function in this case. And this code is basically same as #Hyperboreus answer. :)
w = int('010900004f', 16)
w1 = (0xff00000000 & w) >> 4*8
w2 = (0x00ff000000 & w) >> 3*8
w3 = (0x0000ff0000 & w) >> 2*8
w4 = (0x000000ff00 & w) >> 1*8
w5 = (0x00000000ff & w)
checksum = w1 ^ w2 ^ w3 ^ w4 ^ w5
print hex(checksum)
#'0x47'
And this is more shorter one.
import binascii
word = '010900004f'
print hex(reduce(lambda a, b: a ^ b, (ord(i) for i in binascii.unhexlify(word))))
#0x47

Developing a heuristic to test simple anonymous Python functions for equivalency

I know how function comparison works in Python 3 (just comparing address in memory), and I understand why.
I also understand that "true" comparison (do functions f and g return the same result given the same arguments, for any arguments?) is practically impossible.
I am looking for something in between. I want the comparison to work on the simplest cases of identical functions, and possibly some less trivial ones:
lambda x : x == lambda x : x # True
lambda x : 2 * x == lambda y : 2 * y # True
lambda x : 2 * x == lambda x : x * 2 # True or False is fine, but must be stable
lambda x : 2 * x == lambda x : x + x # True or False is fine, but must be stable
Note that I'm interested in solving this problem for anonymous functions (lambda), but wouldn't mind if the solution also works for named functions.
The motivation for this is that inside blist module, it would be nice to verify that two sortedset instances have the same sort function before performing a union, etc. on them.
Named functions are of less interest because I can assume them to be different when they are not identical. After all, suppose someone created two sortedsets with a named function in the key argument. If they intend these instances to be "compatible" for the purposes of set operations, they'd probably use the same function, rather than two separate named functions that perform identical operations.
I can only think of three approaches. All of them seem hard, so any ideas appreciated.
Comparing bytecodes might work but it might be annoying that it's implementation dependent (and hence the code that worked on one Python breaks on another).
Comparing tokenized source code seems reasonable and portable. Of course, it's less powerful (since identical functions are more likely to be rejected).
A solid heuristic borrowed from some symbolic computation textbook is theoretically the best approach. It might seem too heavy for my purpose, but it actually could be a good fit since lambda functions are usually tiny and so it would run fast.
EDIT
A more complicated example, based on the comment by #delnan:
# global variable
fields = ['id', 'name']
def my_function():
global fields
s1 = sortedset(key = lambda x : x[fields[0].lower()])
# some intervening code here
# ...
s2 = sortedset(key = lambda x : x[fields[0].lower()])
Would I expect the key functions for s1 and s2 to evaluate as equal?
If the intervening code contains any function call at all, the value of fields may be modified, resulting in different key functions for s1 and s2. Since we clearly won't be doing control flow analysis to solve this problem, it's clear that we have to evaluate these two lambda functions as different, if we are trying to perform this evaluation before runtime. (Even if fields wasn't global, it might have been had another name bound to it, etc.) This would severely curtail the usefulness of this whole exercise, since few lambda functions would have no dependence on the environment.
EDIT 2:
I realized it's very important to compare the function objects as they exist in runtime. Without that, all the functions that depend on variables from outer scope cannot be compared; and most useful functions do have such dependencies. Considered in runtime, all functions with the same signature are comparable in a clean, logical way, regardless of what they depend on, whether they are impure, etc.
As a result, I need not just the bytecode but also the global state as of the time the function object was created (presumably __globals__). Then I have to match all variables from outer scope to the values from __globals__.
Edited to check whether external state will affect the sorting function as well as if the two functions are equivalent.
I hacked up dis.dis and friends to output to a global file-like object. I then stripped out line numbers and normalized variable names (without touching constants) and compared the result.
You could clean this up so dis.dis and friends yielded out lines so you wouldn't have to trap their output. But this is a working proof-of-concept for using dis.dis for function comparison with minimal changes.
import types
from opcode import *
_have_code = (types.MethodType, types.FunctionType, types.CodeType,
types.ClassType, type)
def dis(x):
"""Disassemble classes, methods, functions, or code.
With no argument, disassemble the last traceback.
"""
if isinstance(x, types.InstanceType):
x = x.__class__
if hasattr(x, 'im_func'):
x = x.im_func
if hasattr(x, 'func_code'):
x = x.func_code
if hasattr(x, '__dict__'):
items = x.__dict__.items()
items.sort()
for name, x1 in items:
if isinstance(x1, _have_code):
print >> out, "Disassembly of %s:" % name
try:
dis(x1)
except TypeError, msg:
print >> out, "Sorry:", msg
print >> out
elif hasattr(x, 'co_code'):
disassemble(x)
elif isinstance(x, str):
disassemble_string(x)
else:
raise TypeError, \
"don't know how to disassemble %s objects" % \
type(x).__name__
def disassemble(co, lasti=-1):
"""Disassemble a code object."""
code = co.co_code
labels = findlabels(code)
linestarts = dict(findlinestarts(co))
n = len(code)
i = 0
extended_arg = 0
free = None
while i < n:
c = code[i]
op = ord(c)
if i in linestarts:
if i > 0:
print >> out
print >> out, "%3d" % linestarts[i],
else:
print >> out, ' ',
if i == lasti: print >> out, '-->',
else: print >> out, ' ',
if i in labels: print >> out, '>>',
else: print >> out, ' ',
print >> out, repr(i).rjust(4),
print >> out, opname[op].ljust(20),
i = i+1
if op >= HAVE_ARGUMENT:
oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
extended_arg = 0
i = i+2
if op == EXTENDED_ARG:
extended_arg = oparg*65536L
print >> out, repr(oparg).rjust(5),
if op in hasconst:
print >> out, '(' + repr(co.co_consts[oparg]) + ')',
elif op in hasname:
print >> out, '(' + co.co_names[oparg] + ')',
elif op in hasjrel:
print >> out, '(to ' + repr(i + oparg) + ')',
elif op in haslocal:
print >> out, '(' + co.co_varnames[oparg] + ')',
elif op in hascompare:
print >> out, '(' + cmp_op[oparg] + ')',
elif op in hasfree:
if free is None:
free = co.co_cellvars + co.co_freevars
print >> out, '(' + free[oparg] + ')',
print >> out
def disassemble_string(code, lasti=-1, varnames=None, names=None,
constants=None):
labels = findlabels(code)
n = len(code)
i = 0
while i < n:
c = code[i]
op = ord(c)
if i == lasti: print >> out, '-->',
else: print >> out, ' ',
if i in labels: print >> out, '>>',
else: print >> out, ' ',
print >> out, repr(i).rjust(4),
print >> out, opname[op].ljust(15),
i = i+1
if op >= HAVE_ARGUMENT:
oparg = ord(code[i]) + ord(code[i+1])*256
i = i+2
print >> out, repr(oparg).rjust(5),
if op in hasconst:
if constants:
print >> out, '(' + repr(constants[oparg]) + ')',
else:
print >> out, '(%d)'%oparg,
elif op in hasname:
if names is not None:
print >> out, '(' + names[oparg] + ')',
else:
print >> out, '(%d)'%oparg,
elif op in hasjrel:
print >> out, '(to ' + repr(i + oparg) + ')',
elif op in haslocal:
if varnames:
print >> out, '(' + varnames[oparg] + ')',
else:
print >> out, '(%d)' % oparg,
elif op in hascompare:
print >> out, '(' + cmp_op[oparg] + ')',
print >> out
def findlabels(code):
"""Detect all offsets in a byte code which are jump targets.
Return the list of offsets.
"""
labels = []
n = len(code)
i = 0
while i < n:
c = code[i]
op = ord(c)
i = i+1
if op >= HAVE_ARGUMENT:
oparg = ord(code[i]) + ord(code[i+1])*256
i = i+2
label = -1
if op in hasjrel:
label = i+oparg
elif op in hasjabs:
label = oparg
if label >= 0:
if label not in labels:
labels.append(label)
return labels
def findlinestarts(code):
"""Find the offsets in a byte code which are start of lines in the source.
Generate pairs (offset, lineno) as described in Python/compile.c.
"""
byte_increments = [ord(c) for c in code.co_lnotab[0::2]]
line_increments = [ord(c) for c in code.co_lnotab[1::2]]
lastlineno = None
lineno = code.co_firstlineno
addr = 0
for byte_incr, line_incr in zip(byte_increments, line_increments):
if byte_incr:
if lineno != lastlineno:
yield (addr, lineno)
lastlineno = lineno
addr += byte_incr
lineno += line_incr
if lineno != lastlineno:
yield (addr, lineno)
class FakeFile(object):
def __init__(self):
self.store = []
def write(self, data):
self.store.append(data)
a = lambda x : x
b = lambda x : x # True
c = lambda x : 2 * x
d = lambda y : 2 * y # True
e = lambda x : 2 * x
f = lambda x : x * 2 # True or False is fine, but must be stable
g = lambda x : 2 * x
h = lambda x : x + x # True or False is fine, but must be stable
funcs = a, b, c, d, e, f, g, h
outs = []
for func in funcs:
out = FakeFile()
dis(func)
outs.append(out.store)
import ast
def outfilter(out):
for i in out:
if i.strip().isdigit():
continue
if '(' in i:
try:
ast.literal_eval(i)
except ValueError:
i = "(x)"
yield i
processed_outs = [(out, 'LOAD_GLOBAL' in out or 'LOAD_DECREF' in out)
for out in (''.join(outfilter(out)) for out in outs)]
for (out1, polluted1), (out2, polluted2) in zip(processed_outs[::2], processed_outs[1::2]):
print 'Bytecode Equivalent:', out1 == out2, '\nPolluted by state:', polluted1 or polluted2
The output is True, True, False, and False and is stable. The "Polluted" bool is true if the output will depend on external state -- either global state or a closure.
So, let's address some technical issues first.
1) Byte code: it is probably not an problem because, instead of inspecting the pyc (the binary files), you can use dis module to get the "bytecode". e.g.
>>> f = lambda x, y : x+y
>>> dis.dis(f)
1 0 LOAD_FAST 0 (x)
3 LOAD_FAST 1 (y)
6 BINARY_ADD
7 RETURN_VALUE
No need to worry about platform.
2) Tokenized source code. Again python has all you need to do the job. You can use the ast module to parse the code and obtain the ast.
>>> a = ast.parse("f = lambda x, y : x+y")
>>> ast.dump(a)
"Module(body=[Assign(targets=[Name(id='f', ctx=Store())], value=Lambda(args=arguments(args=[Name(id='x', ctx=Param()), Name(id='y', ctx=Param())], vararg=None, kwarg=None, defaults=[]), body=BinOp(left=Name(id='x', ctx=Load()), op=Add(), right=Name(id='y', ctx=Load()))))])"
So, the question we should really address is: is it feasible to determine that two functions are equivalent analytically?
It is easy for human to say 2*x equals to x+x, but how can we create an algorithm to prove it?
If it is what you want to achieve, you may want to check this out: http://en.wikipedia.org/wiki/Computer-assisted_proof
However, if ultimately you simply want to assert two different data set are sorted in the same order, you just need to run the sort function A on dataset B and vice versa, and then check the outcome. If they are identical, then the functions are probably functionally identical. Of course, the check is only valid for the said datasets.

Categories

Resources