Coverage for asteval/astutils.py: 92%
407 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-07 10:50 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-07 10:50 +0000
1"""
2utility functions for asteval
4 Matthew Newville <newville@cars.uchicago.edu>,
5 The University of Chicago
6"""
7import ast
8import io
9import math
10import numbers
11import re
12from sys import exc_info
13from tokenize import ENCODING as tk_ENCODING
14from tokenize import NAME as tk_NAME
15from tokenize import tokenize as generate_tokens
16from string import Formatter
18builtins = __builtins__
19if not isinstance(builtins, dict):
20 builtins = builtins.__dict__
22HAS_NUMPY = False
23try:
24 import numpy
25 numpy_version = numpy.version.version.split('.', 2)
26 HAS_NUMPY = True
27except ImportError:
28 numpy = None
30HAS_NUMPY_FINANCIAL = False
31try:
32 import numpy_financial
33 HAS_NUMPY_FINANCIAL = True
34except ImportError:
35 pass
37# This is a necessary API but it's undocumented and moved around
38# between Python releases
39try:
40 from _string import formatter_field_name_split
41except ImportError:
42 formatter_field_name_split = lambda \
43 x: x._formatter_field_name_split()
47MAX_EXPONENT = 10000
48MAX_STR_LEN = 2 << 17 # 256KiB
49MAX_SHIFT = 1000
50MAX_OPEN_BUFFER = 2 << 17
52RESERVED_WORDS = ('False', 'None', 'True', 'and', 'as', 'assert',
53 'async', 'await', 'break', 'class', 'continue', 'def',
54 'del', 'elif', 'else', 'except', 'finally', 'for',
55 'from', 'global', 'if', 'import', 'in', 'is',
56 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise',
57 'return', 'try', 'while', 'with', 'yield', 'exec',
58 'eval', 'execfile', '__import__', '__package__',
59 '__fstring__')
61NAME_MATCH = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$").match
63# unsafe attributes for all objects:
64UNSAFE_ATTRS = ('__subclasses__', '__bases__', '__globals__', '__code__',
65 '__reduce__', '__reduce_ex__', '__mro__',
66 '__closure__', '__func__', '__self__', '__module__',
67 '__dict__', '__class__', '__call__', '__get__',
68 '__getattribute__', '__subclasshook__', '__new__',
69 '__init__', 'func_globals', 'func_code', 'func_closure',
70 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame',
71 'f_locals', '__asteval__','mro')
73# unsafe attributes for particular objects, by type
74UNSAFE_ATTRS_DTYPES = {str: ('format', 'format_map')}
77# inherit these from python's __builtins__
78FROM_PY = ('ArithmeticError', 'AssertionError', 'AttributeError',
79 'BaseException', 'BufferError', 'BytesWarning',
80 'DeprecationWarning', 'EOFError', 'EnvironmentError',
81 'Exception', 'False', 'FloatingPointError', 'GeneratorExit',
82 'IOError', 'ImportError', 'ImportWarning', 'IndentationError',
83 'IndexError', 'KeyError', 'KeyboardInterrupt', 'LookupError',
84 'MemoryError', 'NameError', 'None',
85 'NotImplementedError', 'OSError', 'OverflowError',
86 'ReferenceError', 'RuntimeError', 'RuntimeWarning',
87 'StopIteration', 'SyntaxError', 'SyntaxWarning', 'SystemError',
88 'SystemExit', 'True', 'TypeError', 'UnboundLocalError',
89 'UnicodeDecodeError', 'UnicodeEncodeError', 'UnicodeError',
90 'UnicodeTranslateError', 'UnicodeWarning', 'ValueError',
91 'Warning', 'ZeroDivisionError', 'abs', 'all', 'any', 'bin',
92 'bool', 'bytearray', 'bytes', 'chr', 'complex', 'dict', 'dir',
93 'divmod', 'enumerate', 'filter', 'float', 'format', 'frozenset',
94 'hash', 'hex', 'id', 'int', 'isinstance', 'len', 'list', 'map',
95 'max', 'min', 'oct', 'ord', 'pow', 'range', 'repr',
96 'reversed', 'round', 'set', 'slice', 'sorted', 'str', 'sum',
97 'tuple', 'zip')
99BUILTINS_TABLE = {sym: builtins[sym] for sym in FROM_PY if sym in builtins}
101# inherit these from python's math
102FROM_MATH = ('acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh',
103 'ceil', 'copysign', 'cos', 'cosh', 'degrees', 'e', 'exp',
104 'fabs', 'factorial', 'floor', 'fmod', 'frexp', 'fsum',
105 'hypot', 'isinf', 'isnan', 'ldexp', 'log', 'log10', 'log1p',
106 'modf', 'pi', 'pow', 'radians', 'sin', 'sinh', 'sqrt', 'tan',
107 'tanh', 'trunc')
109MATH_TABLE = {sym: getattr(math, sym) for sym in FROM_MATH if hasattr(math, sym)}
111FROM_NUMPY = ('abs', 'add', 'all', 'amax', 'amin', 'angle', 'any', 'append',
112 'arange', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctan2',
113 'arctanh', 'argmax', 'argmin', 'argsort', 'argwhere', 'around', 'array',
114 'asarray', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'bartlett',
115 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', 'blackman',
116 'broadcast', 'ceil', 'choose', 'clip', 'column_stack', 'common_type',
117 'complex128', 'compress', 'concatenate', 'conjugate', 'convolve',
118 'copysign', 'corrcoef', 'correlate', 'cos', 'cosh', 'cov', 'cross',
119 'cumprod', 'cumsum', 'datetime_data', 'deg2rad', 'degrees', 'delete',
120 'diag', 'diag_indices', 'diag_indices_from', 'diagflat', 'diagonal',
121 'diff', 'digitize', 'divide', 'dot', 'dsplit', 'dstack', 'dtype', 'e',
122 'ediff1d', 'empty', 'empty_like', 'equal', 'exp', 'exp2', 'expand_dims',
123 'expm1', 'extract', 'eye', 'fabs', 'fill_diagonal', 'finfo', 'fix',
124 'flatiter', 'flatnonzero', 'fliplr', 'flipud', 'float64', 'floor',
125 'floor_divide', 'fmax', 'fmin', 'fmod', 'format_parser', 'frexp',
126 'frombuffer', 'fromfile', 'fromfunction', 'fromiter', 'frompyfunc',
127 'fromregex', 'fromstring', 'genfromtxt', 'getbufsize', 'geterr',
128 'gradient', 'greater', 'greater_equal', 'hamming', 'hanning', 'histogram',
129 'histogram2d', 'histogramdd', 'hsplit', 'hstack', 'hypot', 'i0',
130 'identity', 'iinfo', 'imag', 'indices', 'inexact', 'inf', 'info', 'inner',
131 'insert', 'int32', 'integer', 'interp', 'intersect1d', 'invert',
132 'iscomplex', 'iscomplexobj', 'isfinite', 'isinf', 'isnan', 'isneginf',
133 'isposinf', 'isreal', 'isrealobj', 'isscalar', 'iterable', 'kaiser',
134 'kron', 'ldexp', 'left_shift', 'less', 'less_equal', 'linspace',
135 'little_endian', 'loadtxt', 'log', 'log10', 'log1p', 'log2', 'logaddexp',
136 'logaddexp2', 'logical_and', 'logical_not', 'logical_or', 'logical_xor',
137 'logspace', 'longdouble', 'longlong', 'mask_indices', 'matrix', 'maximum',
138 'may_share_memory', 'mean', 'median', 'memmap', 'meshgrid', 'minimum',
139 'mintypecode', 'mod', 'modf', 'msort', 'multiply', 'nan', 'nan_to_num',
140 'nanargmax', 'nanargmin', 'nanmax', 'nanmin', 'nansum', 'ndarray',
141 'ndenumerate', 'ndim', 'ndindex', 'negative', 'nextafter', 'nonzero',
142 'not_equal', 'number', 'ones', 'ones_like', 'outer', 'packbits',
143 'percentile', 'pi', 'piecewise', 'place', 'poly', 'poly1d', 'polyadd',
144 'polyder', 'polydiv', 'polyint', 'polymul', 'polysub', 'polyval', 'power',
145 'prod', 'ptp', 'put', 'putmask', 'rad2deg', 'radians', 'ravel', 'real',
146 'real_if_close', 'reciprocal', 'record', 'remainder', 'repeat', 'reshape',
147 'resize', 'right_shift', 'rint', 'roll', 'rollaxis', 'roots', 'rot90',
148 'round', 'searchsorted', 'select', 'setbufsize', 'setdiff1d', 'seterr',
149 'setxor1d', 'shape', 'short', 'sign', 'signbit', 'signedinteger', 'sin',
150 'sinc', 'single', 'sinh', 'size', 'sort', 'sort_complex', 'spacing',
151 'split', 'sqrt', 'square', 'squeeze', 'std', 'subtract', 'sum', 'swapaxes',
152 'take', 'tan', 'tanh', 'tensordot', 'tile', 'trace', 'transpose', 'tri',
153 'tril', 'tril_indices', 'tril_indices_from', 'trim_zeros', 'triu',
154 'triu_indices', 'triu_indices_from', 'true_divide', 'trunc', 'ubyte',
155 'uint', 'uint32', 'union1d', 'unique', 'unravel_index', 'unsignedinteger',
156 'unwrap', 'ushort', 'vander', 'var', 'vdot', 'vectorize', 'vsplit',
157 'vstack', 'where', 'zeros', 'zeros_like')
160FROM_NUMPY_FINANCIAL = ('fv', 'ipmt', 'irr', 'mirr', 'nper', 'npv',
161 'pmt', 'ppmt', 'pv', 'rate')
163NUMPY_RENAMES = {'ln': 'log', 'asin': 'arcsin', 'acos': 'arccos',
164 'atan': 'arctan', 'atan2': 'arctan2', 'atanh':
165 'arctanh', 'acosh': 'arccosh', 'asinh': 'arcsinh'}
167if HAS_NUMPY:
168 FROM_NUMPY = tuple(set(FROM_NUMPY))
169 FROM_NUMPY = tuple(sym for sym in FROM_NUMPY if hasattr(numpy, sym))
170 NUMPY_RENAMES = {sym: value for sym, value in NUMPY_RENAMES.items() if hasattr(numpy, value)}
172 NUMPY_TABLE = {}
173 for sym in FROM_NUMPY:
174 obj = getattr(numpy, sym, None)
175 if obj is not None:
176 NUMPY_TABLE[sym] = obj
178 for sname, sym in NUMPY_RENAMES.items():
179 obj = getattr(numpy, sym, None)
180 if obj is not None:
181 NUMPY_TABLE[sname] = obj
183 if HAS_NUMPY_FINANCIAL:
184 for sym in FROM_NUMPY_FINANCIAL:
185 obj = getattr(numpy_financial, sym, None)
186 if obj is not None:
187 NUMPY_TABLE[sym] = obj
189else:
190 NUMPY_TABLE = {}
193def _open(filename, mode='r', buffering=-1, encoding=None):
194 """read only version of open()"""
195 if mode not in ('r', 'rb', 'rU'):
196 raise RuntimeError("Invalid open file mode, must be 'r', 'rb', or 'rU'")
197 if buffering > MAX_OPEN_BUFFER:
198 raise RuntimeError(f"Invalid buffering value, max buffer size is {MAX_OPEN_BUFFER}")
199 return open(filename, mode, buffering, encoding=encoding)
202def _type(x):
203 """type that prevents varargs and varkws"""
204 return type(x).__name__
207LOCALFUNCS = {'open': _open, 'type': _type}
210# Safe versions of functions to prevent denial of service issues
212def safe_pow(base, exp):
213 """safe version of pow"""
214 if isinstance(exp, numbers.Number):
215 if exp > MAX_EXPONENT:
216 raise RuntimeError(f"Invalid exponent, max exponent is {MAX_EXPONENT}")
217 elif HAS_NUMPY and isinstance(exp, numpy.ndarray):
218 if numpy.nanmax(exp) > MAX_EXPONENT:
219 raise RuntimeError(f"Invalid exponent, max exponent is {MAX_EXPONENT}")
220 return base ** exp
223def safe_mult(arg1, arg2):
224 """safe version of multiply"""
225 if isinstance(arg1, str) and isinstance(arg2, int) and len(arg1) * arg2 > MAX_STR_LEN:
226 raise RuntimeError(f"String length exceeded, max string length is {MAX_STR_LEN}")
227 return arg1 * arg2
230def safe_add(arg1, arg2):
231 """safe version of add"""
232 if isinstance(arg1, str) and isinstance(arg2, str) and len(arg1) + len(arg2) > MAX_STR_LEN:
233 raise RuntimeError(f"String length exceeded, max string length is {MAX_STR_LEN}")
234 return arg1 + arg2
237def safe_lshift(arg1, arg2):
238 """safe version of lshift"""
239 if isinstance(arg2, numbers.Number):
240 if arg2 > MAX_SHIFT:
241 raise RuntimeError(f"Invalid left shift, max left shift is {MAX_SHIFT}")
242 elif HAS_NUMPY and isinstance(arg2, numpy.ndarray):
243 if numpy.nanmax(arg2) > MAX_SHIFT:
244 raise RuntimeError(f"Invalid left shift, max left shift is {MAX_SHIFT}")
245 return arg1 << arg2
248OPERATORS = {ast.Is: lambda a, b: a is b,
249 ast.IsNot: lambda a, b: a is not b,
250 ast.In: lambda a, b: a in b,
251 ast.NotIn: lambda a, b: a not in b,
252 ast.Add: safe_add,
253 ast.BitAnd: lambda a, b: a & b,
254 ast.BitOr: lambda a, b: a | b,
255 ast.BitXor: lambda a, b: a ^ b,
256 ast.Div: lambda a, b: a / b,
257 ast.FloorDiv: lambda a, b: a // b,
258 ast.LShift: safe_lshift,
259 ast.RShift: lambda a, b: a >> b,
260 ast.Mult: safe_mult,
261 ast.Pow: safe_pow,
262 ast.MatMult: lambda a, b: a @ b,
263 ast.Sub: lambda a, b: a - b,
264 ast.Mod: lambda a, b: a % b,
265 ast.And: lambda a, b: a and b,
266 ast.Or: lambda a, b: a or b,
267 ast.Eq: lambda a, b: a == b,
268 ast.Gt: lambda a, b: a > b,
269 ast.GtE: lambda a, b: a >= b,
270 ast.Lt: lambda a, b: a < b,
271 ast.LtE: lambda a, b: a <= b,
272 ast.NotEq: lambda a, b: a != b,
273 ast.Invert: lambda a: ~a,
274 ast.Not: lambda a: not a,
275 ast.UAdd: lambda a: +a,
276 ast.USub: lambda a: -a}
278# Safe version of getattr
280def safe_getattr(obj, attr, raise_exc, node):
281 """safe version of getattr"""
282 unsafe = (attr in UNSAFE_ATTRS or
283 (attr.startswith('__') and attr.endswith('__')))
284 if not unsafe:
285 for dtype, attrlist in UNSAFE_ATTRS_DTYPES.items():
286 unsafe = (isinstance(obj, dtype) or obj is dtype) and attr in attrlist
287 if unsafe:
288 break
289 if unsafe:
290 msg = f"no safe attribute '{attr}' for {repr(obj)}"
291 raise_exc(node, exc=AttributeError, msg=msg)
292 else:
293 try:
294 return getattr(obj, attr)
295 except AttributeError:
296 pass
298class SafeFormatter(Formatter):
299 def __init__(self, raise_exc, node):
300 self.raise_exc = raise_exc
301 self.node = node
302 super().__init__()
304 def get_field(self, field_name, args, kwargs):
305 first, rest = formatter_field_name_split(field_name)
306 obj = self.get_value(first, args, kwargs)
307 for is_attr, i in rest:
308 if is_attr:
309 obj = safe_getattr(obj, i, self.raise_exc, self.node)
310 else:
311 obj = obj[i]
312 return obj, first
314def safe_format(_string, raise_exc, node, *args, **kwargs):
315 formatter = SafeFormatter(raise_exc, node)
316 return formatter.vformat(_string, args, kwargs)
318def valid_symbol_name(name):
319 """Determine whether the input symbol name is a valid name.
321 Arguments
322 ---------
323 name : str
324 name to check for validity.
326 Returns
327 --------
328 valid : bool
329 whether name is a a valid symbol name
331 This checks for Python reserved words and that the name matches
332 the regular expression ``[a-zA-Z_][a-zA-Z0-9_]``
333 """
334 if name in RESERVED_WORDS:
335 return False
337 gen = generate_tokens(io.BytesIO(name.encode('utf-8')).readline)
338 typ, _, start, end, _ = next(gen)
339 if typ == tk_ENCODING:
340 typ, _, start, end, _ = next(gen)
341 return typ == tk_NAME and start == (1, 0) and end == (1, len(name))
344def op2func(oper):
345 """Return function for operator nodes."""
346 return OPERATORS[oper.__class__]
349class Empty:
350 """Empty class."""
351 def __init__(self):
352 """TODO: docstring in public method."""
353 return
355 def __nonzero__(self):
356 """Empty is TODO: docstring in magic method."""
357 return False
359 def __repr__(self):
360 """Empty is TODO: docstring in magic method."""
361 return "Empty"
363ReturnedNone = Empty()
365class ExceptionHolder:
366 """Basic exception handler."""
367 def __init__(self, node, exc=None, msg='', expr=None,
368 text=None, lineno=None):
369 """TODO: docstring in public method."""
370 self.node = node
371 self.expr = expr
372 self.msg = msg
373 self.exc = exc
374 self.text = text
375 self.lineno = lineno
376 self.end_lineno = lineno
377 self.col_offset = 0
378 if lineno is None:
379 try:
380 self.lineno = node.lineno
381 self.end_lineno = node.end_lineno
382 self.col_offset = node.col_offset
383 except:
384 pass
385 self.exc_info = exc_info()
386 if self.exc is None and self.exc_info[0] is not None:
387 self.exc = self.exc_info[0]
388 if self.msg == '' and self.exc_info[1] is not None:
389 self.msg = str(self.exc_info[1])
391 def get_error(self):
392 """Retrieve error data."""
393 try:
394 exc_name = self.exc.__name__
395 except AttributeError:
396 exc_name = str(self.exc)
397 if exc_name in (None, 'None'):
398 exc_name = 'UnknownError'
400 out = []
401 self.code = [f'{l}' for l in self.text.split('\n')]
402 self.codelines = [f'{i+1}: {l}' for i, l in enumerate(self.code)]
404 try:
405 out.append('\n'.join(self.code[self.lineno-1:self.end_lineno]))
406 except:
407 out.append(f"{self.expr}")
408 if self.col_offset > 0:
409 out.append(f"{self.col_offset*' '}^^^^")
410 out.append(f"{exc_name}: {self.msg}")
411 return (exc_name, '\n'.join(out))
413 def __repr__(self):
414 return f"ExceptionHolder({self.exc}, {self.msg})"
416class NameFinder(ast.NodeVisitor):
417 """Find all symbol names used by a parsed node."""
419 def __init__(self):
420 """TODO: docstring in public method."""
421 self.names = []
422 ast.NodeVisitor.__init__(self)
424 def generic_visit(self, node):
425 """TODO: docstring in public method."""
426 if node.__class__.__name__ == 'Name':
427 if node.id not in self.names:
428 self.names.append(node.id)
429 ast.NodeVisitor.generic_visit(self, node)
432def get_ast_names(astnode):
433 """Return symbol Names from an AST node."""
434 finder = NameFinder()
435 finder.generic_visit(astnode)
436 return finder.names
439def valid_varname(name):
440 "is this a valid variable name"
441 return name.isidentifier() and name not in RESERVED_WORDS
444class Group(dict):
445 """
446 Group: a container of objects that can be accessed either as an object attributes
447 or dictionary key/value. Attribute names must follow Python naming conventions.
448 """
449 def __init__(self, name=None, searchgroups=None, **kws):
450 if name is None:
451 name = hex(id(self))
452 self.__name__ = name
453 dict.__init__(self, **kws)
454 self._searchgroups = searchgroups
456 def __setattr__(self, name, value):
457 if not valid_varname(name):
458 raise SyntaxError(f"invalid attribute name '{name}'")
459 self[name] = value
461 def __getattr__(self, name, default=None):
462 if name in self:
463 return self[name]
464 if default is not None:
465 return default
466 raise KeyError(f"no attribute named '{name}'")
468 def __setitem__(self, name, value):
469 if valid_varname(name):
470 dict.__setitem__(self, name, value)
471 else: # raise SyntaxError(f"invalid attribute name '{name}'")
472 return setattr(self, name, value)
474 def get(self, key, default=None):
475 val = self.__getattr__(key, ReturnedNone)
476 if not isinstance(val, Empty):
477 return val
478 searchgroups = self._searchgroups
479 if searchgroups is not None:
480 for sgroup in searchgroups:
481 grp = self.__getattr__(sgroup, None)
482 if isinstance(grp, (Group, dict)):
483 val = grp.__getattr__(key, ReturnedNone)
484 if not isinstance(val, Empty):
485 return val
486 return default
489 def __repr__(self):
490 keys = [a for a in self.keys() if a != '__name__']
491 return f"Group('{self.__name__}', {len(keys)} symbols)"
493 def _repr_html_(self):
494 """HTML representation for Jupyter notebook"""
495 html = [f"<table><caption>Group('{self.__name__}')</caption>",
496 "<tr><th>Attribute</th><th>DataType</th><th><b>Value</b></th></tr>"]
497 for key, val in self.items():
498 html.append(f"""
499<tr><td>{key}</td><td><i>{type(val).__name__}</i></td>
500 <td>{repr(val):.75s}</td>
501</tr>""")
502 html.append("</table>")
503 return '\n'.join(html)
506def make_symbol_table(use_numpy=True, nested=False, top=True, **kws):
507 """Create a default symboltable, taking dict of user-defined symbols.
509 Arguments
510 ---------
511 numpy : bool, optional
512 whether to include symbols from numpy [True]
513 nested : bool, optional
514 whether to make a "new-style" nested table instead of a plain dict [False]
515 top : bool, optional
516 whether this is the top-level table in a nested-table [True]
517 kws : optional
518 additional symbol name, value pairs to include in symbol table
520 Returns
521 --------
522 symbol_table : dict or nested Group
523 a symbol table that can be used in `asteval.Interpereter`
525 """
526 if nested:
527 name = '_'
528 if top:
529 name = '_main'
530 if 'name' in kws:
531 name = kws.pop('name')
532 symtable = Group(name=name, Group=Group)
533 else:
534 symtable = {}
536 symtable.update(BUILTINS_TABLE)
537 symtable.update(LOCALFUNCS)
538 symtable.update(kws)
539 math_functions = dict(MATH_TABLE.items())
540 if use_numpy:
541 math_functions.update(NUMPY_TABLE)
543 if nested:
544 symtable['math'] = Group(name='math', **math_functions)
545 symtable['Group'] = Group
546 symtable._searchgroups = ('math',)
547 else:
548 symtable.update(math_functions)
549 symtable.update(**kws)
550 return symtable
553class Procedure:
554 """Procedure: user-defined function for asteval.
556 This stores the parsed ast nodes as from the 'functiondef' ast node
557 for later evaluation.
559 """
561 def __init__(self, name, interp, doc=None, lineno=None,
562 body=None, text=None, args=None, kwargs=None,
563 vararg=None, varkws=None):
564 """TODO: docstring in public method."""
565 self.__ininit__ = True
566 self.name = name
567 self.__name__ = self.name
568 self.__asteval__ = interp
569 self.__raise_exc__ = self.__asteval__.raise_exception
570 self.__doc__ = doc
571 self.__body__ = body
572 self.__argnames__ = args
573 self.__kwargs__ = kwargs
574 self.__vararg__ = vararg
575 self.__varkws__ = varkws
576 self.lineno = lineno
577 self.__text__ = text
578 if text is None:
579 self.__text__ = f'{self.__signature__()}\n' + ast.unparse(self.__body__)
580 self.__ininit__ = False
582 def __setattr__(self, attr, val):
583 if not getattr(self, '__ininit__', True):
584 self.__raise_exc__(None, exc=TypeError,
585 msg="procedure is read-only")
586 self.__dict__[attr] = val
588 def __dir__(self):
589 return ['__getdoc__', 'argnames', 'kwargs', 'name', 'vararg', 'varkws']
591 def __getdoc__(self):
592 doc = self.__doc__
593 if isinstance(doc, ast.Constant):
594 doc = doc.value
595 return doc
597 def __repr__(self):
598 """TODO: docstring in magic method."""
599 sig = self.__signature__()
600 rep = f"<Procedure {sig}>"
601 doc = self.__getdoc__()
602 if doc is not None:
603 rep = f"{rep}\n {doc}"
604 return rep
606 def __signature__(self):
607 "call signature"
608 sig = ""
609 if len(self.__argnames__) > 0:
610 sig = sig + ', '.join(self.__argnames__)
611 if self.__vararg__ is not None:
612 sig = sig + f"*{self.__vararg__}"
613 if len(self.__kwargs__) > 0:
614 if len(sig) > 0:
615 sig = f"{sig}, "
616 _kw = [f"{k}={v}" for k, v in self.__kwargs__]
617 sig = f"{sig}{', '.join(_kw)}"
619 if self.__varkws__ is not None:
620 sig = f"{sig}, **{self.__varkws__}"
621 return f"{self.name}({sig})"
623 def __call__(self, *args, **kwargs):
624 """TODO: docstring in public method."""
625 topsym = self.__asteval__.symtable
626 if self.__asteval__.config.get('nested_symtable', False):
627 sargs = {'_main': topsym}
628 sgroups = topsym.get('_searchgroups', None)
629 if sgroups is not None:
630 for sxname in sgroups:
631 sargs[sxname] = topsym.get(sxname)
634 symlocals = Group(name=f'symtable_{self.name}_', **sargs)
635 symlocals._searchgroups = list(sargs.keys())
636 else:
637 symlocals = {}
639 args = list(args)
640 nargs = len(args)
641 nkws = len(kwargs)
642 nargs_expected = len(self.__argnames__)
644 # check for too few arguments, but the correct keyword given
645 if (nargs < nargs_expected) and nkws > 0:
646 for name in self.__argnames__[nargs:]:
647 if name in kwargs:
648 args.append(kwargs.pop(name))
649 nargs = len(args)
650 nargs_expected = len(self.__argnames__)
651 nkws = len(kwargs)
652 if nargs < nargs_expected:
653 msg = f"{self.name}() takes at least"
654 msg = f"{msg} {nargs_expected} arguments, got {nargs}"
655 self.__raise_exc__(None, exc=TypeError, msg=msg)
656 # check for multiple values for named argument
657 if len(self.__argnames__) > 0 and kwargs is not None:
658 msg = "multiple values for keyword argument"
659 for targ in self.__argnames__:
660 if targ in kwargs:
661 msg = f"{msg} '{targ}' in Procedure {self.name}"
662 self.__raise_exc__(None, exc=TypeError, msg=msg,
663 lineno=self.lineno)
665 # check more args given than expected, varargs not given
666 if nargs != nargs_expected:
667 msg = None
668 if nargs < nargs_expected:
669 msg = f"not enough arguments for Procedure {self.name}()"
670 msg = f"{msg} (expected {nargs_expected}, got {nargs}"
671 self.__raise_exc__(None, exc=TypeError, msg=msg)
673 if nargs > nargs_expected and self.__vararg__ is None:
674 if nargs - nargs_expected > len(self.__kwargs__):
675 msg = f"too many arguments for {self.name}() expected at most"
676 msg = f"{msg} {len(self.__kwargs__)+nargs_expected}, got {nargs}"
677 self.__raise_exc__(None, exc=TypeError, msg=msg)
679 for i, xarg in enumerate(args[nargs_expected:]):
680 kw_name = self.__kwargs__[i][0]
681 if kw_name not in kwargs:
682 kwargs[kw_name] = xarg
684 for argname in self.__argnames__:
685 symlocals[argname] = args.pop(0)
687 try:
688 if self.__vararg__ is not None:
689 symlocals[self.__vararg__] = tuple(args)
691 for key, val in self.__kwargs__:
692 if key in kwargs:
693 val = kwargs.pop(key)
694 symlocals[key] = val
696 if self.__varkws__ is not None:
697 symlocals[self.__varkws__] = kwargs
699 elif len(kwargs) > 0:
700 msg = f"extra keyword arguments for Procedure {self.name}: "
701 msg = msg + ','.join(list(kwargs.keys()))
702 self.__raise_exc__(None, msg=msg, exc=TypeError,
703 lineno=self.lineno)
705 except (ValueError, LookupError, TypeError,
706 NameError, AttributeError):
707 msg = f"incorrect arguments for Procedure {self.name}"
708 self.__raise_exc__(None, msg=msg, lineno=self.lineno)
710 if self.__asteval__.config.get('nested_symtable', False):
711 save_symtable = self.__asteval__.symtable
712 self.__asteval__.symtable = symlocals
713 else:
714 save_symtable = self.__asteval__.symtable.copy()
715 self.__asteval__.symtable.update(symlocals)
717 self.__asteval__.retval = None
718 self.__asteval__._calldepth += 1
719 retval = None
721 # evaluate script of function
722 self.__asteval__.code_text.append(self.__text__)
723 for node in self.__body__:
724 self.__asteval__.run(node, lineno=node.lineno)
725 if len(self.__asteval__.error) > 0:
726 break
727 if self.__asteval__.retval is not None:
728 retval = self.__asteval__.retval
729 self.__asteval__.retval = None
730 if retval is ReturnedNone:
731 retval = None
732 break
734 self.__asteval__.symtable = save_symtable
735 self.__asteval__.code_text.pop()
736 self.__asteval__._calldepth -= 1
737 symlocals = None
738 return retval