Coverage for asteval/asteval.py: 93%
623 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-07 10:50 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-07 10:50 +0000
1#!/usr/bin/env python
2"""
3Safe(ish) evaluation of minimal Python code using Python's ast module.
5This module provides an Interpreter class that compiles a restricted set of
6Python expressions and statements to Python's AST representation, and then
7executes that representation using values held in a symbol table.
9The symbol table is a simple dictionary, giving a flat namespace. This comes
10pre-loaded with many functions from Python's builtin and math module. If numpy
11is installed, many numpy functions are also included. Additional symbols can
12be added when an Interpreter is created, but the user of that interpreter will
13not be able to import additional modules.
15Expressions, including loops, conditionals, and function definitions can be
16compiled into ast node and then evaluated later, using the current values
17in the symbol table.
19The result is a restricted, simplified version of Python meant for numerical
20calculations that is somewhat safer than 'eval' because many unsafe operations
21(such as 'eval') are simply not allowed, and others (such as 'import') are
22disabled by default, but can be explicitly enabled.
24Many parts of Python syntax are supported, including:
25 for loops, while loops, if-then-elif-else conditionals, with,
26 try-except-finally
27 function definitions with def
28 advanced slicing: a[::-1], array[-3:, :, ::2]
29 if-expressions: out = one_thing if TEST else other
30 list, dict, and set comprehension
32The following Python syntax elements are not supported:
33 Import, Exec, Lambda, Class, Global, Generators,
34 Yield, Decorators
36In addition, while many builtin functions are supported, several builtin
37functions that are considered unsafe are missing ('eval', 'exec', and
38'getattr' for example) are missing.
39"""
40import ast
41import sys
42import copy
43import inspect
44import time
45from sys import exc_info, stderr, stdout
47from .astutils import (HAS_NUMPY,
48 ExceptionHolder, ReturnedNone, Empty, make_symbol_table,
49 numpy, op2func, safe_getattr, safe_format, valid_symbol_name, Procedure)
51ALL_NODES = ['arg', 'assert', 'assign', 'attribute', 'augassign', 'binop',
52 'boolop', 'break', 'bytes', 'call', 'compare', 'constant',
53 'continue', 'delete', 'dict', 'dictcomp', 'ellipsis',
54 'excepthandler', 'expr', 'extslice', 'for', 'functiondef', 'if',
55 'ifexp', 'import', 'importfrom', 'index', 'interrupt', 'list',
56 'listcomp', 'module', 'name', 'nameconstant', 'num', 'pass',
57 'raise', 'repr', 'return', 'set', 'setcomp', 'slice', 'str',
58 'subscript', 'try', 'tuple', 'unaryop', 'while', 'with',
59 'formattedvalue', 'joinedstr']
62MINIMAL_CONFIG = {'import': False, 'importfrom': False}
63DEFAULT_CONFIG = {'import': False, 'importfrom': False}
65for _tnode in ('assert', 'augassign', 'delete', 'if', 'ifexp', 'for',
66 'formattedvalue', 'functiondef', 'print', 'raise', 'listcomp',
67 'dictcomp', 'setcomp', 'try', 'while', 'with'):
68 MINIMAL_CONFIG[_tnode] = False
69 DEFAULT_CONFIG[_tnode] = True
71class Interpreter:
72 """create an asteval Interpreter: a restricted, simplified interpreter
73 of mathematical expressions using Python syntax.
75 Parameters
76 ----------
77 symtable : dict or `None`
78 dictionary or SymbolTable to use as symbol table (if `None`, one will be created).
79 nested_symtable : bool, optional
80 whether to use a new-style nested symbol table instead of a plain dict [False]
81 user_symbols : dict or `None`
82 dictionary of user-defined symbols to add to symbol table.
83 writer : file-like or `None`
84 callable file-like object where standard output will be sent.
85 err_writer : file-like or `None`
86 callable file-like object where standard error will be sent.
87 use_numpy : bool
88 whether to use functions from numpy.
89 max_statement_length : int
90 maximum length of expression allowed [50,000 characters]
91 readonly_symbols : iterable or `None`
92 symbols that the user can not assign to
93 builtins_readonly : bool
94 whether to blacklist all symbols that are in the initial symtable
95 minimal : bool
96 create a minimal interpreter: disable many nodes (see Note 1).
97 config : dict
98 dictionay listing which nodes to support (see note 2))
100 Notes
101 -----
102 1. setting `minimal=True` is equivalent to setting a config with the following
103 nodes disabled: ('import', 'importfrom', 'if', 'for', 'while', 'try', 'with',
104 'functiondef', 'ifexp', 'listcomp', 'dictcomp', 'setcomp', 'augassign',
105 'assert', 'delete', 'raise', 'print')
106 2. by default 'import' and 'importfrom' are disabled, though they can be enabled.
107 """
108 def __init__(self, symtable=None, nested_symtable=False,
109 user_symbols=None, writer=None, err_writer=None,
110 use_numpy=True, max_statement_length=50000,
111 minimal=False, readonly_symbols=None,
112 builtins_readonly=False, config=None, **kws):
114 self.config = copy.copy(MINIMAL_CONFIG if minimal else DEFAULT_CONFIG)
115 if config is not None:
116 self.config.update(config)
117 self.config['nested_symtable'] = nested_symtable
119 if user_symbols is None:
120 user_symbols = {}
121 if 'usersyms' in kws:
122 user_symbols = kws.pop('usersyms') # back compat, changed July, 2023, v 0.9.4
124 if len(kws) > 0:
125 for key, val in kws.items():
126 if key.startswith('no_'):
127 node = key[3:]
128 if node in ALL_NODES:
129 self.config[node] = not val
130 elif key.startswith('with_'):
131 node = key[5:]
132 if node in ALL_NODES:
133 self.config[node] = val
135 self.writer = writer or stdout
136 self.err_writer = err_writer or stderr
137 self.max_statement_length = max(1, min(1.e8, max_statement_length))
139 self.use_numpy = HAS_NUMPY and use_numpy
140 if symtable is None:
141 symtable = make_symbol_table(nested=nested_symtable,
142 use_numpy=self.use_numpy, **user_symbols)
144 symtable['print'] = self._printer
145 self.symtable = symtable
146 self._interrupt = None
147 self.error = []
148 self.error_msg = None
149 self.expr = None
150 self.retval = None
151 self._calldepth = 0
152 self.lineno = 0
153 self.code_text = []
154 self.start_time = time.time()
156 self.node_handlers = {}
157 for node in ALL_NODES:
158 handler = self.unimplemented
159 if self.config.get(node, True):
160 handler = getattr(self, f"on_{node}", self.unimplemented)
161 self.node_handlers[node] = handler
163 # to rationalize try/except try/finally
164 if 'try' in self.node_handlers:
165 self.node_handlers['tryexcept'] = self.node_handlers['try']
166 self.node_handlers['tryfinally'] = self.node_handlers['try']
168 if readonly_symbols is None:
169 self.readonly_symbols = set()
170 else:
171 self.readonly_symbols = set(readonly_symbols)
173 if builtins_readonly:
174 self.readonly_symbols |= set(self.symtable)
176 self.no_deepcopy = [key for key, val in symtable.items()
177 if (callable(val)
178 or inspect.ismodule(val)
179 or 'numpy.lib.index_tricks' in repr(type(val)))]
181 def remove_nodehandler(self, node):
182 """remove support for a node
183 returns current node handler, so that it
184 might be re-added with add_nodehandler()
185 """
186 out = None
187 if node in self.node_handlers:
188 out = self.node_handlers.pop(node)
189 return out
191 def set_nodehandler(self, node, handler=None):
192 """set node handler or use current built-in default"""
193 if handler is None:
194 handler = getattr(self, f"on_{node}", self.unimplemented)
195 self.node_handlers[node] = handler
196 return handler
198 def user_defined_symbols(self):
199 """Return a set of symbols that have been added to symtable after
200 construction.
202 I.e., the symbols from self.symtable that are not in
203 self.no_deepcopy.
205 Returns
206 -------
207 unique_symbols : set
208 symbols in symtable that are not in self.no_deepcopy
210 """
211 sym_in_current = set(self.symtable.keys())
212 sym_from_construction = set(self.no_deepcopy)
213 unique_symbols = sym_in_current.difference(sym_from_construction)
214 return unique_symbols
216 def unimplemented(self, node):
217 """Unimplemented nodes."""
218 msg = f"{node.__class__.__name__} not supported"
219 self.raise_exception(node, exc=NotImplementedError, msg=msg)
221 def raise_exception(self, node, exc=None, msg='', expr=None, lineno=None):
222 """Add an exception."""
223 if expr is not None:
224 self.expr = expr
226 msg = str(msg)
227 text = self.expr
228 if len(self.code_text) > 0:
229 text = self.code_text[-1]
230 err = ExceptionHolder(node, exc=exc, msg=msg, expr=self.expr,
231 text=text, lineno=lineno)
232 self._interrupt = ast.Raise()
234 self.error.append(err)
235 if self.error_msg is None:
236 self.error_msg = msg
237 elif len(msg) > 0:
238 pass
239 # if err.exc is not None:
240 # self.error_msg = f"{err.exc.__name__}: {msg}"
241 if exc is None:
242 exc = self.error[-1].exc
243 if exc is None and len(self.error) > 0:
244 while exc is None and len(self.error) > 0:
245 err = self.error.pop()
246 exc = err.exc
248 if exc is None:
249 exc = Exception
250 if len(err.msg) == 0 and len(self.error_msg) == 0 and len(self.error) > 1:
251 err = self.error.pop(-1)
252 raise err.exc(err.msg)
253 else:
254 if len(err.msg) == 0:
255 err.msg = self.error_msg
256 raise exc(self.error_msg)
258 # main entry point for Ast node evaluation
259 # parse: text of statements -> ast
260 # run: ast -> result
261 # eval: string statement -> result = run(parse(statement))
262 def parse(self, text):
263 """Parse statement/expression to Ast representation."""
264 if len(text) > self.max_statement_length:
265 msg = f'length of text exceeds {self.max_statement_length:d} characters'
266 self.raise_exception(None, exc=RuntimeError, expr=msg)
267 self.expr = text
268 try:
269 out = ast.parse(text)
270 except SyntaxError:
271 self.raise_exception(None, exc=SyntaxError, expr=text)
272 except:
273 self.raise_exception(None, exc=RuntimeError, expr=text)
274 out = ast.fix_missing_locations(out)
275 return out
277 def run(self, node, expr=None, lineno=None, with_raise=True):
278 """Execute parsed Ast representation for an expression."""
279 # Note: keep the 'node is None' test: internal code here may run
280 # run(None) and expect a None in return.
281 if isinstance(node, str):
282 return self.eval(node, raise_errors=with_raise)
283 out = None
284 if len(self.error) > 0:
285 return out
286 if self.retval is not None:
287 return self.retval
288 if isinstance(self._interrupt, (ast.Break, ast.Continue)):
289 return self._interrupt
290 if node is None:
291 return out
293 if lineno is not None:
294 self.lineno = lineno
295 if expr is not None:
296 self.expr = expr
297 self.code_text.append(expr)
299 # get handler for this node:
300 # on_xxx with handle nodes of type 'xxx', etc
301 try:
302 handler = self.node_handlers[node.__class__.__name__.lower()]
303 except KeyError:
304 self.raise_exception(None, exc=NotImplementedError, expr=self.expr)
307 # run the handler: this will likely generate
308 # recursive calls into this run method.
309 try:
310 ret = handler(node)
311 if isinstance(ret, enumerate):
312 ret = list(ret)
313 return ret
314 except:
315 if with_raise and self.expr is not None:
316 self.raise_exception(node, expr=self.expr)
319 # avoid too many repeated error messages (yes, this needs to be "2")
320 if len(self.error) > 2:
321 self._remove_duplicate_errors()
323 return None
325 def _remove_duplicate_errors(self):
326 """remove duplicate exceptions"""
327 error = [self.error[0]]
328 for err in self.error[1:]:
329 lerr = error[-1]
330 if err.exc != lerr.exc or err.expr != lerr.expr or err.msg != lerr.msg:
331 if isinstance(err.msg, str) and len(err.msg) > 0:
332 error.append(err)
333 self.error = error
335 def __call__(self, expr, **kw):
336 """Call class instance as function."""
337 return self.eval(expr, **kw)
339 def eval(self, expr, lineno=0, show_errors=True, raise_errors=False):
340 """Evaluate a single statement."""
341 self.lineno = lineno
342 self.error = []
343 self.error_msg = None
344 self.start_time = time.time()
345 if isinstance(expr, str):
346 try:
347 node = self.parse(expr)
348 except Exception:
349 errmsg = exc_info()[1]
350 if len(self.error) > 0:
351 lerr = self.error[-1]
352 errmsg = lerr.get_error()[1]
353 if raise_errors:
354 raise lerr.exc(errmsg)
355 if show_errors:
356 print(errmsg, file=self.err_writer)
357 return None
358 else:
359 node = expr
360 try:
361 return self.run(node, expr=expr, lineno=lineno, with_raise=raise_errors)
362 except Exception:
363 if show_errors and not raise_errors:
364 errmsg = exc_info()[1]
365 if len(self.error) > 0:
366 errmsg = self.error[-1].get_error()[1]
367 print(errmsg, file=self.err_writer)
368 if raise_errors and len(self.error) > 0:
369 self._remove_duplicate_errors()
370 err = self.error[-1]
371 raise err.exc(err.get_error()[1])
372 return None
374 @staticmethod
375 def dump(node, **kw):
376 """Simple ast dumper."""
377 return ast.dump(node, **kw)
379 # handlers for ast components
380 def on_expr(self, node):
381 """Expression."""
382 return self.run(node.value) # ('value',)
384 # imports
385 def on_import(self, node): # ('names',)
386 "simple import"
387 for tnode in node.names:
388 self.import_module(tnode.name, tnode.asname)
390 def on_importfrom(self, node): # ('module', 'names', 'level')
391 "import/from"
392 fromlist, asname = [], []
393 for tnode in node.names:
394 fromlist.append(tnode.name)
395 asname.append(tnode.asname)
396 self.import_module(node.module, asname, fromlist=fromlist)
398 def import_module(self, name, asname, fromlist=None):
399 """import a python module, installing it into the symbol table.
400 options:
401 name name of module to import 'foo' in 'import foo'
402 asname alias for imported name(s)
403 'bar' in 'import foo as bar'
404 or
405 ['s','t'] in 'from foo import x as s, y as t'
406 fromlist list of symbols to import with 'from-import'
407 ['x','y'] in 'from foo import x, y'
408 """
409 # find module in sys.modules or import to it
410 if name in sys.modules:
411 thismod = sys.modules[name]
412 else:
413 try:
414 __import__(name)
415 thismod = sys.modules[name]
416 except:
417 self.raise_exception(None, exc=ImportError, msg='Import Error')
419 if fromlist is None:
420 if asname is not None:
421 self.symtable[asname] = sys.modules[name]
422 else:
423 mparts = []
424 parts = name.split('.')
425 while len(parts) > 0:
426 mparts.append(parts.pop(0))
427 modname = '.'.join(mparts)
428 inname = name if (len(parts) == 0) else modname
429 self.symtable[inname] = sys.modules[modname]
430 else: # import-from construct
431 if asname is None:
432 asname = [None]*len(fromlist)
433 for sym, alias in zip(fromlist, asname):
434 if alias is None:
435 alias = sym
436 self.symtable[alias] = getattr(thismod, sym)
438 def on_index(self, node):
439 """Index."""
440 return self.run(node.value) # ('value',)
442 def on_return(self, node): # ('value',)
443 """Return statement: look for None, return special sentinel."""
444 if self._calldepth == 0:
445 raise SyntaxError('cannot return at top level')
446 self.retval = self.run(node.value)
447 if self.retval is None:
448 self.retval = ReturnedNone
450 def on_repr(self, node):
451 """Repr."""
452 return repr(self.run(node.value)) # ('value',)
454 def on_module(self, node): # ():('body',)
455 """Module def."""
456 out = None
457 for tnode in node.body:
458 out = self.run(tnode)
459 return out
461 def on_expression(self, node):
462 "basic expression"
463 return self.on_module(node) # ():('body',)
465 def on_pass(self, node):
466 """Pass statement."""
467 return None # ()
469 # for break and continue: set the instance variable _interrupt
470 def on_interrupt(self, node): # ()
471 """Interrupt handler."""
472 self._interrupt = node
473 return node
475 def on_break(self, node):
476 """Break."""
477 return self.on_interrupt(node)
479 def on_continue(self, node):
480 """Continue."""
481 return self.on_interrupt(node)
483 def on_assert(self, node): # ('test', 'msg')
484 """Assert statement."""
485 if not self.run(node.test):
486 msg = node.msg.value if node.msg else ""
487 # msg = node.msg.s if node.msg else ""
488 self.raise_exception(node, exc=AssertionError, msg=msg)
489 return True
491 def on_list(self, node): # ('elt', 'ctx')
492 """List."""
493 return [self.run(e) for e in node.elts]
495 def on_tuple(self, node): # ('elts', 'ctx')
496 """Tuple."""
497 return tuple(self.on_list(node))
499 def on_set(self, node): # ('elts')
500 """Set."""
501 return set([self.run(k) for k in node.elts])
503 def on_dict(self, node): # ('keys', 'values')
504 """Dictionary."""
505 return {self.run(k): self.run(v) for k, v in
506 zip(node.keys, node.values)}
508 def on_constant(self, node): # ('value', 'kind')
509 """Return constant value."""
510 return node.value
512 def on_joinedstr(self, node): # ('values',)
513 "join strings, used in f-strings"
514 return ''.join([self.run(k) for k in node.values])
516 def on_formattedvalue(self, node): # ('value', 'conversion', 'format_spec')
517 "formatting used in f-strings"
518 val = self.run(node.value)
519 fstring_converters = {115: str, 114: repr, 97: ascii}
520 if node.conversion in fstring_converters:
521 val = fstring_converters[node.conversion](val)
522 fmt = '{__fstring__}'
523 if node.format_spec is not None:
524 fmt = f'{ __fstring__:{self.run(node.format_spec)}} '
525 return safe_format(fmt, self.raise_exception, node, __fstring__=val)
527 def _getsym(self, node):
528 val = self.symtable.get(node.id, ReturnedNone)
529 if isinstance(val, Empty):
530 msg = f"name '{node.id}' is not defined"
531 self.raise_exception(node, exc=NameError, msg=msg)
532 return val
534 def on_name(self, node): # ('id', 'ctx')
535 """Name node."""
536 ctx = node.ctx.__class__
537 if ctx in (ast.Param, ast.Del):
538 return str(node.id)
539 return self._getsym(node)
541 def node_assign(self, node, val):
542 """Assign a value (not the node.value object) to a node.
544 This is used by on_assign, but also by for, list comprehension,
545 etc.
547 """
548 if node.__class__ == ast.Name:
549 if (not valid_symbol_name(node.id) or
550 node.id in self.readonly_symbols):
551 errmsg = f"invalid symbol name (reserved word?) {node.id}"
552 self.raise_exception(node, exc=NameError, msg=errmsg)
553 self.symtable[node.id] = val
554 if node.id in self.no_deepcopy:
555 self.no_deepcopy.remove(node.id)
557 elif node.__class__ == ast.Attribute:
558 if node.ctx.__class__ == ast.Load:
559 msg = f"cannot assign to attribute {node.attr}"
560 self.raise_exception(node, exc=AttributeError, msg=msg)
562 setattr(self.run(node.value), node.attr, val)
564 elif node.__class__ == ast.Subscript:
565 self.run(node.value)[self.run(node.slice)] = val
567 elif node.__class__ in (ast.Tuple, ast.List):
568 if len(val) == len(node.elts):
569 for telem, tval in zip(node.elts, val):
570 self.node_assign(telem, tval)
571 else:
572 raise ValueError('too many values to unpack')
574 def on_attribute(self, node): # ('value', 'attr', 'ctx')
575 """Extract attribute."""
577 ctx = node.ctx.__class__
578 if ctx == ast.Store:
579 msg = "attribute for storage: shouldn't be here!"
580 self.raise_exception(node, exc=RuntimeError, msg=msg)
582 sym = self.run(node.value)
583 if ctx == ast.Del:
584 return delattr(sym, node.attr)
586 return safe_getattr(sym, node.attr, self.raise_exception, node)
589 def on_assign(self, node): # ('targets', 'value')
590 """Simple assignment."""
591 val = self.run(node.value)
592 for tnode in node.targets:
593 self.node_assign(tnode, val)
595 def on_augassign(self, node): # ('target', 'op', 'value')
596 """Augmented assign."""
597 return self.on_assign(ast.Assign(targets=[node.target],
598 value=ast.BinOp(left=node.target,
599 op=node.op,
600 right=node.value)))
602 def on_slice(self, node): # ():('lower', 'upper', 'step')
603 """Simple slice."""
604 return slice(self.run(node.lower),
605 self.run(node.upper),
606 self.run(node.step))
608 def on_extslice(self, node): # ():('dims',)
609 """Extended slice."""
610 return tuple([self.run(tnode) for tnode in node.dims])
612 def on_subscript(self, node): # ('value', 'slice', 'ctx')
613 """Subscript handling"""
614 return self.run(node.value)[self.run(node.slice)]
617 def on_delete(self, node): # ('targets',)
618 """Delete statement."""
619 for tnode in node.targets:
620 if tnode.ctx.__class__ != ast.Del:
621 break
622 children = []
623 while tnode.__class__ == ast.Attribute:
624 children.append(tnode.attr)
625 tnode = tnode.value
626 if (tnode.__class__ == ast.Name and
627 tnode.id not in self.readonly_symbols):
628 children.append(tnode.id)
629 children.reverse()
630 self.symtable.pop('.'.join(children))
631 elif tnode.__class__ == ast.Subscript:
632 nslice = self.run(tnode.slice)
633 children = []
634 tnode = tnode.value
635 while tnode.__class__ == ast.Attribute:
636 children.append(tnode.attr)
637 tnode = tnode.value
638 if (tnode.__class__ == ast.Name and not
639 tnode.id in self.readonly_symbols):
640 children.append(tnode.id)
641 children.reverse()
642 sname = '.'.join(children)
643 val = self.run(sname)
644 del val[nslice]
645 if len(children) == 1:
646 self.symtable[sname] = val
647 else:
648 child = self.symtable[children[0]]
649 for cname in children[1:-1]:
650 child = child[cname]
651 setattr(child, children[-1], val)
653 def on_unaryop(self, node): # ('op', 'operand')
654 """Unary operator."""
655 return op2func(node.op)(self.run(node.operand))
657 def on_binop(self, node): # ('left', 'op', 'right')
658 """Binary operator."""
659 return op2func(node.op)(self.run(node.left),
660 self.run(node.right))
662 def on_boolop(self, node): # ('op', 'values')
663 """Boolean operator."""
664 val = self.run(node.values[0])
665 is_and = ast.And == node.op.__class__
666 if (is_and and val) or (not is_and and not val):
667 for nodeval in node.values[1:]:
668 val = op2func(node.op)(val, self.run(nodeval))
669 if (is_and and not val) or (not is_and and val):
670 break
671 return val
673 def on_compare(self, node): # ('left', 'ops', 'comparators')
674 """comparison operators, including chained comparisons (a<b<c)"""
675 lval = self.run(node.left)
676 results = []
677 multi = len(node.ops) > 1
678 for oper, rnode in zip(node.ops, node.comparators):
679 rval = self.run(rnode)
680 ret = op2func(oper)(lval, rval)
681 if multi:
682 results.append(ret)
683 if not all(results):
684 return False
685 lval = rval
686 if multi:
687 ret = all(results)
688 return ret
690 def _printer(self, *out, **kws):
691 """Generic print function."""
692 if self.config.get('print', True):
693 flush = kws.pop('flush', True)
694 fileh = kws.pop('file', self.writer)
695 sep = kws.pop('sep', ' ')
696 end = kws.pop('sep', '\n')
697 print(*out, file=fileh, sep=sep, end=end)
698 if flush:
699 fileh.flush()
701 def on_if(self, node): # ('test', 'body', 'orelse')
702 """Regular if-then-else statement."""
703 block = node.body
704 if not self.run(node.test):
705 block = node.orelse
706 for tnode in block:
707 self.run(tnode)
709 def on_ifexp(self, node): # ('test', 'body', 'orelse')
710 """If expressions."""
711 expr = node.orelse
712 if self.run(node.test):
713 expr = node.body
714 return self.run(expr)
716 def on_while(self, node): # ('test', 'body', 'orelse')
717 """While blocks."""
718 while self.run(node.test):
719 self._interrupt = None
720 for tnode in node.body:
721 self.run(tnode)
722 if self._interrupt is not None:
723 break
724 if isinstance(self._interrupt, ast.Break):
725 break
726 else:
727 for tnode in node.orelse:
728 self.run(tnode)
729 self._interrupt = None
731 def on_for(self, node): # ('target', 'iter', 'body', 'orelse')
732 """For blocks."""
733 for val in self.run(node.iter):
734 self.node_assign(node.target, val)
735 self._interrupt = None
736 for tnode in node.body:
737 self.run(tnode)
738 if self._interrupt is not None:
739 break
740 if isinstance(self._interrupt, ast.Break):
741 break
742 else:
743 for tnode in node.orelse:
744 self.run(tnode)
745 self._interrupt = None
747 def on_with(self, node): # ('items', 'body', 'type_comment')
748 """with blocks."""
749 contexts = []
750 for item in node.items:
751 ctx = self.run(item.context_expr)
752 contexts.append(ctx)
753 if hasattr(ctx, '__enter__'):
754 result = ctx.__enter__()
755 if item.optional_vars is not None:
756 self.node_assign(item.optional_vars, result)
757 else:
758 msg = "object does not support the context manager protocol"
759 raise TypeError(f"'{type(ctx)}' {msg}")
760 for bnode in node.body:
761 self.run(bnode)
762 if self._interrupt is not None:
763 break
765 for ctx in contexts:
766 if hasattr(ctx, '__exit__'):
767 ctx.__exit__()
769 def _comp_save_syms(self, node):
770 """find and save symbols that will be used in a comprehension"""
771 saved_syms = {}
772 for tnode in node.generators:
773 if tnode.target.__class__ == ast.Name:
774 if (not valid_symbol_name(tnode.target.id) or
775 tnode.target.id in self.readonly_symbols):
776 errmsg = f"invalid symbol name (reserved word?) {tnode.target.id}"
777 self.raise_exception(tnode.target, exc=NameError, msg=errmsg)
778 if tnode.target.id in self.symtable:
779 saved_syms[tnode.target.id] = copy.deepcopy(self._getsym(tnode.target))
781 elif tnode.target.__class__ == ast.Tuple:
782 for tval in tnode.target.elts:
783 if tval.id in self.symtable:
784 saved_syms[tval.id] = copy.deepcopy(self._getsym(tval))
785 return saved_syms
788 def do_generator(self, gnodes, node, out):
789 """general purpose generator """
790 gnode = gnodes[0]
791 nametype = True
792 target = None
793 if gnode.target.__class__ == ast.Name:
794 if (not valid_symbol_name(gnode.target.id) or
795 gnode.target.id in self.readonly_symbols):
796 errmsg = f"invalid symbol name (reserved word?) {gnode.target.id}"
797 self.raise_exception(gnode.target, exc=NameError, msg=errmsg)
798 target = gnode.target.id
799 elif gnode.target.__class__ == ast.Tuple:
800 nametype = False
801 target = tuple([gval.id for gval in gnode.target.elts])
803 for val in self.run(gnode.iter):
804 if nametype and target is not None:
805 self.symtable[target] = val
806 else:
807 for telem, tval in zip(target, val):
808 self.symtable[telem] = tval
809 add = True
810 for cond in gnode.ifs:
811 add = add and self.run(cond)
812 if not add:
813 break
814 if add:
815 if len(gnodes) > 1:
816 self.do_generator(gnodes[1:], node, out)
817 elif isinstance(out, list):
818 out.append(self.run(node.elt))
819 elif isinstance(out, dict):
820 out[self.run(node.key)] = self.run(node.value)
822 def on_listcomp(self, node):
823 """List comprehension v2"""
824 saved_syms = self._comp_save_syms(node)
826 out = []
827 self.do_generator(node.generators, node, out)
828 for name, val in saved_syms.items():
829 self.symtable[name] = val
830 return out
832 def on_setcomp(self, node):
833 """Set comprehension"""
834 return set(self.on_listcomp(node))
836 def on_dictcomp(self, node):
837 """Dict comprehension v2"""
838 saved_syms = self._comp_save_syms(node)
840 out = {}
841 self.do_generator(node.generators, node, out)
842 for name, val in saved_syms.items():
843 self.symtable[name] = val
844 return out
846 def on_excepthandler(self, node): # ('type', 'name', 'body')
847 """Exception handler..."""
848 return (self.run(node.type), node.name, node.body)
850 def on_try(self, node): # ('body', 'handlers', 'orelse', 'finalbody')
851 """Try/except/else/finally blocks."""
852 no_errors = True
853 for tnode in node.body:
854 self.run(tnode, with_raise=False)
855 no_errors = no_errors and len(self.error) == 0
856 if len(self.error) > 0:
857 e_type, e_value, _ = self.error[-1].exc_info
858 for hnd in node.handlers:
859 htype = None
860 if hnd.type is not None:
861 htype = __builtins__.get(hnd.type.id, None)
862 if htype is None or isinstance(e_type(), htype):
863 self.error = []
864 if hnd.name is not None:
865 self.symtable[hnd.name] = e_value
866 for tline in hnd.body:
867 self.run(tline)
868 break
869 break
870 if no_errors and hasattr(node, 'orelse'):
871 for tnode in node.orelse:
872 self.run(tnode)
874 if hasattr(node, 'finalbody'):
875 for tnode in node.finalbody:
876 self.run(tnode)
878 def on_raise(self, node): # ('type', 'inst', 'tback')
879 """Raise statement: note difference for python 2 and 3."""
880 excnode = node.exc
881 msgnode = node.cause
882 out = self.run(excnode)
883 msg = ' '.join(out.args)
884 msg2 = self.run(msgnode)
885 if msg2 not in (None, 'None'):
886 msg = f"{msg:s}: {msg2:s}"
887 self.raise_exception(None, exc=out.__class__, msg=msg, expr='')
889 def on_call(self, node):
890 """Function execution."""
891 func = self.run(node.func)
892 if not hasattr(func, '__call__') and not isinstance(func, type):
893 msg = f"'{func}' is not callable!!"
894 self.raise_exception(node, exc=TypeError, msg=msg)
895 args = [self.run(targ) for targ in node.args]
896 starargs = getattr(node, 'starargs', None)
897 if starargs is not None:
898 args = args + self.run(starargs)
900 keywords = {}
901 if func == print:
902 keywords['file'] = self.writer
903 for key in node.keywords:
904 if not isinstance(key, ast.keyword):
905 msg = f"keyword error in function call '{func}'"
906 self.raise_exception(node, msg=msg)
907 if key.arg is None:
908 keywords.update(self.run(key.value))
909 elif key.arg in keywords:
910 self.raise_exception(node, exc=SyntaxError,
911 msg=f"keyword argument repeated: {key.arg}")
912 else:
913 keywords[key.arg] = self.run(key.value)
915 kwargs = getattr(node, 'kwargs', None)
916 if kwargs is not None:
917 keywords.update(self.run(kwargs))
919 if isinstance(func, Procedure):
920 self._calldepth += 1
921 try:
922 out = func(*args, **keywords)
923 except Exception as ex:
924 out = None
925 func_name = getattr(func, '__name__', str(func))
926 msg = f"Error running function '{func_name}' with args '{args}'"
927 msg = f"{msg} and kwargs {keywords}: {ex}"
928 self.raise_exception(node, msg=msg)
929 finally:
930 if isinstance(func, Procedure):
931 self._calldepth -= 1
932 return out
934 def on_arg(self, node): # ('test', 'msg')
935 """Arg for function definitions."""
936 return node.arg
938 def on_functiondef(self, node):
939 """Define procedures."""
940 # ('name', 'args', 'body', 'decorator_list')
941 if node.decorator_list:
942 raise Warning("decorated procedures not supported!")
943 kwargs = []
945 if (not valid_symbol_name(node.name) or
946 node.name in self.readonly_symbols):
947 errmsg = f"invalid function name (reserved word?) {node.name}"
948 self.raise_exception(node, exc=NameError, msg=errmsg)
950 offset = len(node.args.args) - len(node.args.defaults)
951 for idef, defnode in enumerate(node.args.defaults):
952 defval = self.run(defnode)
953 keyval = self.run(node.args.args[idef+offset])
954 kwargs.append((keyval, defval))
956 args = [tnode.arg for tnode in node.args.args[:offset]]
957 doc = None
958 nb0 = node.body[0]
959 if isinstance(nb0, ast.Expr) and isinstance(nb0.value, ast.Constant):
960 doc = nb0.value
961 varkws = node.args.kwarg
962 vararg = node.args.vararg
963 if isinstance(vararg, ast.arg):
964 vararg = vararg.arg
965 if isinstance(varkws, ast.arg):
966 varkws = varkws.arg
967 self.symtable[node.name] = Procedure(node.name, self, doc=doc,
968 lineno=self.lineno,
969 body=node.body,
970 text=ast.unparse(node),
971 args=args, kwargs=kwargs,
972 vararg=vararg, varkws=varkws)
973 if node.name in self.no_deepcopy:
974 self.no_deepcopy.remove(node.name)