@@ -578,31 +578,7 @@ def plot(self, index_by='element', material=[], loc=0, **kwargs):
578578 xlim = [float ('inf' ), - float ('inf' )]
579579 ylim = [float ('inf' ), - float ('inf' )]
580580 if n_dim == 2 :
581- simps = np .array (self .elements )
582- pts = np .array (self .points )
583- xy = pts [simps , :]
584-
585- plt_kwargs = {}
586- for key , value in kwargs .items ():
587- if type (value ) in (list , np .array ):
588- plt_value = []
589- for e_num , e_att in enumerate (self .element_attributes ):
590- if index_by == 'element' :
591- ind = e_num
592- elif index_by == 'attribute' :
593- ind = int (e_att )
594- else :
595- e_str = 'Cannot index by {}.' .format (index_by )
596- raise ValueError (e_str )
597- v = value [ind ]
598- plt_value .append (v )
599- else :
600- plt_value = value
601- plt_kwargs [key ] = plt_value
602-
603- pc = collections .PolyCollection (xy , ** plt_kwargs )
604- ax .add_collection (pc )
605- ax .autoscale_view ()
581+ _plot_2d (ax , self , index_by , ** kwargs )
606582 else :
607583 if n_obj > 0 :
608584 zlim = ax .get_zlim ()
@@ -804,7 +780,7 @@ def from_polymesh(cls, polymesh, mesh_size, phases=None):
804780 elem_atts = np .full (cens .shape [0 ], - 1 )
805781 for r_num , region in enumerate (polymesh .regions ):
806782 # A. Create a bounding box
807- r_kps = np .unique ([polymesh . facets [ f ] for f in region ])
783+ r_kps = np .unique ([k for f in region for k in polymesh . facets [ f ] ])
808784 r_pts = p_pts [r_kps ]
809785 r_mins = r_pts .min (axis = 0 )
810786 r_maxs = r_pts .max (axis = 0 )
@@ -1260,8 +1236,12 @@ def as_array(self, element_attributes=True):
12601236
12611237 Args:
12621238 element_attributes (bool): *(optional)* Flag to return element
1263- attributes in the array. Set to True return attributes and set to
1264- False to return element indices. Defaults to True.
1239+ attributes in the array. Set to True return attributes and
1240+ set to False to return element indices. Defaults to True.
1241+
1242+ Returns:
1243+ numpy.ndarray: Array of values of element atttributes, or indices.
1244+
12651245 """
12661246 # 1. Convert 1st node of each element into array indices
12671247 pts = np .array (self .points )
@@ -1270,7 +1250,7 @@ def as_array(self, element_attributes=True):
12701250
12711251 corner_pts = pts [np .array (self .elements )[:, 0 ]]
12721252 rel_pos = corner_pts - mins
1273- elem_tups = (rel_pos / sz ).astype (int )
1253+ elem_tups = np . round (rel_pos / sz ).astype (int )
12741254
12751255 # 2. Create array full of -1 values
12761256 inds_maxs = elem_tups .max (axis = 0 )
@@ -1290,7 +1270,119 @@ def as_array(self, element_attributes=True):
12901270 # ----------------------------------------------------------------------- #
12911271 # Plot Function #
12921272 # ----------------------------------------------------------------------- #
1293- # Inherited from TriMesh
1273+ def plot (self , index_by = 'element' , material = [], loc = 0 , ** kwargs ):
1274+ """Plot the mesh.
1275+
1276+ This method plots the mesh using matplotlib.
1277+ In 2D, this creates a :class:`matplotlib.collections.PolyCollection`
1278+ and adds it to the current axes.
1279+ In 3D, it creates a
1280+ :meth:`mpl_toolkits.mplot3d.axes3d.Axes3D.voxels` and
1281+ adds it to the current axes.
1282+ The keyword arguments are passed though to matplotlib.
1283+
1284+ Args:
1285+ index_by (str): *(optional)* {'element' | 'attribute'}
1286+ Flag for indexing into the other arrays passed into the
1287+ function. For example,
1288+ ``plot(index_by='attribute', color=['blue', 'red'])`` will plot
1289+ the elements with ``element_attribute`` equal to 0 in blue, and
1290+ elements with ``element_attribute`` equal to 1 in red.
1291+ Defaults to 'element'.
1292+ material (list): *(optional)* Names of material phases. One entry
1293+ per material phase (the ``index_by`` argument is ignored).
1294+ If this argument is set, a legend is added to the plot with
1295+ one entry per material. Note that the ``element_attributes``
1296+ must be the material numbers for the legend to be
1297+ formatted properly.
1298+ loc (int or str): *(optional)* The location of the legend,
1299+ if 'material' is specified. This argument is passed directly
1300+ through to :func:`matplotlib.pyplot.legend`. Defaults to 0,
1301+ which is 'best' in matplotlib.
1302+ **kwargs: Keyword arguments that are passed through to matplotlib.
1303+
1304+ """
1305+ n_dim = len (self .points [0 ])
1306+ if n_dim == 2 :
1307+ ax = plt .gca ()
1308+ else :
1309+ ax = plt .gcf ().gca (projection = Axes3D .name )
1310+ n_obj = _misc .ax_objects (ax )
1311+ if n_obj > 0 :
1312+ xlim = ax .get_xlim ()
1313+ ylim = ax .get_ylim ()
1314+ else :
1315+ xlim = [float ('inf' ), - float ('inf' )]
1316+ ylim = [float ('inf' ), - float ('inf' )]
1317+ if n_dim == 2 :
1318+ _plot_2d (ax , self , index_by , ** kwargs )
1319+ else :
1320+ if n_obj > 0 :
1321+ zlim = ax .get_zlim ()
1322+ else :
1323+ zlim = [float ('inf' ), - float ('inf' )]
1324+
1325+ inds = self .as_array (element_attributes = index_by == 'attribute' )
1326+ plt_kwargs = {}
1327+ for key , value in kwargs .items ():
1328+ if type (value ) in (list , np .array ):
1329+ plt_value = np .empty (inds .shape , dtype = object )
1330+ for i , val_i in enumerate (value ):
1331+ plt_value [inds == i ] = val_i
1332+ else :
1333+ plt_value = value
1334+ plt_kwargs [key ] = plt_value
1335+
1336+ # Scale axes
1337+ pts = np .array (self .points )
1338+ mins = pts .min (axis = 0 )
1339+ sz = self .mesh_size
1340+ pt_tups = np .round ((pts - mins ) / sz ).astype (int )
1341+ maxs = pt_tups .max (axis = 0 )
1342+ grids = np .indices (maxs + 1 , dtype = float )
1343+ for pt , pt_tup in zip (pts , pt_tups ):
1344+ for i , x in enumerate (pt ):
1345+ grids [i ][tuple (pt_tup )] = x
1346+ ax .voxels (* grids , inds >= 0 , ** plt_kwargs )
1347+
1348+ # Add legend
1349+ if material and index_by == 'attribute' :
1350+ p_kwargs = [{'label' : m } for m in material ]
1351+ for key , value in kwargs .items ():
1352+ if type (value ) not in (list , np .array ):
1353+ for kws in p_kwargs :
1354+ kws [key ] = value
1355+
1356+ for i , m in enumerate (material ):
1357+ if type (value ) in (list , np .array ):
1358+ p_kwargs [i ][key ] = value [i ]
1359+ else :
1360+ p_kwargs [i ][key ] = value
1361+
1362+ # Replace plural keywords
1363+ for p_kw in p_kwargs :
1364+ for kw in _misc .mpl_plural_kwargs :
1365+ if kw in p_kw :
1366+ p_kw [kw [:- 1 ]] = p_kw [kw ]
1367+ del p_kw [kw ]
1368+ handles = [patches .Patch (** p_kw ) for p_kw in p_kwargs ]
1369+ ax .legend (handles = handles , loc = loc )
1370+
1371+ # Adjust Axes
1372+ mins = np .array (self .points ).min (axis = 0 )
1373+ maxs = np .array (self .points ).max (axis = 0 )
1374+ xlim = (min (xlim [0 ], mins [0 ]), max (xlim [1 ], maxs [0 ]))
1375+ ylim = (min (ylim [0 ], mins [1 ]), max (ylim [1 ], maxs [1 ]))
1376+ if n_dim == 2 :
1377+ plt .axis ('square' )
1378+ plt .xlim (xlim )
1379+ plt .ylim (ylim )
1380+ elif n_dim == 3 :
1381+ zlim = (min (zlim [0 ], mins [2 ]), max (zlim [1 ], maxs [2 ]))
1382+ ax .set_xlim (xlim )
1383+ ax .set_ylim (ylim )
1384+ ax .set_zlim (zlim )
1385+ _misc .axisEqual3D (ax )
12941386
12951387
12961388def facet_check (neighs , polymesh , phases ):
@@ -1794,3 +1886,31 @@ def _facet_in_normal(pts, cen_pt):
17941886 vn *= sgn # flip so center is inward
17951887 un = vn / np .linalg .norm (vn )
17961888 return un , pts .mean (axis = 0 )
1889+
1890+
1891+ def _plot_2d (ax , mesh , index_by , ** kwargs ):
1892+ simps = np .array (mesh .elements )
1893+ pts = np .array (mesh .points )
1894+ xy = pts [simps , :]
1895+
1896+ plt_kwargs = {}
1897+ for key , value in kwargs .items ():
1898+ if type (value ) in (list , np .array ):
1899+ plt_value = []
1900+ for e_num , e_att in enumerate (mesh .element_attributes ):
1901+ if index_by == 'element' :
1902+ ind = e_num
1903+ elif index_by == 'attribute' :
1904+ ind = int (e_att )
1905+ else :
1906+ e_str = 'Cannot index by {}.' .format (index_by )
1907+ raise ValueError (e_str )
1908+ v = value [ind ]
1909+ plt_value .append (v )
1910+ else :
1911+ plt_value = value
1912+ plt_kwargs [key ] = plt_value
1913+
1914+ pc = collections .PolyCollection (xy , ** plt_kwargs )
1915+ ax .add_collection (pc )
1916+ ax .autoscale_view ()
0 commit comments