11import warnings
22from numbers import Number
33from textwrap import dedent
4- from typing import Dict , List , Tuple , Union
4+ from typing import Dict , List , Sequence , Tuple , Union
55
66import numpy as np
77
1717from aesara .tensor import basic as at
1818from aesara .tensor import get_vector_length
1919from aesara .tensor .exceptions import NotScalarConstantError
20- from aesara .tensor .type import DenseTensorType , TensorType , int_dtypes , tensor
20+ from aesara .tensor .type import (
21+ DenseTensorType ,
22+ TensorType ,
23+ int_dtypes ,
24+ shape_encode ,
25+ tensor ,
26+ )
2127from aesara .tensor .type_other import NoneConst
2228from aesara .tensor .var import TensorConstant , TensorVariable
2329
2430
31+ def filter_shape_vars (
32+ ref_shape : Tuple [int , ...], shape : Sequence [Variable ], shape_is_encoded : bool = True
33+ ) -> Tuple [int , ...]:
34+ r"""Compute the most \"informative\" shape based on a static reference.
35+
36+ Parameters
37+ ----------
38+ ref_shape
39+ A static shape reference using static shape constraint encoding.
40+ shape
41+ A symbolic shape.
42+ shape_is_encoded
43+ If ``True``, `shape` is assumed to be static shape constraint encoded.
44+
45+ Returns
46+ -------
47+ The most specific, and compatible (with `ref_shape`), static shape
48+ constraint encoded values.
49+ """
50+ shape_bottom = shape_encode (None )
51+ type_shape = ()
52+ for i , (xts , s ) in enumerate (zip (ref_shape , shape )):
53+
54+ try :
55+ # TODO FIXME: We shouldn't need to do this; let a rewrite
56+ # do constant folding and update the `TensorType`s.
57+ s_val = at .get_scalar_constant_value (s )
58+
59+ if isinstance (s_val , np .ndarray ):
60+ s_val = s_val .item ()
61+
62+ if shape_is_encoded or s_val is not None and s_val > 0 :
63+ type_s = shape_encode (s_val )
64+ else :
65+ type_s = shape_bottom
66+ except NotScalarConstantError :
67+ type_s = shape_bottom
68+
69+ if not (xts <= - 1 or type_s <= - 1 or type_s == xts ):
70+ raise AssertionError (
71+ f"SpecifyShape: Got shape { xts } at index { i } , expected { type_s } ."
72+ )
73+
74+ type_shape += (max (type_s , xts ),)
75+
76+ return type_shape
77+
78+
2579def register_shape_c_code (type , code , version = ()):
2680 """
2781 Tell Shape Op how to generate C code for an Aesara Type.
@@ -394,7 +448,6 @@ class SpecifyShape(COp):
394448 _f16_ok = True
395449
396450 def make_node (self , x , * shape ):
397- from aesara .tensor .basic import get_scalar_constant_value
398451
399452 x = at .as_tensor_variable (x )
400453
@@ -417,18 +470,7 @@ def make_node(self, x, *shape):
417470 f"Input `x` is { x .type .ndim } -dimensional and will never match a shape of length { len (shape )} ."
418471 )
419472
420- type_shape = [None ] * x .ndim
421- for i , (xts , s ) in enumerate (zip (x .type .shape , shape )):
422- if xts is not None :
423- type_shape [i ] = xts
424- else :
425- try :
426- type_s = get_scalar_constant_value (s )
427- if type_s is not None :
428- type_shape [i ] = int (type_s )
429- except NotScalarConstantError :
430- pass
431-
473+ type_shape = filter_shape_vars (x .type .shape_encoded , shape )
432474 out_var = x .type .clone (shape = type_shape )()
433475
434476 return Apply (self , [x , * shape ], [out_var ])
@@ -441,10 +483,10 @@ def perform(self, node, inp, out_):
441483 raise AssertionError (
442484 f"SpecifyShape: Got { x .ndim } dimensions (shape { x .shape } ), expected { ndim } dimensions with shape { tuple (shape )} ."
443485 )
444- if not all ( xs == s for xs , s in zip (x .shape , shape ) if s is not None ):
445- raise AssertionError (
446- f"SpecifyShape: Got shape { x . shape } , expected { tuple ( int ( s ) if s is not None else None for s in shape ) } ."
447- )
486+ for xs , s in zip (x .shape , shape ):
487+ if ( s == - 1 and xs == 1 ) or ( s is not None and s > - 1 and not xs == s ):
488+ raise AssertionError ( f"SpecifyShape: Got shape { xs } , expected { s } ." )
489+
448490 out [0 ] = x
449491
450492 def infer_shape (self , fgraph , node , shapes ):
@@ -454,11 +496,11 @@ def infer_shape(self, fgraph, node, shapes):
454496 for dim in range (node .inputs [0 ].type .ndim ):
455497 s = shape [dim ]
456498 try :
457- s = at .get_scalar_constant_value (s )
458- # We assume that `None` shapes are always retrieved by
499+ s = shape_encode ( at .get_scalar_constant_value (s ) )
500+ # We assume that negative shapes are always retrieved by
459501 # `get_scalar_constant_value`, and only in that case do we default to
460502 # the shape of the input variable
461- if s is None :
503+ if s < 0 :
462504 s = xshape [dim ]
463505 except NotScalarConstantError :
464506 pass
@@ -502,6 +544,9 @@ def c_code(self, node, name, i_names, o_names, sub):
502544 );
503545 { fail } ;
504546 }}
547+
548+ npy_intp shp;
549+ npy_intp actual_shp;
505550 """
506551 )
507552
@@ -510,9 +555,11 @@ def c_code(self, node, name, i_names, o_names, sub):
510555 continue
511556 code += dedent (
512557 f"""
513- if (py_{ shp_name } != Py_None){{
514- dtype_{ shp_name } shp = ((dtype_{ shp_name } *)PyArray_GETPTR1({ shp_name } , 0))[0];
515- if (PyArray_DIMS({ x_name } )[{ i } ] != shp) {{
558+ shp = ((dtype_{ shp_name } *)PyArray_GETPTR1({ shp_name } , 0))[0];
559+
560+ if (shp > -2) {{
561+ actual_shp = PyArray_DIMS({ x_name } )[{ i } ];
562+ if (actual_shp == -1 && shp == 1 || actual_shp != shp) {{
516563 PyErr_Format(PyExc_AssertionError,
517564 "SpecifyShape: dim %d of input has shape %d, expected %d.",
518565 { i } , PyArray_DIMS({ x_name } )[{ i } ], shp
@@ -533,7 +580,7 @@ def c_code(self, node, name, i_names, o_names, sub):
533580 return code
534581
535582 def c_code_cache_version (self ):
536- return (2 ,)
583+ return (3 ,)
537584
538585
539586_specify_shape = SpecifyShape ()
0 commit comments