Coverage for asteval/asteval.py: 93%
629 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-05 11:00 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-05 11:00 +0000
1#!/usr/bin/env python
2"""
3Safe(ish) evaluation of minimal Python code using Python's ast module.
5This module provides an Interpreter class that compiles a restricted set of
6Python expressions and statements to Python's AST representation, and then
7executes that representation using values held in a symbol table.
9The symbol table is a simple dictionary, giving a flat namespace. This comes
10pre-loaded with many functions from Python's builtin and math module. If numpy
11is installed, many numpy functions are also included. Additional symbols can
12be added when an Interpreter is created, but the user of that interpreter will
13not be able to import additional modules.
15Expressions, including loops, conditionals, and function definitions can be
16compiled into ast node and then evaluated later, using the current values
17in the symbol table.
19The result is a restricted, simplified version of Python meant for numerical
20calculations that is somewhat safer than 'eval' because many unsafe operations
21(such as 'eval') are simply not allowed, and others (such as 'import') are
22disabled by default, but can be explicitly enabled.
24Many parts of Python syntax are supported, including:
25 for loops, while loops, if-then-elif-else conditionals, with,
26 try-except-finally
27 function definitions with def
28 advanced slicing: a[::-1], array[-3:, :, ::2]
29 if-expressions: out = one_thing if TEST else other
30 list, dict, and set comprehension
32The following Python syntax elements are not supported:
33 Import, Exec, Lambda, Class, Global, Generators,
34 Yield, Decorators
36In addition, while many builtin functions are supported, several builtin
37functions that are considered unsafe are missing ('eval', 'exec', and
38'getattr' for example) are missing.
39"""
40import ast
41import sys
42import copy
43import inspect
44import time
45from sys import exc_info, stderr, stdout
47from .astutils import (HAS_NUMPY, UNSAFE_ATTRS, UNSAFE_ATTRS_DTYPES,
48 ExceptionHolder, ReturnedNone, Empty, make_symbol_table,
49 numpy, op2func, valid_symbol_name, Procedure)
51ALL_NODES = ['arg', 'assert', 'assign', 'attribute', 'augassign', 'binop',
52 'boolop', 'break', 'bytes', 'call', 'compare', 'constant',
53 'continue', 'delete', 'dict', 'dictcomp', 'ellipsis',
54 'excepthandler', 'expr', 'extslice', 'for', 'functiondef', 'if',
55 'ifexp', 'import', 'importfrom', 'index', 'interrupt', 'list',
56 'listcomp', 'module', 'name', 'nameconstant', 'num', 'pass',
57 'raise', 'repr', 'return', 'set', 'setcomp', 'slice', 'str',
58 'subscript', 'try', 'tuple', 'unaryop', 'while', 'with',
59 'formattedvalue', 'joinedstr']
62MINIMAL_CONFIG = {'import': False, 'importfrom': False}
63DEFAULT_CONFIG = {'import': False, 'importfrom': False}
65for _tnode in ('assert', 'augassign', 'delete', 'if', 'ifexp', 'for',
66 'formattedvalue', 'functiondef', 'print', 'raise', 'listcomp',
67 'dictcomp', 'setcomp', 'try', 'while', 'with'):
68 MINIMAL_CONFIG[_tnode] = False
69 DEFAULT_CONFIG[_tnode] = True
71class Interpreter:
72 """create an asteval Interpreter: a restricted, simplified interpreter
73 of mathematical expressions using Python syntax.
75 Parameters
76 ----------
77 symtable : dict or `None`
78 dictionary or SymbolTable to use as symbol table (if `None`, one will be created).
79 nested_symtable : bool, optional
80 whether to use a new-style nested symbol table instead of a plain dict [False]
81 user_symbols : dict or `None`
82 dictionary of user-defined symbols to add to symbol table.
83 writer : file-like or `None`
84 callable file-like object where standard output will be sent.
85 err_writer : file-like or `None`
86 callable file-like object where standard error will be sent.
87 use_numpy : bool
88 whether to use functions from numpy.
89 max_statement_length : int
90 maximum length of expression allowed [50,000 characters]
91 readonly_symbols : iterable or `None`
92 symbols that the user can not assign to
93 builtins_readonly : bool
94 whether to blacklist all symbols that are in the initial symtable
95 minimal : bool
96 create a minimal interpreter: disable many nodes (see Note 1).
97 config : dict
98 dictionay listing which nodes to support (see note 2))
100 Notes
101 -----
102 1. setting `minimal=True` is equivalent to setting a config with the following
103 nodes disabled: ('import', 'importfrom', 'if', 'for', 'while', 'try', 'with',
104 'functiondef', 'ifexp', 'listcomp', 'dictcomp', 'setcomp', 'augassign',
105 'assert', 'delete', 'raise', 'print')
106 2. by default 'import' and 'importfrom' are disabled, though they can be enabled.
107 """
108 def __init__(self, symtable=None, nested_symtable=False,
109 user_symbols=None, writer=None, err_writer=None,
110 use_numpy=True, max_statement_length=50000,
111 minimal=False, readonly_symbols=None,
112 builtins_readonly=False, config=None, **kws):
114 self.config = copy.copy(MINIMAL_CONFIG if minimal else DEFAULT_CONFIG)
115 if config is not None:
116 self.config.update(config)
117 self.config['nested_symtable'] = nested_symtable
119 if user_symbols is None:
120 user_symbols = {}
121 if 'usersyms' in kws:
122 user_symbols = kws.pop('usersyms') # back compat, changed July, 2023, v 0.9.4
124 if len(kws) > 0:
125 for key, val in kws.items():
126 if key.startswith('no_'):
127 node = key[3:]
128 if node in ALL_NODES:
129 self.config[node] = not val
130 elif key.startswith('with_'):
131 node = key[5:]
132 if node in ALL_NODES:
133 self.config[node] = val
135 self.writer = writer or stdout
136 self.err_writer = err_writer or stderr
137 self.max_statement_length = max(1, min(1.e8, max_statement_length))
139 self.use_numpy = HAS_NUMPY and use_numpy
140 if symtable is None:
141 symtable = make_symbol_table(nested=nested_symtable,
142 use_numpy=self.use_numpy, **user_symbols)
144 symtable['print'] = self._printer
145 self.symtable = symtable
146 self._interrupt = None
147 self.error = []
148 self.error_msg = None
149 self.expr = None
150 self.retval = None
151 self._calldepth = 0
152 self.lineno = 0
153 self.start_time = time.time()
155 self.node_handlers = {}
156 for node in ALL_NODES:
157 handler = self.unimplemented
158 if self.config.get(node, True):
159 handler = getattr(self, f"on_{node}", self.unimplemented)
160 self.node_handlers[node] = handler
162 # to rationalize try/except try/finally
163 if 'try' in self.node_handlers:
164 self.node_handlers['tryexcept'] = self.node_handlers['try']
165 self.node_handlers['tryfinally'] = self.node_handlers['try']
167 if readonly_symbols is None:
168 self.readonly_symbols = set()
169 else:
170 self.readonly_symbols = set(readonly_symbols)
172 if builtins_readonly:
173 self.readonly_symbols |= set(self.symtable)
175 self.no_deepcopy = [key for key, val in symtable.items()
176 if (callable(val)
177 or inspect.ismodule(val)
178 or 'numpy.lib.index_tricks' in repr(type(val)))]
180 def remove_nodehandler(self, node):
181 """remove support for a node
182 returns current node handler, so that it
183 might be re-added with add_nodehandler()
184 """
185 out = None
186 if node in self.node_handlers:
187 out = self.node_handlers.pop(node)
188 return out
190 def set_nodehandler(self, node, handler=None):
191 """set node handler or use current built-in default"""
192 if handler is None:
193 handler = getattr(self, f"on_{node}", self.unimplemented)
194 self.node_handlers[node] = handler
195 return handler
197 def user_defined_symbols(self):
198 """Return a set of symbols that have been added to symtable after
199 construction.
201 I.e., the symbols from self.symtable that are not in
202 self.no_deepcopy.
204 Returns
205 -------
206 unique_symbols : set
207 symbols in symtable that are not in self.no_deepcopy
209 """
210 sym_in_current = set(self.symtable.keys())
211 sym_from_construction = set(self.no_deepcopy)
212 unique_symbols = sym_in_current.difference(sym_from_construction)
213 return unique_symbols
215 def unimplemented(self, node):
216 """Unimplemented nodes."""
217 msg = f"{node.__class__.__name__} not supported"
218 self.raise_exception(node, exc=NotImplementedError, msg=msg)
220 def raise_exception(self, node, exc=None, msg='', expr=None, lineno=None):
221 """Add an exception."""
222 if expr is not None:
223 self.expr = expr
224 msg = str(msg)
225 err = ExceptionHolder(node, exc=exc, msg=msg, expr=self.expr, lineno=lineno)
226 self._interrupt = ast.Raise()
228 self.error.append(err)
229 if self.error_msg is None:
230 self.error_msg = msg
231 elif len(msg) > 0:
232 pass
233 # if err.exc is not None:
234 # self.error_msg = f"{err.exc.__name__}: {msg}"
235 if exc is None:
236 exc = self.error[-1].exc
237 if exc is None and len(self.error) > 0:
238 while exc is None and len(self.error) > 0:
239 err = self.error.pop()
240 exc = err.exc
241 if exc is None:
242 exc = Exception
243 if len(err.msg) == 0 and len(self.error_msg) == 0 and len(self.error) > 1:
244 err = self.error.pop(-1)
245 raise err.exc(err.msg)
246 else:
247 if len(err.msg) == 0:
248 err.msg = self.error_msg
249 raise exc(self.error_msg)
251 # main entry point for Ast node evaluation
252 # parse: text of statements -> ast
253 # run: ast -> result
254 # eval: string statement -> result = run(parse(statement))
255 def parse(self, text):
256 """Parse statement/expression to Ast representation."""
257 if len(text) > self.max_statement_length:
258 msg = f'length of text exceeds {self.max_statement_length:d} characters'
259 self.raise_exception(None, exc=RuntimeError, expr=msg)
260 self.expr = text
261 try:
262 out = ast.parse(text)
263 except SyntaxError:
264 self.raise_exception(None, exc=SyntaxError, expr=text)
265 except:
266 self.raise_exception(None, exc=RuntimeError, expr=text)
268 return out
270 def run(self, node, expr=None, lineno=None, with_raise=True):
271 """Execute parsed Ast representation for an expression."""
272 # Note: keep the 'node is None' test: internal code here may run
273 # run(None) and expect a None in return.
274 if isinstance(node, str):
275 return self.eval(node, raise_errors=with_raise)
277 out = None
278 if len(self.error) > 0:
279 return out
280 if self.retval is not None:
281 return self.retval
282 if isinstance(self._interrupt, (ast.Break, ast.Continue)):
283 return self._interrupt
284 if node is None:
285 return out
287 if lineno is not None:
288 self.lineno = lineno
289 if expr is not None:
290 self.expr = expr
292 # get handler for this node:
293 # on_xxx with handle nodes of type 'xxx', etc
294 try:
295 handler = self.node_handlers[node.__class__.__name__.lower()]
296 except KeyError:
297 self.raise_exception(None, exc=NotImplementedError, expr=self.expr)
299 # run the handler: this will likely generate
300 # recursive calls into this run method.
301 try:
302 ret = handler(node)
303 if isinstance(ret, enumerate):
304 ret = list(ret)
305 return ret
306 except:
307 if with_raise and self.expr is not None:
308 self.raise_exception(node, expr=self.expr)
310 # avoid too many repeated error messages (yes, this needs to be "2")
311 if len(self.error) > 2:
312 self._remove_duplicate_errors()
314 return None
316 def _remove_duplicate_errors(self):
317 """remove duplicate exceptions"""
318 error = [self.error[0]]
319 for err in self.error[1:]:
320 lerr = error[-1]
321 if err.exc != lerr.exc or err.expr != lerr.expr or err.msg != lerr.msg:
322 if isinstance(err.msg, str) and len(err.msg) > 0:
323 error.append(err)
324 self.error = error
326 def __call__(self, expr, **kw):
327 """Call class instance as function."""
328 return self.eval(expr, **kw)
330 def eval(self, expr, lineno=0, show_errors=True, raise_errors=False):
331 """Evaluate a single statement."""
332 self.lineno = lineno
333 self.error = []
334 self.error_msg = None
335 self.start_time = time.time()
336 if isinstance(expr, str):
337 try:
338 node = self.parse(expr)
339 except Exception:
340 errmsg = exc_info()[1]
341 if len(self.error) > 0:
342 lerr = self.error[-1]
343 errmsg = lerr.get_error()[1]
344 if raise_errors:
345 raise lerr.exc(errmsg)
346 if show_errors:
347 print(errmsg, file=self.err_writer)
348 return None
349 else:
350 node = expr
351 try:
352 return self.run(node, expr=expr, lineno=lineno, with_raise=raise_errors)
353 except Exception:
354 if show_errors and not raise_errors:
355 errmsg = exc_info()[1]
356 if len(self.error) > 0:
357 errmsg = self.error[-1].get_error()[1]
358 print(errmsg, file=self.err_writer)
359 if raise_errors and len(self.error) > 0:
360 self._remove_duplicate_errors()
361 err = self.error[-1]
362 raise err.exc(err.get_error()[1])
363 return None
365 @staticmethod
366 def dump(node, **kw):
367 """Simple ast dumper."""
368 return ast.dump(node, **kw)
370 # handlers for ast components
371 def on_expr(self, node):
372 """Expression."""
373 return self.run(node.value) # ('value',)
375 # imports
376 def on_import(self, node): # ('names',)
377 "simple import"
378 for tnode in node.names:
379 self.import_module(tnode.name, tnode.asname)
381 def on_importfrom(self, node): # ('module', 'names', 'level')
382 "import/from"
383 fromlist, asname = [], []
384 for tnode in node.names:
385 fromlist.append(tnode.name)
386 asname.append(tnode.asname)
387 self.import_module(node.module, asname, fromlist=fromlist)
389 def import_module(self, name, asname, fromlist=None):
390 """import a python module, installing it into the symbol table.
391 options:
392 name name of module to import 'foo' in 'import foo'
393 asname alias for imported name(s)
394 'bar' in 'import foo as bar'
395 or
396 ['s','t'] in 'from foo import x as s, y as t'
397 fromlist list of symbols to import with 'from-import'
398 ['x','y'] in 'from foo import x, y'
399 """
400 # find module in sys.modules or import to it
401 if name in sys.modules:
402 thismod = sys.modules[name]
403 else:
404 try:
405 __import__(name)
406 thismod = sys.modules[name]
407 except:
408 self.raise_exception(None, exc=ImportError, msg='Import Error')
410 if fromlist is None:
411 if asname is not None:
412 self.symtable[asname] = sys.modules[name]
413 else:
414 mparts = []
415 parts = name.split('.')
416 while len(parts) > 0:
417 mparts.append(parts.pop(0))
418 modname = '.'.join(mparts)
419 inname = name if (len(parts) == 0) else modname
420 self.symtable[inname] = sys.modules[modname]
421 else: # import-from construct
422 if asname is None:
423 asname = [None]*len(fromlist)
424 for sym, alias in zip(fromlist, asname):
425 if alias is None:
426 alias = sym
427 self.symtable[alias] = getattr(thismod, sym)
429 def on_index(self, node):
430 """Index."""
431 return self.run(node.value) # ('value',)
433 def on_return(self, node): # ('value',)
434 """Return statement: look for None, return special sentinel."""
435 if self._calldepth == 0:
436 raise SyntaxError('cannot return at top level')
437 self.retval = self.run(node.value)
438 if self.retval is None:
439 self.retval = ReturnedNone
441 def on_repr(self, node):
442 """Repr."""
443 return repr(self.run(node.value)) # ('value',)
445 def on_module(self, node): # ():('body',)
446 """Module def."""
447 out = None
448 for tnode in node.body:
449 out = self.run(tnode)
450 return out
452 def on_expression(self, node):
453 "basic expression"
454 return self.on_module(node) # ():('body',)
456 def on_pass(self, node):
457 """Pass statement."""
458 return None # ()
460 # for break and continue: set the instance variable _interrupt
461 def on_interrupt(self, node): # ()
462 """Interrupt handler."""
463 self._interrupt = node
464 return node
466 def on_break(self, node):
467 """Break."""
468 return self.on_interrupt(node)
470 def on_continue(self, node):
471 """Continue."""
472 return self.on_interrupt(node)
474 def on_assert(self, node): # ('test', 'msg')
475 """Assert statement."""
476 if not self.run(node.test):
477 msg = node.msg.value if node.msg else ""
478 # msg = node.msg.s if node.msg else ""
479 self.raise_exception(node, exc=AssertionError, msg=msg)
480 return True
482 def on_list(self, node): # ('elt', 'ctx')
483 """List."""
484 return [self.run(e) for e in node.elts]
486 def on_tuple(self, node): # ('elts', 'ctx')
487 """Tuple."""
488 return tuple(self.on_list(node))
490 def on_set(self, node): # ('elts')
491 """Set."""
492 return set([self.run(k) for k in node.elts])
494 def on_dict(self, node): # ('keys', 'values')
495 """Dictionary."""
496 return {self.run(k): self.run(v) for k, v in
497 zip(node.keys, node.values)}
499 def on_constant(self, node): # ('value', 'kind')
500 """Return constant value."""
501 return node.value
503 def on_joinedstr(self, node): # ('values',)
504 "join strings, used in f-strings"
505 return ''.join([self.run(k) for k in node.values])
507 def on_formattedvalue(self, node): # ('value', 'conversion', 'format_spec')
508 "formatting used in f-strings"
509 val = self.run(node.value)
510 fstring_converters = {115: str, 114: repr, 97: ascii}
511 if node.conversion in fstring_converters:
512 val = fstring_converters[node.conversion](val)
513 fmt = '{__fstring__}'
514 if node.format_spec is not None:
515 fmt = f'{ __fstring__:{self.run(node.format_spec)}} '
516 return fmt.format(__fstring__=val)
518 def _getsym(self, node):
519 val = self.symtable.get(node.id, ReturnedNone)
520 if isinstance(val, Empty):
521 msg = f"name '{node.id}' is not defined"
522 self.raise_exception(node, exc=NameError, msg=msg)
523 return val
525 def on_name(self, node): # ('id', 'ctx')
526 """Name node."""
527 ctx = node.ctx.__class__
528 if ctx in (ast.Param, ast.Del):
529 return str(node.id)
530 return self._getsym(node)
532 def node_assign(self, node, val):
533 """Assign a value (not the node.value object) to a node.
535 This is used by on_assign, but also by for, list comprehension,
536 etc.
538 """
539 if node.__class__ == ast.Name:
540 if (not valid_symbol_name(node.id) or
541 node.id in self.readonly_symbols):
542 errmsg = f"invalid symbol name (reserved word?) {node.id}"
543 self.raise_exception(node, exc=NameError, msg=errmsg)
544 self.symtable[node.id] = val
545 if node.id in self.no_deepcopy:
546 self.no_deepcopy.remove(node.id)
548 elif node.__class__ == ast.Attribute:
549 if node.ctx.__class__ == ast.Load:
550 msg = f"cannot assign to attribute {node.attr}"
551 self.raise_exception(node, exc=AttributeError, msg=msg)
553 setattr(self.run(node.value), node.attr, val)
555 elif node.__class__ == ast.Subscript:
556 self.run(node.value)[self.run(node.slice)] = val
558 elif node.__class__ in (ast.Tuple, ast.List):
559 if len(val) == len(node.elts):
560 for telem, tval in zip(node.elts, val):
561 self.node_assign(telem, tval)
562 else:
563 raise ValueError('too many values to unpack')
565 def on_attribute(self, node): # ('value', 'attr', 'ctx')
566 """Extract attribute."""
568 ctx = node.ctx.__class__
569 if ctx == ast.Store:
570 msg = "attribute for storage: shouldn't be here!"
571 self.raise_exception(node, exc=RuntimeError, msg=msg)
573 sym = self.run(node.value)
574 if ctx == ast.Del:
575 return delattr(sym, node.attr)
576 #
577 unsafe = (node.attr in UNSAFE_ATTRS or
578 (node.attr.startswith('__') and node.attr.endswith('__')))
579 if not unsafe:
580 for dtype, attrlist in UNSAFE_ATTRS_DTYPES.items():
581 unsafe = isinstance(sym, dtype) and node.attr in attrlist
582 if unsafe:
583 break
584 if unsafe:
585 msg = f"no safe attribute '{node.attr}' for {repr(sym)}"
586 self.raise_exception(node, exc=AttributeError, msg=msg)
587 else:
588 try:
589 return getattr(sym, node.attr)
590 except AttributeError:
591 pass
594 def on_assign(self, node): # ('targets', 'value')
595 """Simple assignment."""
596 val = self.run(node.value)
597 for tnode in node.targets:
598 self.node_assign(tnode, val)
600 def on_augassign(self, node): # ('target', 'op', 'value')
601 """Augmented assign."""
602 return self.on_assign(ast.Assign(targets=[node.target],
603 value=ast.BinOp(left=node.target,
604 op=node.op,
605 right=node.value)))
607 def on_slice(self, node): # ():('lower', 'upper', 'step')
608 """Simple slice."""
609 return slice(self.run(node.lower),
610 self.run(node.upper),
611 self.run(node.step))
613 def on_extslice(self, node): # ():('dims',)
614 """Extended slice."""
615 return tuple([self.run(tnode) for tnode in node.dims])
617 def on_subscript(self, node): # ('value', 'slice', 'ctx')
618 """Subscript handling"""
619 return self.run(node.value)[self.run(node.slice)]
622 def on_delete(self, node): # ('targets',)
623 """Delete statement."""
624 for tnode in node.targets:
625 if tnode.ctx.__class__ != ast.Del:
626 break
627 children = []
628 while tnode.__class__ == ast.Attribute:
629 children.append(tnode.attr)
630 tnode = tnode.value
631 if (tnode.__class__ == ast.Name and
632 tnode.id not in self.readonly_symbols):
633 children.append(tnode.id)
634 children.reverse()
635 self.symtable.pop('.'.join(children))
636 elif tnode.__class__ == ast.Subscript:
637 nslice = self.run(tnode.slice)
638 children = []
639 tnode = tnode.value
640 while tnode.__class__ == ast.Attribute:
641 children.append(tnode.attr)
642 tnode = tnode.value
643 if (tnode.__class__ == ast.Name and not
644 tnode.id in self.readonly_symbols):
645 children.append(tnode.id)
646 children.reverse()
647 sname = '.'.join(children)
648 val = self.run(sname)
649 del val[nslice]
650 if len(children) == 1:
651 self.symtable[sname] = val
652 else:
653 child = self.symtable[children[0]]
654 for cname in children[1:-1]:
655 child = child[cname]
656 setattr(child, children[-1], val)
658 def on_unaryop(self, node): # ('op', 'operand')
659 """Unary operator."""
660 return op2func(node.op)(self.run(node.operand))
662 def on_binop(self, node): # ('left', 'op', 'right')
663 """Binary operator."""
664 return op2func(node.op)(self.run(node.left),
665 self.run(node.right))
667 def on_boolop(self, node): # ('op', 'values')
668 """Boolean operator."""
669 val = self.run(node.values[0])
670 is_and = ast.And == node.op.__class__
671 if (is_and and val) or (not is_and and not val):
672 for nodeval in node.values[1:]:
673 val = op2func(node.op)(val, self.run(nodeval))
674 if (is_and and not val) or (not is_and and val):
675 break
676 return val
678 def on_compare(self, node): # ('left', 'ops', 'comparators')
679 """comparison operators, including chained comparisons (a<b<c)"""
680 lval = self.run(node.left)
681 results = []
682 multi = len(node.ops) > 1
683 for oper, rnode in zip(node.ops, node.comparators):
684 rval = self.run(rnode)
685 ret = op2func(oper)(lval, rval)
686 if multi:
687 results.append(ret)
688 if not all(results):
689 return False
690 lval = rval
691 if multi:
692 ret = all(results)
693 return ret
695 def _printer(self, *out, **kws):
696 """Generic print function."""
697 if self.config.get('print', True):
698 flush = kws.pop('flush', True)
699 fileh = kws.pop('file', self.writer)
700 sep = kws.pop('sep', ' ')
701 end = kws.pop('sep', '\n')
702 print(*out, file=fileh, sep=sep, end=end)
703 if flush:
704 fileh.flush()
706 def on_if(self, node): # ('test', 'body', 'orelse')
707 """Regular if-then-else statement."""
708 block = node.body
709 if not self.run(node.test):
710 block = node.orelse
711 for tnode in block:
712 self.run(tnode)
714 def on_ifexp(self, node): # ('test', 'body', 'orelse')
715 """If expressions."""
716 expr = node.orelse
717 if self.run(node.test):
718 expr = node.body
719 return self.run(expr)
721 def on_while(self, node): # ('test', 'body', 'orelse')
722 """While blocks."""
723 while self.run(node.test):
724 self._interrupt = None
725 for tnode in node.body:
726 self.run(tnode)
727 if self._interrupt is not None:
728 break
729 if isinstance(self._interrupt, ast.Break):
730 break
731 else:
732 for tnode in node.orelse:
733 self.run(tnode)
734 self._interrupt = None
736 def on_for(self, node): # ('target', 'iter', 'body', 'orelse')
737 """For blocks."""
738 for val in self.run(node.iter):
739 self.node_assign(node.target, val)
740 self._interrupt = None
741 for tnode in node.body:
742 self.run(tnode)
743 if self._interrupt is not None:
744 break
745 if isinstance(self._interrupt, ast.Break):
746 break
747 else:
748 for tnode in node.orelse:
749 self.run(tnode)
750 self._interrupt = None
752 def on_with(self, node): # ('items', 'body', 'type_comment')
753 """with blocks."""
754 contexts = []
755 for item in node.items:
756 ctx = self.run(item.context_expr)
757 contexts.append(ctx)
758 if hasattr(ctx, '__enter__'):
759 result = ctx.__enter__()
760 if item.optional_vars is not None:
761 self.node_assign(item.optional_vars, result)
762 else:
763 msg = "object does not support the context manager protocol"
764 raise TypeError(f"'{type(ctx)}' {msg}")
765 for bnode in node.body:
766 self.run(bnode)
767 if self._interrupt is not None:
768 break
770 for ctx in contexts:
771 if hasattr(ctx, '__exit__'):
772 ctx.__exit__()
774 def _comp_save_syms(self, node):
775 """find and save symbols that will be used in a comprehension"""
776 saved_syms = {}
777 for tnode in node.generators:
778 if tnode.target.__class__ == ast.Name:
779 if (not valid_symbol_name(tnode.target.id) or
780 tnode.target.id in self.readonly_symbols):
781 errmsg = f"invalid symbol name (reserved word?) {tnode.target.id}"
782 self.raise_exception(tnode.target, exc=NameError, msg=errmsg)
783 if tnode.target.id in self.symtable:
784 saved_syms[tnode.target.id] = copy.deepcopy(self._getsym(tnode.target))
786 elif tnode.target.__class__ == ast.Tuple:
787 for tval in tnode.target.elts:
788 if tval.id in self.symtable:
789 saved_syms[tval.id] = copy.deepcopy(self._getsym(tval))
790 return saved_syms
793 def do_generator(self, gnodes, node, out):
794 """general purpose generator """
795 gnode = gnodes[0]
796 nametype = True
797 target = None
798 if gnode.target.__class__ == ast.Name:
799 if (not valid_symbol_name(gnode.target.id) or
800 gnode.target.id in self.readonly_symbols):
801 errmsg = f"invalid symbol name (reserved word?) {gnode.target.id}"
802 self.raise_exception(gnode.target, exc=NameError, msg=errmsg)
803 target = gnode.target.id
804 elif gnode.target.__class__ == ast.Tuple:
805 nametype = False
806 target = tuple([gval.id for gval in gnode.target.elts])
808 for val in self.run(gnode.iter):
809 if nametype and target is not None:
810 self.symtable[target] = val
811 else:
812 for telem, tval in zip(target, val):
813 self.symtable[telem] = tval
814 add = True
815 for cond in gnode.ifs:
816 add = add and self.run(cond)
817 if not add:
818 break
819 if add:
820 if len(gnodes) > 1:
821 self.do_generator(gnodes[1:], node, out)
822 elif isinstance(out, list):
823 out.append(self.run(node.elt))
824 elif isinstance(out, dict):
825 out[self.run(node.key)] = self.run(node.value)
827 def on_listcomp(self, node):
828 """List comprehension v2"""
829 saved_syms = self._comp_save_syms(node)
831 out = []
832 self.do_generator(node.generators, node, out)
833 for name, val in saved_syms.items():
834 self.symtable[name] = val
835 return out
837 def on_setcomp(self, node):
838 """Set comprehension"""
839 return set(self.on_listcomp(node))
841 def on_dictcomp(self, node):
842 """Dict comprehension v2"""
843 saved_syms = self._comp_save_syms(node)
845 out = {}
846 self.do_generator(node.generators, node, out)
847 for name, val in saved_syms.items():
848 self.symtable[name] = val
849 return out
851 def on_excepthandler(self, node): # ('type', 'name', 'body')
852 """Exception handler..."""
853 return (self.run(node.type), node.name, node.body)
855 def on_try(self, node): # ('body', 'handlers', 'orelse', 'finalbody')
856 """Try/except/else/finally blocks."""
857 no_errors = True
858 for tnode in node.body:
859 self.run(tnode, with_raise=False)
860 no_errors = no_errors and len(self.error) == 0
861 if len(self.error) > 0:
862 e_type, e_value, _ = self.error[-1].exc_info
863 for hnd in node.handlers:
864 htype = None
865 if hnd.type is not None:
866 htype = __builtins__.get(hnd.type.id, None)
867 if htype is None or isinstance(e_type(), htype):
868 self.error = []
869 if hnd.name is not None:
870 self.node_assign(hnd.name, e_value)
871 for tline in hnd.body:
872 self.run(tline)
873 break
874 break
875 if no_errors and hasattr(node, 'orelse'):
876 for tnode in node.orelse:
877 self.run(tnode)
879 if hasattr(node, 'finalbody'):
880 for tnode in node.finalbody:
881 self.run(tnode)
883 def on_raise(self, node): # ('type', 'inst', 'tback')
884 """Raise statement: note difference for python 2 and 3."""
885 excnode = node.exc
886 msgnode = node.cause
887 out = self.run(excnode)
888 msg = ' '.join(out.args)
889 msg2 = self.run(msgnode)
890 if msg2 not in (None, 'None'):
891 msg = f"{msg:s}: {msg2:s}"
892 self.raise_exception(None, exc=out.__class__, msg=msg, expr='')
894 def on_call(self, node):
895 """Function execution."""
896 func = self.run(node.func)
897 if not hasattr(func, '__call__') and not isinstance(func, type):
898 msg = f"'{func}' is not callable!!"
899 self.raise_exception(node, exc=TypeError, msg=msg)
900 args = [self.run(targ) for targ in node.args]
901 starargs = getattr(node, 'starargs', None)
902 if starargs is not None:
903 args = args + self.run(starargs)
905 keywords = {}
906 if func == print:
907 keywords['file'] = self.writer
908 for key in node.keywords:
909 if not isinstance(key, ast.keyword):
910 msg = f"keyword error in function call '{func}'"
911 self.raise_exception(node, msg=msg)
912 if key.arg is None:
913 keywords.update(self.run(key.value))
914 elif key.arg in keywords:
915 self.raise_exception(node, exc=SyntaxError,
916 msg=f"keyword argument repeated: {key.arg}")
917 else:
918 keywords[key.arg] = self.run(key.value)
920 kwargs = getattr(node, 'kwargs', None)
921 if kwargs is not None:
922 keywords.update(self.run(kwargs))
924 if isinstance(func, Procedure):
925 self._calldepth += 1
926 try:
927 out = func(*args, **keywords)
928 except Exception as ex:
929 out = None
930 func_name = getattr(func, '__name__', str(func))
931 msg = f"Error running function '{func_name}' with args '{args}'"
932 msg = f"{msg} and kwargs {keywords}: {ex}"
933 self.raise_exception(node, msg=msg)
934 finally:
935 if isinstance(func, Procedure):
936 self._calldepth -= 1
937 return out
939 def on_arg(self, node): # ('test', 'msg')
940 """Arg for function definitions."""
941 return node.arg
943 def on_functiondef(self, node):
944 """Define procedures."""
945 # ('name', 'args', 'body', 'decorator_list')
946 if node.decorator_list:
947 raise Warning("decorated procedures not supported!")
948 kwargs = []
950 if (not valid_symbol_name(node.name) or
951 node.name in self.readonly_symbols):
952 errmsg = f"invalid function name (reserved word?) {node.name}"
953 self.raise_exception(node, exc=NameError, msg=errmsg)
955 offset = len(node.args.args) - len(node.args.defaults)
956 for idef, defnode in enumerate(node.args.defaults):
957 defval = self.run(defnode)
958 keyval = self.run(node.args.args[idef+offset])
959 kwargs.append((keyval, defval))
961 args = [tnode.arg for tnode in node.args.args[:offset]]
962 doc = None
963 nb0 = node.body[0]
964 if isinstance(nb0, ast.Expr) and isinstance(nb0.value, ast.Constant):
965 doc = nb0.value
966 varkws = node.args.kwarg
967 vararg = node.args.vararg
968 if isinstance(vararg, ast.arg):
969 vararg = vararg.arg
970 if isinstance(varkws, ast.arg):
971 varkws = varkws.arg
972 self.symtable[node.name] = Procedure(node.name, self, doc=doc,
973 lineno=self.lineno,
974 body=node.body,
975 args=args, kwargs=kwargs,
976 vararg=vararg, varkws=varkws)
977 if node.name in self.no_deepcopy:
978 self.no_deepcopy.remove(node.name)