Source code for cas.expressions

# -*- coding: utf-8 -*-

# This code is part of Amoco
# Copyright (C) 2006-2011 Axel Tillequin (bdcht3@gmail.com)
# published under GPLv2 license
"""
cas/expressions.py
==================

The expressions module implements all above :class:`exp` classes.
All symbolic representation of data in amoco rely on these expressions.
"""

from amoco.config import conf
from amoco.logger import Log

logger = Log(__name__)
logger.debug("loading module")
from amoco.ui import render
import operator


# decorators:
# ------------


def _checkarg1_exp(f):
    def checkarg1_exp(*args):
        if len(args) > 0 and isinstance(args[0], exp):
            return f(*args)
        else:
            logger.error("first arg is not an expression")
            raise TypeError(args)

    return checkarg1_exp


def _checkarg_sizes(f):
    def checkarg_sizes(self, n):
        if self.size != n.size:
            if self.size > 0 and n.size > 0:
                logger.error("size mismatch")
                raise ValueError(n)
        return f(self, n)

    return checkarg_sizes


def _checkarg_numeric(f):
    def checkarg_numeric(self, n):
        if isinstance(n, int):
            n = cst(n, self.size)
        elif isinstance(n, (float)):
            n = cfp(n, self.size)
        return f(self, n)

    return checkarg_numeric


def _checkarg_slice(f):
    def checkarg_slice(self, *args):
        i = args[0]
        if isinstance(i, slice):
            if i.step != None:
                raise ValueError(i)
            if i.start < 0 or i.stop > self.size:
                logger.error("size mismatch")
                raise ValueError(i)
            if i.stop <= i.start:
                logger.error("invalid slice")
                raise ValueError(i)
        else:
            logger.error("argument should be a slice")
            raise TypeError(i)
        return f(self, *args)

    return checkarg_slice

# expression types:

et_cst = 0x00001
et_reg = 0x00002
# note:  0x000#0 is for reg subtypes (STD/PC/FLAG/STACK/OTHER)
et_slc = 0x00100
et_ext = 0x00200
et_lab = 0x00400
et_mem = 0x00800
et_ptr = 0x01000
et_tst = 0x02000
et_eqn = 0x04000
et_vec = 0x08000
et_cmp = 0x10000
et_msk = 0x1ffff

# ------------------------------------------------------------------------------
# exp is the core class for all expressions.
# It defines mandatory attributes, shared methods like dumps/loads, etc.
# ------------------------------------------------------------------------------

[docs]class exp(object): """the core class for all expressions. It defines mandatory attributes, shared methods like dumps/loads etc. Attributes: size (int): the bit size of the expression (default is 0.) sf (Bool): the sign flag of the expression (default is False: unsigned.) length (int): the byte size of the expression. mask (int): the bit mask of the expression. Note: len(exp) returns the byte size, assuming that size is a multiple of 8. """ etype = 0 __slots__ = ["size", "sf"] def __init__(self, size=0, sf=False): self.size = size self.sf = False def __len__(self): return self.length
[docs] def signed(self): "consider expression as signed" self.sf = True return self
[docs] def unsigned(self): "consider expression as unsigned" self.sf = False return self
@property def length(self): # length value is in bytes return self.size // 8 @property def mask(self): return (1 << self.size) - 1
[docs] def eval(self, env): "evalute expression in given :class:`mapper` env" if self._is_top: return top(self.size) if not self._is_def: return exp(self.size) else: raise NotImplementedError("can't eval %s" % self)
[docs] def simplify(self, **kargs): "simplify expression based on predefined heuristics" return self
[docs] def depth(self): "depth size of the expression tree" return 1.0
def addr(self, env): raise TypeError("exp has no address")
[docs] def dumps(self): "pickle expression" from pickle import dumps, HIGHEST_PROTOCOL return dumps(self, HIGHEST_PROTOCOL)
[docs] def loads(self, s): "unpickle expression" from pickle import loads self = loads(s) return self
def __unicode__(self): if self._is_top: return render.icons.top+("%d" % self.size) if not self._is_def: return render.icons.bot+("%d" % self.size) raise ValueError("void expression") def __str__(self): res = self.__unicode__() try: return str(res) except UnicodeEncodeError: return res.encode("utf-8")
[docs] def toks(self, **kargs): "returns list of pretty printing tokens of the expression" return [(render.Token.Literal, "%s" % self)]
[docs] def pp(self, **kargs): "pretty-printed string of the expression" return render.highlight(self.toks(**kargs))
[docs] def bit(self, i): "extract i-th bit expression of the expression" i = i % self.size return self[i : i + 1]
[docs] def bytes(self, sta=0, sto=None, endian=1): """ returns the expression slice located at bytes [sta,sto] taking into account given endianess 1 (little) or -1 (big). Defaults to little endian. """ s = slice(sta, sto) l = self.length sta, sto, stp = s.indices(l) if endian == -1: sta, sto = l - sto, l - sta return self[sta * 8 : sto * 8]
# get item allows to extract the expression of a slice of the exp @_checkarg_slice def __getitem__(self, i): return slicer(self, i.start, i.stop - i.start) # set item allows to insert the expression of a slice in the exp # note: most child classes can't really inherit from this method # since the method makes sense only by returning an comp object # while __setitem__ is supposed to modify self... @_checkarg_slice def __setitem__(self, i, e): res = comp(self.size) res[0 : res.size] = self res[i.start : i.stop] = e return res.simplify()
[docs] def extend(self, sign, size): "extend expression to given size, taking sign into account" xt = size - self.size if xt <= 0: return self sb = self[self.size - 1 : self.size] if sign is True: xx = tst(sb, cst(-1, xt), cst(0, xt)) xx.sf = True else: xx = cst(0, xt) xx.sf = False return composer([self, xx])
[docs] def signextend(self, size): "sign extend expression to given size" return self.extend(True, size)
[docs] def zeroextend(self, size): "zero extend expression to given size" return self.extend(False, size)
# arithmetic / logic methods : These methods are shared by all nodes. # unary operators: def __invert__(self): return oper(OP_NOT, self) def __neg__(self): return oper(OP_MIN, self) def __pos__(self): return self # binary operators: @_checkarg_numeric def __add__(self, n): return oper(OP_ADD, self, n) @_checkarg_numeric def __sub__(self, n): return oper(OP_MIN, self, n) @_checkarg_numeric def __mul__(self, n): return oper(OP_MUL, self, n) @_checkarg_numeric def __pow__(self, n): return oper(OP_MUL2, self, n) @_checkarg_numeric def __truediv__(self, n): return oper(OP_DIV, self, n) @_checkarg_numeric def __div__(self, n): return oper(OP_DIV, self, n) @_checkarg_numeric def __truediv__(self, n): return oper(OP_DIV, self, n) @_checkarg_numeric def __mod__(self, n): return oper(OP_MOD, self, n) @_checkarg_numeric def __floordiv__(self, n): return oper(OP_ASR, self, n) @_checkarg_numeric def __and__(self, n): return oper(OP_AND, self, n) @_checkarg_numeric def __or__(self, n): return oper(OP_OR, self, n) @_checkarg_numeric def __xor__(self, n): return oper(OP_XOR, self, n) # reflected operand cases: @_checkarg_numeric def __radd__(self, n): return oper(OP_ADD, n, self) @_checkarg_numeric def __rsub__(self, n): return oper(OP_MIN, n, self) @_checkarg_numeric def __rmul__(self, n): return oper(OP_MUL, n, self) @_checkarg_numeric def __rpow__(self, n): return oper(OP_MUL2, n, self) @_checkarg_numeric def __rand__(self, n): return oper(OP_AND, n, self) @_checkarg_numeric def __ror__(self, n): return oper(OP_OR, n, self) @_checkarg_numeric def __rxor__(self, n): return oper(OP_XOR, n, self) # shifts: @_checkarg_numeric def __lshift__(self, n): return oper(OP_LSL, self, n) @_checkarg_numeric def __rshift__(self, n): return oper(OP_LSR, self, n) # WARNING: comparison operators cmp returns a python bool # but any other operators always return an expression ! def __hash__(self): return hash("%s" % self) + self.size # An expression defaults to False, and only bit1 will return True. def __bool__(self): return False def __eq__(self, n): # we inline checkarg_numeric only here: if isinstance(n, int): n = cst(n, self.size) elif isinstance(n, (float)): n = cfp(n, self.size) if hash(self) == hash(n): return bit1 return oper(OP_EQ, self, n) @_checkarg_numeric def __ne__(self, n): if hash(self) == hash(n): return bit0 return oper(OP_NEQ, self, n) @_checkarg_numeric def __lt__(self, n): if hash(self) == hash(n): return bit0 return oper(OP_LT, self, n) @_checkarg_numeric def __le__(self, n): if hash(self) == hash(n): return bit1 return oper(OP_LE, self, n) @_checkarg_numeric def __ge__(self, n): if hash(self) == hash(n): return bit1 return oper(OP_GE, self, n) @_checkarg_numeric def __gt__(self, n): if hash(self) == hash(n): return bit0 return oper(OP_GT, self, n)
[docs] def to_smtlib(self, solver=None): "translate expression to its smt form" logger.warning("no SMT solver defined") raise NotImplementedError
def is_(self,t): return t & self.etype def set_top(self): self.etype = ~(~self.etype & et_msk) @property def _is_def(self): return self.etype > 0 @property def _is_top(self): return self.etype < 0 @property def _is_cst(self): return et_cst & self.etype @property def _is_reg(self): return et_reg & self.etype @property def _is_cmp(self): return et_cmp & self.etype @property def _is_slc(self): return et_slc & self.etype @property def _is_mem(self): return et_mem & self.etype @property def _is_ext(self): return et_ext & self.etype @property def _is_lab(self): return et_lab & self.etype @property def _is_ptr(self): return et_ptr & self.etype @property def _is_tst(self): return et_tst & self.etype @property def _is_eqn(self): return et_eqn & self.etype @property def _is_vec(self): return et_vec & self.etype
[docs]class top(exp): """ top expression represents symbolic values that have reached a high complexity threshold. Note: This expression is an absorbing element of the algebra. Any expression that involves a top expression results in a top expression. """ etype = -et_msk-1 __hash__ = exp.__hash__ __eq__ = exp.__eq__
[docs] def depth(self): return float("inf")
# ----------------------------------- # cst holds numeric immediate values # -----------------------------------
[docs]class cst(exp): """ cst expression represents concrete values (constants). Attributes: value (int): get the integer of the expression, taking into account the sign flag. """ __slots__ = ["v"] etype = et_cst __hash__ = exp.__hash__ __eq__ = exp.__eq__ def __init__(self, v, size=32): if isinstance(v, bool): # only True/False forces size=1 (not 0/1 !) v = 1 if v else 0 size = 1 self.sf = False if v >= 0 else True self.size = size self.v = v & self.mask @property def value(self): if self.sf and (self.v >> (self.size - 1) == 1): return -(self.v ^ self.mask) - 1 else: return self.v # for slicing purpose: def __index__(self): return self.value # coercion to Python int: def __int__(self): return self.value # defaults to signed hex base def __unicode__(self): return "{:#x}".format(self.value)
[docs] def toks(self, **kargs): return [(render.Token.Constant, "%s" % self)]
[docs] def to_sym(self, ref): "cast into a symbol expression associated to name ref" return sym(ref, self.v, self.size)
def to_bytes(self,endian=1): s = [] v = self.v for i in range(0,self.size,8): s.append(v&0xff) v = v>>8 return bytes(s[::endian]) # eval of cst is always itself: (sf flag conserved)
[docs] def eval(self, env): return cst(self.value, self.size)
[docs] def zeroextend(self, size): return cst(self.v, max(size, self.size))
[docs] def signextend(self, size): sf = self.sf self.sf = True v = self.value self.sf = sf return cst(v, max(size, self.size))
# bit-slice (returns cst) : @_checkarg_slice def __getitem__(self, i): start = i.start or 0 stop = i.stop or self.size return cst(self.v >> start, stop - start) def __invert__(self): # note: masking is needed because python uses unlimited ints # so ~0x80 means not(...0000080) = ...fffffef return cst((~(self.v)) & self.mask, self.size) def __neg__(self): return cst(-(self.value), self.size) @_checkarg_numeric @_checkarg_sizes def __add__(self, n): if n._is_cst: return cst(self.value + n.value, self.size) else: return exp.__add__(self, n) @_checkarg_numeric @_checkarg_sizes def __sub__(self, n): if n._is_cst: return cst(self.value - n.value, self.size) else: return exp.__sub__(self, n) @_checkarg_numeric @_checkarg_sizes def __mul__(self, n): if n._is_cst: return cst(self.value * n.value, self.size) else: return exp.__mul__(self, n) @_checkarg_numeric @_checkarg_sizes def __pow__(self, n): if n._is_cst: return cst(self.value * n.value, 2 * self.size) else: return exp.__pow__(self, n) @_checkarg_numeric def __div__(self, n): if n._is_cst: return cst(self.value // n.value, self.size) else: return exp.__div__(self, n) @_checkarg_numeric def __truediv__(self, n): if n._is_cst: return cst(self.value // n.value, self.size) else: return exp.__truediv__(self, n) @_checkarg_numeric def __div__(self, n): if n._is_cst: return cst(self.value // n.value, self.size) else: return exp.__div__(self, n) @_checkarg_numeric def __mod__(self, n): if n._is_cst: return cst(self.value % n.value, self.size) else: return exp.__mod__(self, n) @_checkarg_numeric @_checkarg_sizes def __and__(self, n): if n._is_cst: return cst(self.v & n.v, self.size) else: return exp.__and__(self, n) @_checkarg_numeric @_checkarg_sizes def __or__(self, n): if n._is_cst: return cst(self.v | n.v, self.size) else: return exp.__or__(self, n) @_checkarg_numeric @_checkarg_sizes def __xor__(self, n): if n._is_cst: return cst(self.v ^ n.v, self.size) else: return exp.__xor__(self, n) @_checkarg_numeric def __lshift__(self, n): if n._is_cst: return cst(self.value << n.value, self.size) else: return exp.__lshift__(self, n) @_checkarg_numeric def __rshift__(self, n): self.sf = False # rshift implements logical right shift if n._is_cst: return cst(self.value >> n.value, self.size) else: return exp.__rshift__(self, n) @_checkarg_numeric def __floordiv__(self, n): self.sf = True # floordiv implements arithmetic right shift if n._is_cst: return cst(self.value >> n.value, self.size) else: return exp.__floordiv__(self, n) @_checkarg_numeric def __radd__(self, n): return n + self @_checkarg_numeric def __rsub__(self, n): return n - self @_checkarg_numeric def __rmul__(self, n): return n * self @_checkarg_numeric def __rpow__(self, n): return n ** self @_checkarg_numeric def __rdiv__(self, n): return n / self @_checkarg_numeric def __rand__(self, n): return n & self @_checkarg_numeric def __ror__(self, n): return n | self @_checkarg_numeric def __rxor__(self, n): return n ^ self # the only atom that is considered True is the cst(1,1) (ie bit1 below) def __bool__(self): if self.size == 1 and self.v == 1: return True else: return False @_checkarg_numeric @_checkarg_sizes def __eq__(self, n): if n._is_cst: return cst(self.v == n.v) else: return exp.__eq__(self, n) @_checkarg_numeric @_checkarg_sizes def __ne__(self, n): if n._is_cst: return cst(self.v != n.v) else: return exp.__ne__(self, n) @_checkarg_numeric @_checkarg_sizes def __lt__(self, n): if n._is_cst: return cst(self.value < n.value) else: return exp.__lt__(self, n) @_checkarg_numeric @_checkarg_sizes def __le__(self, n): if n._is_cst: return cst(self.value <= n.value) else: return exp.__le__(self, n) @_checkarg_numeric @_checkarg_sizes def __ge__(self, n): if n._is_cst: return cst(self.value >= n.value) else: return exp.__ge__(self, n) @_checkarg_numeric @_checkarg_sizes def __gt__(self, n): if n._is_cst: return cst(self.value > n.value) else: return exp.__gt__(self, n)
bit0 = cst(0, 1) bit1 = cst(1, 1) assert bool(bit1)
[docs]class sym(cst): "symbol expression extends cst with a reference name for pretty printing" __slots__ = ["ref"] __hash__ = cst.__hash__ __eq__ = exp.__eq__ def __init__(self, ref, v, size=32): self.ref = ref cst.__init__(self, v, size) def __unicode__(self): return "#%s" % self.ref
[docs]class cfp(exp): "floating point concrete value expression" __slots__ = ["v"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ etype = et_cst def __init__(self, v, size=32): self.size = size self.v = float(v) @property def value(self): return self.v # coercion to integer: def __int__(self): return NotImplementedError def __unicode__(self): return "{:f}".format(self.value)
[docs] def toks(self, **kargs): return [(render.Token.Constant, "%s" % self)]
[docs] def eval(self, env): return cfp(self.value, self.size)
def __neg__(self): return cfp(-(self.value), self.size) @_checkarg_numeric @_checkarg_sizes def __add__(self, n): if n._is_cst: return cfp(self.v + n.value, self.size) else: return exp.__add__(self, n) @_checkarg_numeric @_checkarg_sizes def __sub__(self, n): if n._is_cst: return cfp(self.v - n.value, self.size) else: return exp.__sub__(self, n) @_checkarg_numeric @_checkarg_sizes def __mul__(self, n): if n._is_cst: return cfp(self.v * n.value, self.size) else: return exp.__mul__(self, n) @_checkarg_numeric @_checkarg_sizes def __pow__(self, n): if n._is_cst: return cfp(self.v * n.value, self.size) else: return exp.__pow__(self, n) @_checkarg_numeric def __div__(self, n): if n._is_cst: return cfp(self.v / n.value, self.size) else: return exp.__div__(self, n) @_checkarg_numeric def __truediv__(self, n): if n._is_cst: return cfp(self.v / n.value, self.size) else: return exp.__truediv__(self, n) @_checkarg_numeric def __radd__(self, n): return n + self @_checkarg_numeric def __rsub__(self, n): return n - self @_checkarg_numeric def __rmul__(self, n): return n * self @_checkarg_numeric def __rpow__(self, n): return n ** self @_checkarg_numeric def __rdiv__(self, n): return n / self @_checkarg_numeric @_checkarg_sizes def __eq__(self, n): if n._is_cst: return cst(self.value == n.value) else: return exp.__eq__(self, n) @_checkarg_numeric @_checkarg_sizes def __ne__(self, n): if n._is_cst: return cst(self.value != n.value) else: return exp.__ne__(self, n) @_checkarg_numeric @_checkarg_sizes def __lt__(self, n): if n._is_cst: return cst(self.value < n.value) else: return exp.__lt__(self, n) @_checkarg_numeric @_checkarg_sizes def __le__(self, n): if n._is_cst: return cst(self.value <= n.value) else: return exp.__le__(self, n) @_checkarg_numeric @_checkarg_sizes def __ge__(self, n): if n._is_cst: return cst(self.value >= n.value) else: return exp.__ge__(self, n) @_checkarg_numeric @_checkarg_sizes def __gt__(self, n): if n._is_cst: return cst(self.value > n.value) else: return exp.__gt__(self, n)
# ------------------------------------------------------------------------------ # reg holds 32-bit register reference (refname). # ------------------------------------------------------------------------------
[docs]class reg(exp): "symbolic register expression" __slots__ = ["ref", "etype", "_subrefs", "__protect"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ def __init__(self, refname, size=32): self.__protect = False self.size = size self.__protect = True self.sf = False self.ref = refname self._subrefs = {} self.etype = et_reg | (regtype.cur or regtype.STD) def __unicode__(self): return "%s" % self.ref
[docs] def toks(self, **kargs): return [(render.Token.Register, "%s" % self)]
[docs] def eval(self, env): r = env[self] r.sf = self.sf return r
def addr(self, env): return self def __setattr__(self, a, v): if a == "size" and self.__protect == True: raise AttributeError("protected attribute") exp.__setattr__(self, a, v) # howto pickle/unpickle reg objects: def __setstate__(self, state): v = state[1] self.__protect = False self.size = v["size"] self.sf = v["sf"] self.ref = v["ref"] self.etype = v["etype"] self._subrefs = v["_subrefs"] self.__protect = v["_reg__protect"]
[docs]class regtype(object): """ decorator and context manager (with...) for associating a register to a specific category among STD (standard), PC (program counter), FLAGS, STACK, OTHER. """ STD = 0x00 PC = 0x10 FLAGS = 0x20 STACK = 0x40 OTHER = 0x80 cur = None def __init__(self, t): self.t = t def __call__(self, r): if not r._is_reg: logger.error("pc decorator ignored (not a register)") r.etype |= self.t return r def __enter__(self): regtype.cur = self.t def __exit__(self, exc_type, exc_value, traceback): regtype.cur = None
is_reg_pc = regtype(regtype.PC) is_reg_flags = regtype(regtype.FLAGS) is_reg_stack = regtype(regtype.STACK) is_reg_other = regtype(regtype.OTHER) # ------------------------------------------------------------------------------ # ext holds external symbols used by the dynamic linker. # ------------------------------------------------------------------------------
[docs]class ext(reg): "external reference to a dynamic (lazy or non-lazy) symbol" __hash__ = reg.__hash__ __eq__ = exp.__eq__ def __init__(self, refname, **kargs): self.ref = refname self._subrefs = kargs self.size = kargs.get("size", None) self.sf = False self._reg__protect = False self.etype = et_ext | et_reg | regtype.OTHER self.stub = None # add the instruction interface: self.address = None self.operands = [] self.misc = {} self.type = 2 # type_control_flow def __unicode__(self): return "@%s" % self.ref
[docs] def toks(self, **kargs): tk = render.Token.Tainted if "!" in self.ref else render.Token.Name return [(tk, "%s" % self)]
def __setattr__(self, a, v): exp.__setattr__(self, a, v)
[docs] def call(self, env, **kargs): "explicit call to the ext's stub" logger.info("stub %s explicit call" % self.ref) if not "size" in kargs: kargs.update(size=self.size) try: res = self.stub(env, **kargs) except TypeError: res = None if res is None: return top(self.size) return res[0 : self.size]
def __call__(self, env): "used when the expression is used as a target instruction" logger.info("stub %s implicit call" % self.ref) f = self.stub f(env, **self._subrefs)
# ------------------------------------------------------------------------------ # lab holds labels/symbols, e.g. from relocations # ------------------------------------------------------------------------------
[docs]class lab(ext): "label expression used by the assembler" __hash__ = ext.__hash__ __eq__ = exp.__eq__ def __init__(self, refname, **kargs): super().__init__(refname,**kargs) self.etype |= et_lab
# ------------------------------------------------------------------------------
[docs]def composer(parts): """ composer returns a comp object (see below) constructed with parts from low significant bits parts to most significant bits parts. The last part sf flag propagates to the resulting comp. """ assert len(parts) > 0 if len(parts) == 1: return parts[0] s = sum([x.size for x in parts]) c = comp(s) c.sf = parts[-1].sf pos = 0 for x in parts: c[pos : pos + x.size] = x pos += x.size return c.simplify()
# ------------------------------------------------------------------------------
[docs]class comp(exp): """ composite expression, represents an expression made of several parts. Attributes: parts (dict): expressions parts dictionary. Each key is a tuple (pos,sz) and value is the exp part. pos is the bit position for this part, and sz is its size. smask (list): mapping of bit index to the part's key that defines this bit. Note: Each part can be accessed by 'slicing' the comp to obtain another comp or the part if the given slice indices match the part position. """ __slots__ = ["smask", "parts"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ etype = et_cmp def __init__(self, s): self.size = s self.sf = False self.smask = [None] * self.size self.parts = {} # the symp is only obtained after a restruct ! def __unicode__(self): s = "{ |" cur = 0 for nv in self: nk = cur, cur + nv.size s += " %s->%s |" % ("[%d:%d]" % nk, nv) cur += nv.size return s + " }"
[docs] def toks(self, **kargs): if "indent" in kargs: p = kargs.get("indent", 0) pad = "\n".ljust(p + 1) kargs["indent"] = p + 4 else: pad = "" tl = (render.Token.Literal, ", ") s = [(render.Token.Literal, "{")] cur = 0 for nv in self: loc = "%s[%2d:%2d] -> " % (pad, cur, cur + nv.size) cur += nv.size s.append((render.Token.Literal, loc)) t = nv.toks(**kargs) s.extend(t) s.append(tl) if len(s) > 1: s.pop() s.append((render.Token.Literal, "}")) return s
[docs] def eval(self, env): res = comp(self.size) res.sf = self.sf res.smask = self.smask[:] for nk, nv in iter(self.parts.items()): res.parts[nk] = nv.eval(env) # now there may be raw numeric value in enode dict, so tiddy up: res.restruct() # once simplified, it may be reduced to 1 part, so: if (0, res.size) in res.parts.keys(): res = res.parts[(0, res.size)] return res
def copy(self): res = comp(self.size) res.smask = self.smask[:] for nk, nv in iter(self.parts.items()): res.parts[nk] = nv res.sf = self.sf return res
[docs] def simplify(self, **kargs): for nk, nv in iter(self.parts.items()): self.parts[nk] = nv.simplify(**kargs) self.restruct() if (0, self.size) in self.parts.keys(): return self.parts[(0, self.size)] else: return self
@_checkarg_slice def __getitem__(self, i): start = i.start or 0 stop = i.stop or self.size # see if the slice is exactly in the compound set: if (start, stop) in self.parts.keys(): return self.parts[(start, stop)] if start == 0 and stop == self.size: return self.copy() l = stop - start res = comp(l) res.sf = self.sf b = 0 while b < l: # select symbol index and object: idx = self.smask[start] if idx is None: b += 1 start += 1 continue else: # idx is a slice keyed in enode dict s = self.parts[idx] # get slice for this symbol: deb = start - idx[0] fin = min(idx[1], stop) - idx[0] d = fin - deb res[b : b + d] = s[deb:fin] b += d start += d res.restruct() if len(res.parts.keys()) == 0: return slicer(self, start, stop - start) if len(res.parts.keys()) == 1: return list(res.parts.values())[0] return res @_checkarg_slice def __setitem__(self, i, v): sta = i.start or 0 sto = i.stop or self.size l = sto - sta if v.size != l: raise ValueError("size mismatch") # make cmp always flat: if v._is_cmp: for vp, vv in v.parts.items(): vsta, vsto = vp self[sta + vsta : sta + vsto] = vv else: # see if the slice is exactly in the compound set: if (sta, sto) in self.parts.keys(): self.parts[(sta, sto)] = v else: self.parts[(sta, sto)] = v self.cut(sta, sto)
[docs] def cut(self, start, stop): """ cut will scan the parts dict to find those spanning **over** start and/or stop bounds then it will split them and remove their inner parts. Note: cut is in in-place method (affects self). """ # list parts that cover (start,stop) range: maskset = [] for nk in filter(None, self.smask[start:stop]): if not nk in maskset: maskset.append(nk) # for each listed part, remove its covering in this range # and update parts and smask dicts accordingly: for nk in maskset: nv = self.parts.pop(nk) if nk[0] < start: self.parts[(nk[0], start)] = nv[0 : start - nk[0]] self.smask[nk[0] : start] = [(nk[0], start)] * (start - nk[0]) if nk[1] > stop: self.parts[(stop, nk[1])] = nv[stop - nk[0] : nk[1] - nk[0]] self.smask[stop : nk[1]] = [(stop, nk[1])] * (nk[1] - stop) self.smask[start:stop] = [(start, stop)] * (stop - start)
def __iter__(self): # gather cst as possible: part = list(self.parts.keys()) part.sort(key=operator.itemgetter(0)) cur = 0 for p in part: assert p[0] == cur yield self.parts[p] cur = p[1]
[docs] def restruct(self): """ restruct will aggregate consecutive cst expressions in order to minimize the number of parts. """ # gather cst as possible: part = list(self.parts.keys()) part.sort(key=operator.itemgetter(0)) for i in range(len(part) - 1): ra = part[i] rb = part[i + 1] if ra[1] == rb[0]: na = self.parts[ra] nb = self.parts[rb] if na._is_cst and nb._is_cst: v = (nb.v << na.size) | (na.v) self.parts[(ra[0], rb[1])] = cst(v, na.size + nb.size) self.parts.pop(ra) self.parts.pop(rb) self.smask[ra[0] : rb[1]] = [(ra[0], rb[1])] * (rb[1] - ra[0]) self.restruct() break elif not (na._is_def or nb._is_def): self.parts[(ra[0], rb[1])] = top(rb[1] - ra[0]) self.parts.pop(ra) self.parts.pop(rb) self.smask[ra[0] : rb[1]] = [(ra[0], rb[1])] * (rb[1] - ra[0]) self.restruct() break
[docs] def depth(self): return sum((p.depth() for p in self))
# ------------------------------------------------------------------------------
[docs]class mem(exp): """ memory expression represents a symbolic value of length size, in segment seg, at given address expression. Attributes: a (ptr): a pointer expression that represents the address. endian (int): 1 means little, -1 means big. mods (list): list of possibly aliasing operations affecting this exp. Note: The mods list allows to handle aliasing issues detected at fetching time and adjust the eval result accordingly. """ __slots__ = ["a", "mods", "endian"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ etype = et_mem def __init__(self, a, size=32, seg=None, disp=0, mods=None, endian=1): self.size = size self.sf = False self.a = ptr(a, seg, disp) self.mods = mods or [] self.endian = endian def __unicode__(self): n = len(self.mods) n = "$%d" % n if n > 0 else "" return "M%d%s%s" % (self.size, n, self.a)
[docs] def toks(self, **kargs): return [(render.Token.Memory, "%s" % self)]
[docs] def eval(self, env): a = self.a.eval(env) m = env.use() for loc, v in self.mods: if loc._is_ptr: loc = env(loc) m[loc] = env(v) res = m[mem(a, self.size, endian=self.endian)] res.sf = self.sf return res
[docs] def simplify(self, **kargs): self.a.simplify(**kargs) if self.a.base._is_vec: seg, disp = self.a.seg, self.a.disp l = [] for a in self.a.base.l: x = mem(a, self.size, seg, disp, mods=self.mods, endian=self.endian) l.append(x) v = vec(l) return v if self.a.base._is_def else vecw(v) return self
def addr(self, env): return self.a.eval(env).unsigned()
[docs] def bytes(self, sta=0, sto=None, endian=0): s = slice(sta, sto) l = self.length sta, sto, stp = s.indices(l) size = (sto - sta) * 8 a = self.a return mem(a, size, disp=sta, mods=self.mods, endian=self.endian)
@_checkarg_slice def __getitem__(self, i): sta, sto, stp = i.indices(self.size) b1, r1 = divmod(sta, 8) b2, r2 = divmod(sto, 8) if r2 > 0: b2 += 1 l = self.length if self.endian == -1: b1, b2 = l - b2, l - b1 a = self.a size = (b2 - b1) * 8 x = mem(a, size, disp=b1, mods=self.mods, endian=self.endian) x.sf = self.sf if r1 > 0 or r2 > 0: x = slc(x, r1, (sto - sta)) return x
# ------------------------------------------------------------------------------
[docs]class ptr(exp): """ ptr holds memory addresses with segment, base expressions and displacement integer (offset relative to base). Attributes: base (exp): symbolic expression for the base of pointer address. disp (int): offset relative to base for the pointer address. seg (reg): segment register (or None if unused.) """ __slots__ = ["base", "disp", "seg"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ etype = et_ptr def __init__(self, base, seg=None, disp=0): if base._is_ptr: if seg is None: seg = base.seg disp = base.disp + disp base = base.base self.base, offset = extract_offset(base) self.disp = disp + offset self.seg = seg self.size = base.size self.sf = False def __unicode__(self): d = self.disp_tostring() seg = "" if self.seg is None else self.seg return "%s(%s%s)" % (seg, self.base, d) def disp_tostring(self, base10=True): if hasattr(self.disp, "_is_cst"): # When allowing label in expressions, e.g. when parsing # relocatable objects and relocations, 'disp' (displacement # from a base address in memory) can not only be a number as # in standard amoco, but also a label, a difference of labels # or even a difference of labels added with an integer return "+%s" % self.disp if self.disp == 0: return "" if base10: return "%+d" % self.disp c = cst(self.disp, self.size) c.sf = False return "+%s" % c
[docs] def toks(self, **kargs): return [(render.Token.Address, "%s" % self)]
[docs] def simplify(self, **kargs): self.base, offset = extract_offset(self.base) self.disp += offset if isinstance(self.seg, exp): self.seg = self.seg.simplify(**kargs) if not self.base._is_def: self.disp = 0 return self
# default segment handler does not care about seg value: @staticmethod def segment_handler(env, s, bd): base, disp = bd return ptr(base, s, disp)
[docs] def eval(self, env): a = self.base.eval(env) s = self.seg if isinstance(s, exp): s = s.eval(env) return self.segment_handler(env, s, (a, self.disp))
# ------------------------------------------------------------------------------
[docs]def slicer(x, pos, size): """ wrapper of slc class that returns a simplified version of x[pos:pos+size]. """ if not isinstance(x, exp): raise TypeError(x) if not x._is_def: return top(size) if pos == 0 and size == x.size: return x else: if x._is_mem or x._is_cmp: res = x[pos : pos + size] res.sf = x.sf return res return slc(x, pos, size)
# ------------------------------------------------------------------------------
[docs]class slc(exp): """ slice expression, represents an expression part. Attributes: x (exp): reference to the symbolic expression pos (int): start bit for the part. ref (str): an alternative symbolic name for this part. """ __slots__ = ["x", "pos", "ref", "__protect", "etype"] __eq__ = exp.__eq__ def __init__(self, x, pos, size, ref=None): if not isinstance(pos, int): raise TypeError(pos) self.__protect = False self.size = size self.sf = x.sf if isinstance(x, slc): res = x[pos : pos + size] x, pos = res.x, res.pos self.x = x self.pos = pos self.etype = et_slc self.setref(ref) def setref(self, ref): if self.x._is_reg: self.etype |= self.x.etype if ref is None: ref = self.x._subrefs.get((self.pos, self.size), None) else: self.x._subrefs[(self.pos, self.size)] = ref self.__protect = True self.ref = ref
[docs] def raw(self): "returns the raw symbolic name (ignore the ref attribute.)" return "%s[%d:%d]" % (self.x, self.pos, self.pos + self.size)
def __setattr__(self, a, v): if a == "size" and self.__protect == True: raise AttributeError("protected attribute") exp.__setattr__(self, a, v) def __unicode__(self): return self.ref or self.raw()
[docs] def toks(self, **kargs): if self._is_reg: return [(render.Token.Register, "%s" % self)] subpart = [(render.Token.Literal, "[%d:%d]" % (self.pos, self.pos + self.size))] return self.x.toks(**kargs) + subpart
def __hash__(self): return hash(self.raw()) # lgtm [py/equals-hash-mismatch]
[docs] def depth(self): return 2 * self.x.depth()
[docs] def eval(self, env): n = self.x.eval(env) res = n[self.pos : self.pos + self.size] res.sf = self.sf return res
# slc of mem objects are simplified by adjusting the disp offset of # the sliced mem object.
[docs] def simplify(self, **kargs): self.x = self.x.simplify(**kargs) if not self.x._is_def: return top(self.size) if self.x._is_cmp or self.x._is_cst: res = self.x[self.pos : self.pos + self.size] res.sf = self.sf return res if self.x._is_mem and self.size % 8 == 0: off, rst = divmod(self.pos, 8) if rst == 0: a = ptr(self.x.a.base, self.x.a.seg, self.x.a.disp + off) res = mem(a, self.size) res.sf = self.sf return res if self.x._is_eqn and ( self.x.op.type == 2 or (self.x.op.symbol in (OP_ADD, OP_MIN) and self.pos == 0) ): r = self.x.r[self.pos : self.pos + self.size] if self.x.op.unary: return self.x.op(r) l = self.x.l[self.pos : self.pos + self.size] return self.x.op(l, r) if self.x._is_vec: return vec([x[self.pos : self.pos + self.size] for x in self.x.l]) else: return self
# slice of a slice: @_checkarg_slice def __getitem__(self, i): if i.start == 0 and i.stop == self.size: return self else: start = self.pos + i.start return slicer(self.x, start, i.stop - i.start) ## # simplify: the only simplification would apply on slc'ed expression x # but x can't be of type slc... def addr(self, env): if self.x._is_mem: a = self.x.addr(env).unsigned() a.disp = self.pos return a elif self.x._is_reg: return self.x else: raise TypeError("this expression is not a location") def __setstate__(self, state): v = state[1] self.__protect = False self.size = v["size"] self.sf = v["sf"] self.x = v["x"] self.pos = v["pos"] self.ref = v["ref"] self.etype = v["etype"] self.__protect = v["_slc__protect"]
# ------------------------------------------------------------------------------
[docs]class tst(exp): """ Conditional expression. Attributes: tst (exp): the boolean expression that represents the condition. l (exp): the resulting expression if test == bit1. r (exp): the resulting expression if test == bit0. """ __slots__ = ["tst", "l", "r"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ etype = et_tst def __init__(self, t, l, r): if t is True or t is False: t = cst(t, 1) self.tst = t # the expression to test, probably a 'op' expressions. if l.size != r.size: raise ValueError((l, r)) self.l = l # true (tst evals to val) self.r = r # false self.size = self.l.size self.sf = False ## def __unicode__(self): return "(%s ? %s : %s)" % (self.tst, self.l, self.r)
[docs] def toks(self, **kargs): ttest = self.tst.toks(**kargs) ttest.append((render.Token.Literal, " ? ")) ttrue = self.l.toks(**kargs) ttrue.append((render.Token.Literal, " : ")) tfalse = self.r.toks(**kargs) return ttest + ttrue + tfalse
# default verify method if smt module is not loaded. # here we check if tst or its negation exist in env.conds but we can # only rely on "syntaxic" features unless we have a solver. # see smt.py: tst_verify() for a SMT-based implementation. def verify(self, env): flag = self.tst.eval(env) for c in env.conds: if c == flag: flag = bit1 break if c == (~flag): flag = bit0 break return flag
[docs] def eval(self, env): cond = self.verify(env) l = self.l.eval(env) r = self.r.eval(env) if not cond._is_cst: return tst(cond, l, r) if cond.v == 1: return l else: return r
[docs] def simplify(self, **kargs): self.tst = self.tst.simplify(**kargs) widening = kargs.get("widening", False) if widening or not self.tst._is_def: return vec([self.l, self.r]).simplify() self.l = self.l.simplify(**kargs) if self.tst == bit1: return self.l self.r = self.r.simplify(**kargs) if self.tst == bit0: return self.r if self.l == self.r: return self.l return self
[docs] def depth(self): return (self.tst.depth() + self.l.depth() + self.r.depth()) / 3.0
# ------------------------------------------------------------------------------
[docs]def oper(opsym, l, r=None): "wrapper of the operator expression that detects unary operations" if r is None: return uop(opsym, l).simplify() return op(opsym, l, r).simplify()
# ------------------------------------------------------------------------------
[docs]class op(exp): """ op holds binary integer arithmetic and bitwise logic expressions Attributes: op (_operator): binary operator prop (int): type of operator (ARITH, LOGIC, CONDT, SHIFT) l (exp): left-hand expression of the operator r (exp): right-hand expression of the operator """ __slots__ = ["op", "l", "r", "prop"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ etype = et_eqn def __init__(self, op, l, r): self.op = _operator(op) self.prop = self.op.type if self.prop < 4: if l.size != r.size: raise ValueError("Size mismatch %d != %d" % (l.size, r.size)) self.l = l self.r = r self.size = self.l.size if self.prop == 4: self.size = 1 elif self.op.symbol in [OP_MUL2]: self.size *= 2 self.sf = l.sf if self.prop == 1: self.sf |= r.sf if self.l._is_eqn: self.prop |= self.l.prop if self.r._is_eqn: self.prop |= self.r.prop
[docs] def eval(self, env): # single-operand : l = self.l.eval(env) r = self.r.eval(env) res = self.op(l, r) res.sf = self.sf return res
## def __unicode__(self): return "(%s%s%s)" % (self.l, render.icons.op(self.op.symbol), self.r)
[docs] def toks(self, **kargs): l = self.l.toks(**kargs) l.insert(0, (render.Token.Literal, "(")) r = self.r.toks(**kargs) r.append((render.Token.Literal, ")")) return l + [(render.Token.Literal, self.op.symbol)] + r
[docs] def simplify(self, **kargs): l = self.l.simplify(**kargs) r = self.r.simplify(**kargs) if self.prop < 4 and self.op.symbol not in (OP_DIV, OP_MOD): if l._is_top: return l if r._is_top: return r minus = self.op.symbol == OP_MIN # arithm/logic normalisation: # push cst to the right if l._is_cst: if r._is_cst: return self.op(l, r) if minus: l, r = (-r), l self.op = _operator(OP_ADD) else: l, r = r, l # lexical ordering of symbols: elif not r._is_cst: lh = "".join(["%s" % x for x in symbols_of(l)]) rh = "".join(["%s" % x for x in symbols_of(r)]) if lh > rh: if minus: l, r = (-r), l self.op = _operator(OP_ADD) else: l, r = r, l self.l = l self.r = r return eqn2_helpers(self, **kargs)
[docs] def depth(self): return self.l.depth() + self.r.depth()
# ------------------------------------------------------------------------------
[docs]class uop(exp): """ uop holds unary integer arithmetic and bitwise logic expressions Attributes: op (_operator): unary operator prop (int): type of operator (ARITH, LOGIC, CONDT, SHIFT) l (None): returns None in case uop is treated as an op instance. r (exp): right-hand expression of the operator """ __slots__ = ["op", "r", "prop"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ etype = et_eqn def __init__(self, op, r): self.op = _operator(op, unary=1) self.prop = self.op.type self.r = r self.size = r.size self.sf = r.sf if self.r._is_eqn: self.prop |= self.r.prop
[docs] def eval(self, env): # single-operand : r = self.r.eval(env) res = self.op(r) res.sf = self.sf return res
@property def l(self): return None def __unicode__(self): return "(%s%s)" % (render.icons.op(self.op.symbol), self.r)
[docs] def toks(self, **kargs): r = self.r.toks(**kargs) r.append((render.Token.Literal, ")")) return [(render.Token.Literal, "(%s" % self.op.symbol)] + r
[docs] def simplify(self, **kargs): r = self.r.simplify(**kargs) if r._is_top: return r self.r = r return eqn1_helpers(self, **kargs)
[docs] def depth(self): return self.r.depth()
# operators: # ----------- OP_ADD = "+" OP_MIN = "-" OP_MUL = "*" OP_MUL2 = "**" OP_DIV = "/" OP_MOD = "%" OP_AND = "&" OP_OR = "|" OP_XOR = "^" OP_NOT = "~" OP_EQ = "==" OP_NEQ = "!=" OP_LE = "<=" OP_GE = ">=" OP_GEU = ">=." OP_LT = "<" OP_LTU = "<." OP_GT = ">" OP_LSL = "<<" OP_LSR = ">>" OP_ASR = ".>>" OP_ROR = ">>>" OP_ROL = "<<<"
[docs]def ror(x, n): "high-level rotate right n bits" return (x >> n | x << (x.size - n)) if x._is_cst else op(OP_ROR, x, n)
[docs]def rol(x, n): "high-level rotate left n bits" return (x << n | x >> (x.size - n)) if x._is_cst else op(OP_ROL, x, n)
[docs]def ltu(x, y): "high-level less-than-unsigned operation" try: if not (x._is_cst and y._is_cst): return op(OP_LTU, x, y) except AttributeError: pass x.sf = y.sf = True return x < y
[docs]def geu(x, y): "high level greater-or-equal-unsigned operation" try: if not (x._is_cst and y._is_cst): return op(OP_GEU, x, y) except AttributeError: pass x.sf = y.sf = True return x >= y
OP_ARITH = { OP_ADD: operator.add, OP_MIN: operator.sub, OP_MUL: operator.mul, OP_MUL2: operator.pow, OP_DIV: operator.truediv, OP_MOD: operator.mod, } OP_LOGIC = { OP_AND: operator.and_, OP_OR: operator.or_, OP_XOR: operator.xor, OP_NOT: operator.invert, } OP_CONDT = { OP_EQ: operator.eq, OP_NEQ: operator.ne, OP_LE: operator.le, OP_GE: operator.ge, OP_GEU: geu, OP_LT: operator.lt, OP_LTU: ltu, OP_GT: operator.gt, } OP_SHIFT = { OP_LSR: operator.rshift, # logical shift right (see cst.value) OP_LSL: operator.lshift, OP_ASR: operator.floordiv, # this is arithmetic shift right OP_ROR: ror, OP_ROL: rol, } class _operator(object): def __init__(self, op, unary=0): self.symbol = op self.unary = unary self.unsigned = False if op in OP_ARITH: self.type = 1 if self.unary: self.impl = {OP_ADD: operator.pos, OP_MIN: operator.neg}[op] else: self.impl = OP_ARITH[op] elif op in OP_LOGIC: self.type = 2 self.unsigned = True if self.unary: assert op == OP_NOT self.impl = OP_LOGIC[op] elif op in OP_CONDT: self.type = 4 self.impl = OP_CONDT[op] if op in (OP_GEU, OP_LTU): self.unsigned = True elif op in OP_SHIFT: self.type = 8 self.impl = OP_SHIFT[op] else: raise NotImplementedError def __call__(self, l, r=None): if r is None: assert self.unary return self.impl(l) if self.unsigned: l.sf = r.sf = False return self.impl(l, r) def __mul__(self, op): ss = self.symbol + op.symbol if ss in ("++", "--"): return OP_ADD if ss in ("+-", "-+"): return OP_MIN return None # basic simplifier: # ------------------
[docs]def symbols_of(e): "returns all symbols contained in expression e" if e is None: return [] if e._is_cst: return [] if e._is_reg: return [e] if e._is_mem: return symbols_of(e.a.base) if e._is_ptr: return symbols_of(e.base) if e._is_eqn: return symbols_of(e.l) + symbols_of(e.r) if e._is_tst: return sum([symbols_of(x) for x in (e.tst, e.l, e.r)], []) if e._is_slc: return symbols_of(e.x) if e._is_cmp: return sum([symbols_of(x) for x in e.parts.values()], []) if e._is_vec: return sum([symbols_of(x) for x in e.l], []) if not e._is_def: return [] raise ValueError(e)
[docs]def locations_of(e): "returns all locations contained in expression e" if e is None: return [] if e._is_cst: return [] if e._is_reg: return [e] if e._is_mem: return [e] if e._is_ptr: return [e] if e._is_eqn: return locations_of(e.l) + locations_of(e.r) if e._is_tst: return sum([locations_of(x) for x in (e.tst, e.l, e.r)], []) if e._is_slc: return locations_of(e.x) if e._is_cmp: return sum([locations_of(x) for x in e.parts.values()], []) if e._is_vec: return sum([locations_of(x) for x in e.l], []) if not e._is_def: return [] raise ValueError(e)
[docs]def complexity(e): "evaluate the complexity of expression e" factor = e.prop if e._is_eqn else 1 return (e.depth() + len(symbols_of(e))) * factor
[docs]def eqn1_helpers(e, **kargs): "helpers for simplifying unary expressions" assert e.op.unary if e.r._is_cst: return e.op(e.r) if e.r._is_vec: return vec([e.op(x) for x in e.r.l]) if e.r._is_eqn: if e.r.op.unary: ss = e.op * e.r.op if ss == OP_ADD: return e.r.r elif ss == OP_MIN: return -e.r.r elif e.op.symbol == OP_MIN: if e.r.op.symbol in (OP_MIN, OP_ADD): l = -e.r.l r = e.r.r return OP_ARITH[e.op * e.r.op](l, r) elif e.op.symbol == OP_NOT and e.r.op.type == 4: notop = { OP_EQ: OP_NEQ, OP_NEQ: OP_EQ, OP_LT: OP_GE, OP_GT: OP_LE, OP_LTU: OP_GEU, OP_GEU: OP_LTU, OP_LE: OP_GT, OP_GE: OP_LT, }[e.r.op.symbol] return OP_CONDT[notop](e.r.l, e.r.r) return e
def get_lsb_msb(v): msb = v.bit_length() - 1 lsb = (v & -v).bit_length() - 1 return (lsb, msb) def ismask(v): i1, i2 = get_lsb_msb(v) return ((1 << (i2 + 1)) - 1) ^ ((1 << i1) - 1) == v # reminder: be careful not to modify the internal structure of # e.l or e.r because these objects might be used also in other # expressions. See tests/test_cas_exp.py for details.
[docs]def eqn2_helpers(e, bitslice=False, widening=False): "helpers for simplifying binary expressions" threshold = conf.Cas.complexity if complexity(e.r) > threshold: e.r = top(e.r.size) if complexity(e.l) > threshold: e.l = top(e.l.size) if e.r._is_top or e.l._is_top: return top(e.size) # if e := ((a l.op cst) e.op r) if e.l._is_eqn and e.l.r._is_cst and e.l.op.unary == 0: xop = e.op * e.l.op # if ++ -- +- -+, if xop: # move cst to the right: # e := (a e.op r) l.op cst e.op, lop = e.l.op, e.op lr, e.r = e.r, e.l.r e.l = lop(e.l.l, lr) # if e:= (l + (- r) # change into e:= l - r if e.r._is_eqn and e.r.op.unary: if e.op.symbol == OP_ADD and e.r.op.symbol == OP_MIN: e.op = _operator(OP_MIN) e.r = e.r.r # if e:= (l [+-] (a [+-] cst)) # move cst to the right: # e:= (l [+-] a) xop cst if e.r._is_eqn and e.r.r._is_cst: xop = e.op * e.r.op if xop: e.l = e.op(e.l, e.r.l) e.r = e.r.r e.op = _operator(xop) # now if e:= (l op cst) if e.r._is_cst: if e.r.value == 0: # if e:= (l [|^+-...] 0) then e:= l if e.op.symbol in ( OP_OR, OP_XOR, OP_ADD, OP_MIN, OP_LSR, OP_LSL, OP_ROR, OP_ROL, ): return e.l # if e:= (l [|&*] 0) then e:= 0 if e.op.symbol in (OP_AND, OP_MUL, OP_MUL2): return cst(0, e.size) # if e:= (l [|*/] 1) then e:= l elif e.r.value == 1 and e.op.symbol in (OP_MUL, OP_MUL2, OP_DIV): return e.l # if e:= (l & mask) then e:= l[i1:i2] elif e.op.symbol == OP_AND and ismask(e.r.value): i1, i2 = get_lsb_msb(e.r.value) c = comp(e.size) c[0 : e.size] = cst(0, e.size) c[i1 : i2 + 1] = e.l[i1 : i2 + 1] return c.simplify() elif bitslice and e.op.symbol in (OP_AND, OP_OR, OP_XOR): return composer( [e.op(e.l[i : i + 1], e.r[i : i + 1]) for i in range(e.size)] ) elif bitslice and e.op.symbol in (OP_LSL): return composer( [bit0] * e.r.value + [e.l[i : i + 1] for i in range(0, e.size - e.r.value)] ) elif bitslice and e.op.symbol in (OP_LSR): return composer( [e.l[i : i + 1] for i in range(e.r.value, e.size)] + [bit0] * e.r.value ) # if e:= (l [>> <<] r) then e:= l[i1:i2] elif e.op.symbol in (OP_LSL, OP_LSR): c = comp(e.l.size) c[0 : e.l.size] = cst(0, e.l.size) if e.op.symbol == OP_LSL: l = e.l[0 : e.l.size - e.r.value] c[e.r.value : e.l.size] = l elif e.op.symbol == OP_LSR: l = e.l[e.r.value : e.l.size] c[0 : e.l.size - e.r.value] = l return c.simplify() # if e:= ((a op b) e.op cst) if e.l._is_eqn: xop = e.op * e.l.op if xop: # if e:= ((a [+-] cst) [+-] cst) # merge constants: # change into e := a [+-] cst if e.l.r._is_cst: cc = OP_ARITH[xop](e.l.r, e.r) e.op = e.l.op if not e.l.op.unary: e.l = e.l.l e.r = cc return e elif e.r.size == 1: # if e:= ((a op b) == bit1) change in e := (a op b) # if e:= ((a op b) == bit0) change in e := ~(a op b) if e.op.symbol == OP_EQ: return e.l if e.r.value == 1 else ~(e.l) if e.op.symbol == OP_NEQ: return ~(e.l) if e.r.value == 1 else ~(e.l) elif e.l._is_ptr: if e.op.symbol in (OP_MIN, OP_ADD): return ptr(e.l, disp=e.op(0, e.r.value)) elif e.l._is_cmp: if e.op.symbol in (OP_AND, OP_OR, OP_XOR): cc = comp(e.l.size) for (ij, p) in e.l.parts.items(): i, j = ij cc[i:j] = e.op(p, e.r[i:j]) return cc.simplify(bitslice=bitslice) elif e.l._is_cst: return e.op(e.l, e.r) if e.l._is_vec: return vec([e.op(x, e.r) for x in e.l.l]).simplify(widening=widening) if e.r._is_vec: return vec([e.op(e.l, x) for x in e.r.l]).simplify(widening=widening) if "%s" % (e.l) == "%s" % (e.r): if e.op.symbol in (OP_NEQ, OP_LT, OP_GT): return bit0 if e.op.symbol in (OP_EQ, OP_LE, OP_GE): return bit1 if e.op.symbol == OP_MIN: return cst(0, e.size) if e.op.symbol == OP_XOR: return cst(0, e.size) if e.op.symbol == OP_AND: return e.l if e.op.symbol == OP_OR: return e.l return e
[docs]def extract_offset(e): "separate expression e into (e' + C) with C cst offset." x = e.simplify().unsigned() if x._is_eqn and x.r._is_cst: if e.op.symbol == OP_ADD: return (x.l, x.r.value) elif e.op.symbol == OP_MIN: return (x.l, -x.r.value) return (x, 0)
# -----------------------------------------------------
[docs]class vec(exp): """ vec holds a list of expressions each being a possible representation of the current expression. A vec object is obtained by merging several execution paths using the merge function in the mapper module. The simplify method uses the complexity measure to eventually "reduce" the expression to top with a hard-limit currently set to op.threshold. """ __slots__ = ["l"] __hash__ = exp.__hash__ __eq__ = exp.__eq__ etype = et_vec def __init__(self, l=None): if l is None: l = [] self.l = l size = 0 for e in self.l: if e.size > size: size = e.size if any([e.size != size for e in self.l]): raise ValueError("size mismatch") self.size = size self.sf = any([e.sf for e in self.l]) def __unicode__(self): s = ",".join(["%s" % x for x in self.l]) return "[%s]" % (s)
[docs] def toks(self, **kargs): t = [] for x in self.l: t.extend(x.toks(**kargs)) t.append((render.Token.Literal, ", ")) if len(t) > 0: t.pop() t.insert(0, (render.Token.Literal, "[")) t.append((render.Token.Literal, "]")) return t
[docs] def simplify(self, **kargs): widening = kargs.get("widening", False) l = [] for e in self.l: ee = e.simplify() if not ee._is_def: return ee if ee._is_vec: l.extend(ee.l) if isinstance(ee, vecw): widening = True else: l.append(ee) self.l = [] for e in l: if e in self.l: continue self.l.append(e) if len(self.l) == 1: return self.l[0] if widening: return vecw(self) cl = [complexity(x) for x in self.l] if sum(cl, 0.0) > conf.Cas.complexity: return top(self.size) return self
[docs] def eval(self, env): l = [] for e in self.l: l.append(e.eval(env)) return vec(l)
[docs] def depth(self): if self.size == 0: return 0.0 return max([e.depth() for e in self.l]) * len(self.l)
@_checkarg_slice def __getitem__(self, i): sta, sto, stp = i.indices(self.size) l = [] for e in self.l: l.append(e[sta:sto]) return vec(l) def __contains__(self, x): return x in self.l def __bool__(self): return all([e.__bool__() for e in self.l])
[docs]class vecw(top): """ vecw is a *widened* vec expression: it allows to limit the list of possible values to a fixed range and acts as a top (absorbing) expression. """ __slots__ = ["l"] __hash__ = top.__hash__ __eq__ = exp.__eq__ etype = ~(~et_vec & et_msk) def __init__(self, v): self.l = v.l self.size = v.size self.sf = False def __unicode__(self): s = ",".join(["%s" % x for x in self.l]) return "[%s, %s]" % (s,render.icons.dots)
[docs] def toks(self, **kargs): t = [] for x in self.l: t.extend(x.toks(**kargs)) t.append((render.Token.Literal, ", ")) if len(t) > 0: t.pop() t.insert(0, (render.Token.Literal, "[")) t.append((render.Token.Literal, ", %s]"%render.icons.dots)) return t
[docs] def eval(self, env): v = vec([x.eval(env) for x in self.l]) return vecw(v)
@_checkarg_slice def __getitem__(self, i): sta, sto, stp = i.indices(self.size) l = [] for e in self.l: l.append(e[sta:sto]) return vecw(vec(l))