diff --git a/llpython/addr_flow.py b/llpython/addr_flow.py index 7447789..8630a5a 100644 --- a/llpython/addr_flow.py +++ b/llpython/addr_flow.py @@ -5,6 +5,7 @@ from __future__ import absolute_import from .byte_flow import BytecodeFlowBuilder +from .opcode_util import build_basic_blocks, itercodeobjs # ______________________________________________________________________ # Class definition(s) @@ -79,16 +80,43 @@ def build_addr_flow(func): cfg = byte_control.build_cfg(func) return AddressFlowBuilder().visit_cfg(cfg) +# ______________________________________________________________________ + +def build_addr_flow_from_co(codeobj): + from .byte_control import ControlFlowBuilder + cfg = ControlFlowBuilder().visit(build_basic_blocks(codeobj), + codeobj.co_argcount) + return AddressFlowBuilder().visit_cfg(cfg) + +# ______________________________________________________________________ + +def build_addr_flows_from_co(root_co): + return dict((co, build_addr_flow_from_co(co)) + for co in itercodeobjs(root_co)) + # ______________________________________________________________________ # Main (self-test) routine def main(*args): import pprint - from .tests import llfuncs + try: + from .tests import llfuncs + except ImportError: + llfuncs = object() if not args: args = ('pymod',) for arg in args: - pprint.pprint(build_addr_flow(getattr(llfuncs, arg))) + if arg.endswith('.py'): + with open(arg) as in_file: + in_source = in_file.read() + in_codeobj = compile(in_source, arg, 'exec') + flow_map = build_addr_flows_from_co(in_codeobj) + for codeobj, flow in flow_map.items(): + print("_" * 70) + print(codeobj) + pprint.pprint(flow) + else: + pprint.pprint(build_addr_flow(getattr(llfuncs, arg))) # ______________________________________________________________________ diff --git a/llpython/opcode_util.py b/llpython/opcode_util.py index bec9c01..b6ccf11 100644 --- a/llpython/opcode_util.py +++ b/llpython/opcode_util.py @@ -1,9 +1,10 @@ #! /usr/bin/env python # ______________________________________________________________________ +from collections import namedtuple import dis import opcode -from collections import namedtuple +import types # ______________________________________________________________________ # Module data @@ -195,6 +196,16 @@ def itercode(code, start = 0): # ______________________________________________________________________ +def itercodeobjs(codeobj): + "Iterator that traverses code objects via the co_consts member." + yield codeobj + for const in codeobj.co_consts: + if isinstance(const, types.CodeType): + for childobj in itercodeobjs(const): + yield childobj + +# ______________________________________________________________________ + def extendlabels(code, labels = None): """Extend the set of jump target labels to account for the passthrough targets of conditional branches.