Coverage for asteval/astutils.py: 92%

407 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-07 10:50 +0000

1""" 

2utility functions for asteval 

3 

4 Matthew Newville <newville@cars.uchicago.edu>, 

5 The University of Chicago 

6""" 

7import ast 

8import io 

9import math 

10import numbers 

11import re 

12from sys import exc_info 

13from tokenize import ENCODING as tk_ENCODING 

14from tokenize import NAME as tk_NAME 

15from tokenize import tokenize as generate_tokens 

16from string import Formatter 

17 

18builtins = __builtins__ 

19if not isinstance(builtins, dict): 

20 builtins = builtins.__dict__ 

21 

22HAS_NUMPY = False 

23try: 

24 import numpy 

25 numpy_version = numpy.version.version.split('.', 2) 

26 HAS_NUMPY = True 

27except ImportError: 

28 numpy = None 

29 

30HAS_NUMPY_FINANCIAL = False 

31try: 

32 import numpy_financial 

33 HAS_NUMPY_FINANCIAL = True 

34except ImportError: 

35 pass 

36 

37# This is a necessary API but it's undocumented and moved around 

38# between Python releases 

39try: 

40 from _string import formatter_field_name_split 

41except ImportError: 

42 formatter_field_name_split = lambda \ 

43 x: x._formatter_field_name_split() 

44 

45 

46 

47MAX_EXPONENT = 10000 

48MAX_STR_LEN = 2 << 17 # 256KiB 

49MAX_SHIFT = 1000 

50MAX_OPEN_BUFFER = 2 << 17 

51 

52RESERVED_WORDS = ('False', 'None', 'True', 'and', 'as', 'assert', 

53 'async', 'await', 'break', 'class', 'continue', 'def', 

54 'del', 'elif', 'else', 'except', 'finally', 'for', 

55 'from', 'global', 'if', 'import', 'in', 'is', 

56 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 

57 'return', 'try', 'while', 'with', 'yield', 'exec', 

58 'eval', 'execfile', '__import__', '__package__', 

59 '__fstring__') 

60 

61NAME_MATCH = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$").match 

62 

63# unsafe attributes for all objects: 

64UNSAFE_ATTRS = ('__subclasses__', '__bases__', '__globals__', '__code__', 

65 '__reduce__', '__reduce_ex__', '__mro__', 

66 '__closure__', '__func__', '__self__', '__module__', 

67 '__dict__', '__class__', '__call__', '__get__', 

68 '__getattribute__', '__subclasshook__', '__new__', 

69 '__init__', 'func_globals', 'func_code', 'func_closure', 

70 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', 

71 'f_locals', '__asteval__','mro') 

72 

73# unsafe attributes for particular objects, by type 

74UNSAFE_ATTRS_DTYPES = {str: ('format', 'format_map')} 

75 

76 

77# inherit these from python's __builtins__ 

78FROM_PY = ('ArithmeticError', 'AssertionError', 'AttributeError', 

79 'BaseException', 'BufferError', 'BytesWarning', 

80 'DeprecationWarning', 'EOFError', 'EnvironmentError', 

81 'Exception', 'False', 'FloatingPointError', 'GeneratorExit', 

82 'IOError', 'ImportError', 'ImportWarning', 'IndentationError', 

83 'IndexError', 'KeyError', 'KeyboardInterrupt', 'LookupError', 

84 'MemoryError', 'NameError', 'None', 

85 'NotImplementedError', 'OSError', 'OverflowError', 

86 'ReferenceError', 'RuntimeError', 'RuntimeWarning', 

87 'StopIteration', 'SyntaxError', 'SyntaxWarning', 'SystemError', 

88 'SystemExit', 'True', 'TypeError', 'UnboundLocalError', 

89 'UnicodeDecodeError', 'UnicodeEncodeError', 'UnicodeError', 

90 'UnicodeTranslateError', 'UnicodeWarning', 'ValueError', 

91 'Warning', 'ZeroDivisionError', 'abs', 'all', 'any', 'bin', 

92 'bool', 'bytearray', 'bytes', 'chr', 'complex', 'dict', 'dir', 

93 'divmod', 'enumerate', 'filter', 'float', 'format', 'frozenset', 

94 'hash', 'hex', 'id', 'int', 'isinstance', 'len', 'list', 'map', 

95 'max', 'min', 'oct', 'ord', 'pow', 'range', 'repr', 

96 'reversed', 'round', 'set', 'slice', 'sorted', 'str', 'sum', 

97 'tuple', 'zip') 

98 

99BUILTINS_TABLE = {sym: builtins[sym] for sym in FROM_PY if sym in builtins} 

100 

101# inherit these from python's math 

102FROM_MATH = ('acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 

103 'ceil', 'copysign', 'cos', 'cosh', 'degrees', 'e', 'exp', 

104 'fabs', 'factorial', 'floor', 'fmod', 'frexp', 'fsum', 

105 'hypot', 'isinf', 'isnan', 'ldexp', 'log', 'log10', 'log1p', 

106 'modf', 'pi', 'pow', 'radians', 'sin', 'sinh', 'sqrt', 'tan', 

107 'tanh', 'trunc') 

108 

109MATH_TABLE = {sym: getattr(math, sym) for sym in FROM_MATH if hasattr(math, sym)} 

110 

111FROM_NUMPY = ('abs', 'add', 'all', 'amax', 'amin', 'angle', 'any', 'append', 

112 'arange', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctan2', 

113 'arctanh', 'argmax', 'argmin', 'argsort', 'argwhere', 'around', 'array', 

114 'asarray', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'average', 'bartlett', 

115 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', 'blackman', 

116 'broadcast', 'ceil', 'choose', 'clip', 'column_stack', 'common_type', 

117 'complex128', 'compress', 'concatenate', 'conjugate', 'convolve', 

118 'copysign', 'corrcoef', 'correlate', 'cos', 'cosh', 'cov', 'cross', 

119 'cumprod', 'cumsum', 'datetime_data', 'deg2rad', 'degrees', 'delete', 

120 'diag', 'diag_indices', 'diag_indices_from', 'diagflat', 'diagonal', 

121 'diff', 'digitize', 'divide', 'dot', 'dsplit', 'dstack', 'dtype', 'e', 

122 'ediff1d', 'empty', 'empty_like', 'equal', 'exp', 'exp2', 'expand_dims', 

123 'expm1', 'extract', 'eye', 'fabs', 'fill_diagonal', 'finfo', 'fix', 

124 'flatiter', 'flatnonzero', 'fliplr', 'flipud', 'float64', 'floor', 

125 'floor_divide', 'fmax', 'fmin', 'fmod', 'format_parser', 'frexp', 

126 'frombuffer', 'fromfile', 'fromfunction', 'fromiter', 'frompyfunc', 

127 'fromregex', 'fromstring', 'genfromtxt', 'getbufsize', 'geterr', 

128 'gradient', 'greater', 'greater_equal', 'hamming', 'hanning', 'histogram', 

129 'histogram2d', 'histogramdd', 'hsplit', 'hstack', 'hypot', 'i0', 

130 'identity', 'iinfo', 'imag', 'indices', 'inexact', 'inf', 'info', 'inner', 

131 'insert', 'int32', 'integer', 'interp', 'intersect1d', 'invert', 

132 'iscomplex', 'iscomplexobj', 'isfinite', 'isinf', 'isnan', 'isneginf', 

133 'isposinf', 'isreal', 'isrealobj', 'isscalar', 'iterable', 'kaiser', 

134 'kron', 'ldexp', 'left_shift', 'less', 'less_equal', 'linspace', 

135 'little_endian', 'loadtxt', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 

136 'logaddexp2', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 

137 'logspace', 'longdouble', 'longlong', 'mask_indices', 'matrix', 'maximum', 

138 'may_share_memory', 'mean', 'median', 'memmap', 'meshgrid', 'minimum', 

139 'mintypecode', 'mod', 'modf', 'msort', 'multiply', 'nan', 'nan_to_num', 

140 'nanargmax', 'nanargmin', 'nanmax', 'nanmin', 'nansum', 'ndarray', 

141 'ndenumerate', 'ndim', 'ndindex', 'negative', 'nextafter', 'nonzero', 

142 'not_equal', 'number', 'ones', 'ones_like', 'outer', 'packbits', 

143 'percentile', 'pi', 'piecewise', 'place', 'poly', 'poly1d', 'polyadd', 

144 'polyder', 'polydiv', 'polyint', 'polymul', 'polysub', 'polyval', 'power', 

145 'prod', 'ptp', 'put', 'putmask', 'rad2deg', 'radians', 'ravel', 'real', 

146 'real_if_close', 'reciprocal', 'record', 'remainder', 'repeat', 'reshape', 

147 'resize', 'right_shift', 'rint', 'roll', 'rollaxis', 'roots', 'rot90', 

148 'round', 'searchsorted', 'select', 'setbufsize', 'setdiff1d', 'seterr', 

149 'setxor1d', 'shape', 'short', 'sign', 'signbit', 'signedinteger', 'sin', 

150 'sinc', 'single', 'sinh', 'size', 'sort', 'sort_complex', 'spacing', 

151 'split', 'sqrt', 'square', 'squeeze', 'std', 'subtract', 'sum', 'swapaxes', 

152 'take', 'tan', 'tanh', 'tensordot', 'tile', 'trace', 'transpose', 'tri', 

153 'tril', 'tril_indices', 'tril_indices_from', 'trim_zeros', 'triu', 

154 'triu_indices', 'triu_indices_from', 'true_divide', 'trunc', 'ubyte', 

155 'uint', 'uint32', 'union1d', 'unique', 'unravel_index', 'unsignedinteger', 

156 'unwrap', 'ushort', 'vander', 'var', 'vdot', 'vectorize', 'vsplit', 

157 'vstack', 'where', 'zeros', 'zeros_like') 

158 

159 

160FROM_NUMPY_FINANCIAL = ('fv', 'ipmt', 'irr', 'mirr', 'nper', 'npv', 

161 'pmt', 'ppmt', 'pv', 'rate') 

162 

163NUMPY_RENAMES = {'ln': 'log', 'asin': 'arcsin', 'acos': 'arccos', 

164 'atan': 'arctan', 'atan2': 'arctan2', 'atanh': 

165 'arctanh', 'acosh': 'arccosh', 'asinh': 'arcsinh'} 

166 

167if HAS_NUMPY: 

168 FROM_NUMPY = tuple(set(FROM_NUMPY)) 

169 FROM_NUMPY = tuple(sym for sym in FROM_NUMPY if hasattr(numpy, sym)) 

170 NUMPY_RENAMES = {sym: value for sym, value in NUMPY_RENAMES.items() if hasattr(numpy, value)} 

171 

172 NUMPY_TABLE = {} 

173 for sym in FROM_NUMPY: 

174 obj = getattr(numpy, sym, None) 

175 if obj is not None: 

176 NUMPY_TABLE[sym] = obj 

177 

178 for sname, sym in NUMPY_RENAMES.items(): 

179 obj = getattr(numpy, sym, None) 

180 if obj is not None: 

181 NUMPY_TABLE[sname] = obj 

182 

183 if HAS_NUMPY_FINANCIAL: 

184 for sym in FROM_NUMPY_FINANCIAL: 

185 obj = getattr(numpy_financial, sym, None) 

186 if obj is not None: 

187 NUMPY_TABLE[sym] = obj 

188 

189else: 

190 NUMPY_TABLE = {} 

191 

192 

193def _open(filename, mode='r', buffering=-1, encoding=None): 

194 """read only version of open()""" 

195 if mode not in ('r', 'rb', 'rU'): 

196 raise RuntimeError("Invalid open file mode, must be 'r', 'rb', or 'rU'") 

197 if buffering > MAX_OPEN_BUFFER: 

198 raise RuntimeError(f"Invalid buffering value, max buffer size is {MAX_OPEN_BUFFER}") 

199 return open(filename, mode, buffering, encoding=encoding) 

200 

201 

202def _type(x): 

203 """type that prevents varargs and varkws""" 

204 return type(x).__name__ 

205 

206 

207LOCALFUNCS = {'open': _open, 'type': _type} 

208 

209 

210# Safe versions of functions to prevent denial of service issues 

211 

212def safe_pow(base, exp): 

213 """safe version of pow""" 

214 if isinstance(exp, numbers.Number): 

215 if exp > MAX_EXPONENT: 

216 raise RuntimeError(f"Invalid exponent, max exponent is {MAX_EXPONENT}") 

217 elif HAS_NUMPY and isinstance(exp, numpy.ndarray): 

218 if numpy.nanmax(exp) > MAX_EXPONENT: 

219 raise RuntimeError(f"Invalid exponent, max exponent is {MAX_EXPONENT}") 

220 return base ** exp 

221 

222 

223def safe_mult(arg1, arg2): 

224 """safe version of multiply""" 

225 if isinstance(arg1, str) and isinstance(arg2, int) and len(arg1) * arg2 > MAX_STR_LEN: 

226 raise RuntimeError(f"String length exceeded, max string length is {MAX_STR_LEN}") 

227 return arg1 * arg2 

228 

229 

230def safe_add(arg1, arg2): 

231 """safe version of add""" 

232 if isinstance(arg1, str) and isinstance(arg2, str) and len(arg1) + len(arg2) > MAX_STR_LEN: 

233 raise RuntimeError(f"String length exceeded, max string length is {MAX_STR_LEN}") 

234 return arg1 + arg2 

235 

236 

237def safe_lshift(arg1, arg2): 

238 """safe version of lshift""" 

239 if isinstance(arg2, numbers.Number): 

240 if arg2 > MAX_SHIFT: 

241 raise RuntimeError(f"Invalid left shift, max left shift is {MAX_SHIFT}") 

242 elif HAS_NUMPY and isinstance(arg2, numpy.ndarray): 

243 if numpy.nanmax(arg2) > MAX_SHIFT: 

244 raise RuntimeError(f"Invalid left shift, max left shift is {MAX_SHIFT}") 

245 return arg1 << arg2 

246 

247 

248OPERATORS = {ast.Is: lambda a, b: a is b, 

249 ast.IsNot: lambda a, b: a is not b, 

250 ast.In: lambda a, b: a in b, 

251 ast.NotIn: lambda a, b: a not in b, 

252 ast.Add: safe_add, 

253 ast.BitAnd: lambda a, b: a & b, 

254 ast.BitOr: lambda a, b: a | b, 

255 ast.BitXor: lambda a, b: a ^ b, 

256 ast.Div: lambda a, b: a / b, 

257 ast.FloorDiv: lambda a, b: a // b, 

258 ast.LShift: safe_lshift, 

259 ast.RShift: lambda a, b: a >> b, 

260 ast.Mult: safe_mult, 

261 ast.Pow: safe_pow, 

262 ast.MatMult: lambda a, b: a @ b, 

263 ast.Sub: lambda a, b: a - b, 

264 ast.Mod: lambda a, b: a % b, 

265 ast.And: lambda a, b: a and b, 

266 ast.Or: lambda a, b: a or b, 

267 ast.Eq: lambda a, b: a == b, 

268 ast.Gt: lambda a, b: a > b, 

269 ast.GtE: lambda a, b: a >= b, 

270 ast.Lt: lambda a, b: a < b, 

271 ast.LtE: lambda a, b: a <= b, 

272 ast.NotEq: lambda a, b: a != b, 

273 ast.Invert: lambda a: ~a, 

274 ast.Not: lambda a: not a, 

275 ast.UAdd: lambda a: +a, 

276 ast.USub: lambda a: -a} 

277 

278# Safe version of getattr 

279 

280def safe_getattr(obj, attr, raise_exc, node): 

281 """safe version of getattr""" 

282 unsafe = (attr in UNSAFE_ATTRS or 

283 (attr.startswith('__') and attr.endswith('__'))) 

284 if not unsafe: 

285 for dtype, attrlist in UNSAFE_ATTRS_DTYPES.items(): 

286 unsafe = (isinstance(obj, dtype) or obj is dtype) and attr in attrlist 

287 if unsafe: 

288 break 

289 if unsafe: 

290 msg = f"no safe attribute '{attr}' for {repr(obj)}" 

291 raise_exc(node, exc=AttributeError, msg=msg) 

292 else: 

293 try: 

294 return getattr(obj, attr) 

295 except AttributeError: 

296 pass 

297 

298class SafeFormatter(Formatter): 

299 def __init__(self, raise_exc, node): 

300 self.raise_exc = raise_exc 

301 self.node = node 

302 super().__init__() 

303 

304 def get_field(self, field_name, args, kwargs): 

305 first, rest = formatter_field_name_split(field_name) 

306 obj = self.get_value(first, args, kwargs) 

307 for is_attr, i in rest: 

308 if is_attr: 

309 obj = safe_getattr(obj, i, self.raise_exc, self.node) 

310 else: 

311 obj = obj[i] 

312 return obj, first 

313 

314def safe_format(_string, raise_exc, node, *args, **kwargs): 

315 formatter = SafeFormatter(raise_exc, node) 

316 return formatter.vformat(_string, args, kwargs) 

317 

318def valid_symbol_name(name): 

319 """Determine whether the input symbol name is a valid name. 

320 

321 Arguments 

322 --------- 

323 name : str 

324 name to check for validity. 

325 

326 Returns 

327 -------- 

328 valid : bool 

329 whether name is a a valid symbol name 

330 

331 This checks for Python reserved words and that the name matches 

332 the regular expression ``[a-zA-Z_][a-zA-Z0-9_]`` 

333 """ 

334 if name in RESERVED_WORDS: 

335 return False 

336 

337 gen = generate_tokens(io.BytesIO(name.encode('utf-8')).readline) 

338 typ, _, start, end, _ = next(gen) 

339 if typ == tk_ENCODING: 

340 typ, _, start, end, _ = next(gen) 

341 return typ == tk_NAME and start == (1, 0) and end == (1, len(name)) 

342 

343 

344def op2func(oper): 

345 """Return function for operator nodes.""" 

346 return OPERATORS[oper.__class__] 

347 

348 

349class Empty: 

350 """Empty class.""" 

351 def __init__(self): 

352 """TODO: docstring in public method.""" 

353 return 

354 

355 def __nonzero__(self): 

356 """Empty is TODO: docstring in magic method.""" 

357 return False 

358 

359 def __repr__(self): 

360 """Empty is TODO: docstring in magic method.""" 

361 return "Empty" 

362 

363ReturnedNone = Empty() 

364 

365class ExceptionHolder: 

366 """Basic exception handler.""" 

367 def __init__(self, node, exc=None, msg='', expr=None, 

368 text=None, lineno=None): 

369 """TODO: docstring in public method.""" 

370 self.node = node 

371 self.expr = expr 

372 self.msg = msg 

373 self.exc = exc 

374 self.text = text 

375 self.lineno = lineno 

376 self.end_lineno = lineno 

377 self.col_offset = 0 

378 if lineno is None: 

379 try: 

380 self.lineno = node.lineno 

381 self.end_lineno = node.end_lineno 

382 self.col_offset = node.col_offset 

383 except: 

384 pass 

385 self.exc_info = exc_info() 

386 if self.exc is None and self.exc_info[0] is not None: 

387 self.exc = self.exc_info[0] 

388 if self.msg == '' and self.exc_info[1] is not None: 

389 self.msg = str(self.exc_info[1]) 

390 

391 def get_error(self): 

392 """Retrieve error data.""" 

393 try: 

394 exc_name = self.exc.__name__ 

395 except AttributeError: 

396 exc_name = str(self.exc) 

397 if exc_name in (None, 'None'): 

398 exc_name = 'UnknownError' 

399 

400 out = [] 

401 self.code = [f'{l}' for l in self.text.split('\n')] 

402 self.codelines = [f'{i+1}: {l}' for i, l in enumerate(self.code)] 

403 

404 try: 

405 out.append('\n'.join(self.code[self.lineno-1:self.end_lineno])) 

406 except: 

407 out.append(f"{self.expr}") 

408 if self.col_offset > 0: 

409 out.append(f"{self.col_offset*' '}^^^^") 

410 out.append(f"{exc_name}: {self.msg}") 

411 return (exc_name, '\n'.join(out)) 

412 

413 def __repr__(self): 

414 return f"ExceptionHolder({self.exc}, {self.msg})" 

415 

416class NameFinder(ast.NodeVisitor): 

417 """Find all symbol names used by a parsed node.""" 

418 

419 def __init__(self): 

420 """TODO: docstring in public method.""" 

421 self.names = [] 

422 ast.NodeVisitor.__init__(self) 

423 

424 def generic_visit(self, node): 

425 """TODO: docstring in public method.""" 

426 if node.__class__.__name__ == 'Name': 

427 if node.id not in self.names: 

428 self.names.append(node.id) 

429 ast.NodeVisitor.generic_visit(self, node) 

430 

431 

432def get_ast_names(astnode): 

433 """Return symbol Names from an AST node.""" 

434 finder = NameFinder() 

435 finder.generic_visit(astnode) 

436 return finder.names 

437 

438 

439def valid_varname(name): 

440 "is this a valid variable name" 

441 return name.isidentifier() and name not in RESERVED_WORDS 

442 

443 

444class Group(dict): 

445 """ 

446 Group: a container of objects that can be accessed either as an object attributes 

447 or dictionary key/value. Attribute names must follow Python naming conventions. 

448 """ 

449 def __init__(self, name=None, searchgroups=None, **kws): 

450 if name is None: 

451 name = hex(id(self)) 

452 self.__name__ = name 

453 dict.__init__(self, **kws) 

454 self._searchgroups = searchgroups 

455 

456 def __setattr__(self, name, value): 

457 if not valid_varname(name): 

458 raise SyntaxError(f"invalid attribute name '{name}'") 

459 self[name] = value 

460 

461 def __getattr__(self, name, default=None): 

462 if name in self: 

463 return self[name] 

464 if default is not None: 

465 return default 

466 raise KeyError(f"no attribute named '{name}'") 

467 

468 def __setitem__(self, name, value): 

469 if valid_varname(name): 

470 dict.__setitem__(self, name, value) 

471 else: # raise SyntaxError(f"invalid attribute name '{name}'") 

472 return setattr(self, name, value) 

473 

474 def get(self, key, default=None): 

475 val = self.__getattr__(key, ReturnedNone) 

476 if not isinstance(val, Empty): 

477 return val 

478 searchgroups = self._searchgroups 

479 if searchgroups is not None: 

480 for sgroup in searchgroups: 

481 grp = self.__getattr__(sgroup, None) 

482 if isinstance(grp, (Group, dict)): 

483 val = grp.__getattr__(key, ReturnedNone) 

484 if not isinstance(val, Empty): 

485 return val 

486 return default 

487 

488 

489 def __repr__(self): 

490 keys = [a for a in self.keys() if a != '__name__'] 

491 return f"Group('{self.__name__}', {len(keys)} symbols)" 

492 

493 def _repr_html_(self): 

494 """HTML representation for Jupyter notebook""" 

495 html = [f"<table><caption>Group('{self.__name__}')</caption>", 

496 "<tr><th>Attribute</th><th>DataType</th><th><b>Value</b></th></tr>"] 

497 for key, val in self.items(): 

498 html.append(f""" 

499<tr><td>{key}</td><td><i>{type(val).__name__}</i></td> 

500 <td>{repr(val):.75s}</td> 

501</tr>""") 

502 html.append("</table>") 

503 return '\n'.join(html) 

504 

505 

506def make_symbol_table(use_numpy=True, nested=False, top=True, **kws): 

507 """Create a default symboltable, taking dict of user-defined symbols. 

508 

509 Arguments 

510 --------- 

511 numpy : bool, optional 

512 whether to include symbols from numpy [True] 

513 nested : bool, optional 

514 whether to make a "new-style" nested table instead of a plain dict [False] 

515 top : bool, optional 

516 whether this is the top-level table in a nested-table [True] 

517 kws : optional 

518 additional symbol name, value pairs to include in symbol table 

519 

520 Returns 

521 -------- 

522 symbol_table : dict or nested Group 

523 a symbol table that can be used in `asteval.Interpereter` 

524 

525 """ 

526 if nested: 

527 name = '_' 

528 if top: 

529 name = '_main' 

530 if 'name' in kws: 

531 name = kws.pop('name') 

532 symtable = Group(name=name, Group=Group) 

533 else: 

534 symtable = {} 

535 

536 symtable.update(BUILTINS_TABLE) 

537 symtable.update(LOCALFUNCS) 

538 symtable.update(kws) 

539 math_functions = dict(MATH_TABLE.items()) 

540 if use_numpy: 

541 math_functions.update(NUMPY_TABLE) 

542 

543 if nested: 

544 symtable['math'] = Group(name='math', **math_functions) 

545 symtable['Group'] = Group 

546 symtable._searchgroups = ('math',) 

547 else: 

548 symtable.update(math_functions) 

549 symtable.update(**kws) 

550 return symtable 

551 

552 

553class Procedure: 

554 """Procedure: user-defined function for asteval. 

555 

556 This stores the parsed ast nodes as from the 'functiondef' ast node 

557 for later evaluation. 

558 

559 """ 

560 

561 def __init__(self, name, interp, doc=None, lineno=None, 

562 body=None, text=None, args=None, kwargs=None, 

563 vararg=None, varkws=None): 

564 """TODO: docstring in public method.""" 

565 self.__ininit__ = True 

566 self.name = name 

567 self.__name__ = self.name 

568 self.__asteval__ = interp 

569 self.__raise_exc__ = self.__asteval__.raise_exception 

570 self.__doc__ = doc 

571 self.__body__ = body 

572 self.__argnames__ = args 

573 self.__kwargs__ = kwargs 

574 self.__vararg__ = vararg 

575 self.__varkws__ = varkws 

576 self.lineno = lineno 

577 self.__text__ = text 

578 if text is None: 

579 self.__text__ = f'{self.__signature__()}\n' + ast.unparse(self.__body__) 

580 self.__ininit__ = False 

581 

582 def __setattr__(self, attr, val): 

583 if not getattr(self, '__ininit__', True): 

584 self.__raise_exc__(None, exc=TypeError, 

585 msg="procedure is read-only") 

586 self.__dict__[attr] = val 

587 

588 def __dir__(self): 

589 return ['__getdoc__', 'argnames', 'kwargs', 'name', 'vararg', 'varkws'] 

590 

591 def __getdoc__(self): 

592 doc = self.__doc__ 

593 if isinstance(doc, ast.Constant): 

594 doc = doc.value 

595 return doc 

596 

597 def __repr__(self): 

598 """TODO: docstring in magic method.""" 

599 sig = self.__signature__() 

600 rep = f"<Procedure {sig}>" 

601 doc = self.__getdoc__() 

602 if doc is not None: 

603 rep = f"{rep}\n {doc}" 

604 return rep 

605 

606 def __signature__(self): 

607 "call signature" 

608 sig = "" 

609 if len(self.__argnames__) > 0: 

610 sig = sig + ', '.join(self.__argnames__) 

611 if self.__vararg__ is not None: 

612 sig = sig + f"*{self.__vararg__}" 

613 if len(self.__kwargs__) > 0: 

614 if len(sig) > 0: 

615 sig = f"{sig}, " 

616 _kw = [f"{k}={v}" for k, v in self.__kwargs__] 

617 sig = f"{sig}{', '.join(_kw)}" 

618 

619 if self.__varkws__ is not None: 

620 sig = f"{sig}, **{self.__varkws__}" 

621 return f"{self.name}({sig})" 

622 

623 def __call__(self, *args, **kwargs): 

624 """TODO: docstring in public method.""" 

625 topsym = self.__asteval__.symtable 

626 if self.__asteval__.config.get('nested_symtable', False): 

627 sargs = {'_main': topsym} 

628 sgroups = topsym.get('_searchgroups', None) 

629 if sgroups is not None: 

630 for sxname in sgroups: 

631 sargs[sxname] = topsym.get(sxname) 

632 

633 

634 symlocals = Group(name=f'symtable_{self.name}_', **sargs) 

635 symlocals._searchgroups = list(sargs.keys()) 

636 else: 

637 symlocals = {} 

638 

639 args = list(args) 

640 nargs = len(args) 

641 nkws = len(kwargs) 

642 nargs_expected = len(self.__argnames__) 

643 

644 # check for too few arguments, but the correct keyword given 

645 if (nargs < nargs_expected) and nkws > 0: 

646 for name in self.__argnames__[nargs:]: 

647 if name in kwargs: 

648 args.append(kwargs.pop(name)) 

649 nargs = len(args) 

650 nargs_expected = len(self.__argnames__) 

651 nkws = len(kwargs) 

652 if nargs < nargs_expected: 

653 msg = f"{self.name}() takes at least" 

654 msg = f"{msg} {nargs_expected} arguments, got {nargs}" 

655 self.__raise_exc__(None, exc=TypeError, msg=msg) 

656 # check for multiple values for named argument 

657 if len(self.__argnames__) > 0 and kwargs is not None: 

658 msg = "multiple values for keyword argument" 

659 for targ in self.__argnames__: 

660 if targ in kwargs: 

661 msg = f"{msg} '{targ}' in Procedure {self.name}" 

662 self.__raise_exc__(None, exc=TypeError, msg=msg, 

663 lineno=self.lineno) 

664 

665 # check more args given than expected, varargs not given 

666 if nargs != nargs_expected: 

667 msg = None 

668 if nargs < nargs_expected: 

669 msg = f"not enough arguments for Procedure {self.name}()" 

670 msg = f"{msg} (expected {nargs_expected}, got {nargs}" 

671 self.__raise_exc__(None, exc=TypeError, msg=msg) 

672 

673 if nargs > nargs_expected and self.__vararg__ is None: 

674 if nargs - nargs_expected > len(self.__kwargs__): 

675 msg = f"too many arguments for {self.name}() expected at most" 

676 msg = f"{msg} {len(self.__kwargs__)+nargs_expected}, got {nargs}" 

677 self.__raise_exc__(None, exc=TypeError, msg=msg) 

678 

679 for i, xarg in enumerate(args[nargs_expected:]): 

680 kw_name = self.__kwargs__[i][0] 

681 if kw_name not in kwargs: 

682 kwargs[kw_name] = xarg 

683 

684 for argname in self.__argnames__: 

685 symlocals[argname] = args.pop(0) 

686 

687 try: 

688 if self.__vararg__ is not None: 

689 symlocals[self.__vararg__] = tuple(args) 

690 

691 for key, val in self.__kwargs__: 

692 if key in kwargs: 

693 val = kwargs.pop(key) 

694 symlocals[key] = val 

695 

696 if self.__varkws__ is not None: 

697 symlocals[self.__varkws__] = kwargs 

698 

699 elif len(kwargs) > 0: 

700 msg = f"extra keyword arguments for Procedure {self.name}: " 

701 msg = msg + ','.join(list(kwargs.keys())) 

702 self.__raise_exc__(None, msg=msg, exc=TypeError, 

703 lineno=self.lineno) 

704 

705 except (ValueError, LookupError, TypeError, 

706 NameError, AttributeError): 

707 msg = f"incorrect arguments for Procedure {self.name}" 

708 self.__raise_exc__(None, msg=msg, lineno=self.lineno) 

709 

710 if self.__asteval__.config.get('nested_symtable', False): 

711 save_symtable = self.__asteval__.symtable 

712 self.__asteval__.symtable = symlocals 

713 else: 

714 save_symtable = self.__asteval__.symtable.copy() 

715 self.__asteval__.symtable.update(symlocals) 

716 

717 self.__asteval__.retval = None 

718 self.__asteval__._calldepth += 1 

719 retval = None 

720 

721 # evaluate script of function 

722 self.__asteval__.code_text.append(self.__text__) 

723 for node in self.__body__: 

724 self.__asteval__.run(node, lineno=node.lineno) 

725 if len(self.__asteval__.error) > 0: 

726 break 

727 if self.__asteval__.retval is not None: 

728 retval = self.__asteval__.retval 

729 self.__asteval__.retval = None 

730 if retval is ReturnedNone: 

731 retval = None 

732 break 

733 

734 self.__asteval__.symtable = save_symtable 

735 self.__asteval__.code_text.pop() 

736 self.__asteval__._calldepth -= 1 

737 symlocals = None 

738 return retval