1010from pandas .api .types import (
1111 is_list_like ,
1212 is_string_dtype ,
13+ is_categorical_dtype ,
1314)
1415from pandas .core .dtypes .concat import concat_compat
1516
@@ -1176,6 +1177,7 @@ def _final_frame_longer(
11761177 df = {** index , ** outcome , ** values }
11771178
11781179 df = pd .DataFrame (df , copy = False , index = df_index )
1180+ df_index = None
11791181
11801182 if sort_by_appearance :
11811183 df = _sort_by_appearance_for_melt (df = df , len_index = len_index )
@@ -1198,6 +1200,9 @@ def pivot_wider(
11981200 flatten_levels : Optional [bool ] = True ,
11991201 names_sep : str = "_" ,
12001202 names_glue : str = None ,
1203+ reset_index : bool = True ,
1204+ names_expand : bool = False ,
1205+ index_expand : bool = False ,
12011206) -> pd .DataFrame :
12021207 """
12031208 Reshapes data from *long* to *wide* form.
@@ -1222,6 +1227,7 @@ def pivot_wider(
12221227 at the start of each label in the columns.
12231228
12241229
1230+
12251231 Example:
12261232
12271233 >>> import pandas as pd
@@ -1292,9 +1298,16 @@ def pivot_wider(
12921298 and uses python's `str.format_map` under the hood.
12931299 Simply create the string template,
12941300 using the column labels in `names_from`,
1295- and special `_value` as a placeholder
1296- if there are multiple `values_from`.
1301+ and special `_value` as a placeholder for `values_from`.
12971302 Applicable only if `flatten_levels` is `True`.
1303+ :param reset_index: Determines whether to restore `index`
1304+ as a column/columns. Applicable only if `index` is provided,
1305+ and `flatten_levels` is `True`. Default is `True`.
1306+ :param names_expand: Expand columns to show all the categories.
1307+ Applies only if `names_from` is a categorical column.
1308+ Default is `False`.
1309+ :param index_expand: Expand the index to show all the categories.
1310+ Applies only if `index` is a categorical column. Default is `False`.
12981311 :returns: A pandas DataFrame that has been unpivoted from long to wide
12991312 form.
13001313 """
@@ -1309,6 +1322,9 @@ def pivot_wider(
13091322 flatten_levels ,
13101323 names_sep ,
13111324 names_glue ,
1325+ reset_index ,
1326+ names_expand ,
1327+ index_expand ,
13121328 )
13131329
13141330
@@ -1320,6 +1336,9 @@ def _computations_pivot_wider(
13201336 flatten_levels : Optional [bool ] = True ,
13211337 names_sep : str = "_" ,
13221338 names_glue : str = None ,
1339+ reset_index : bool = True ,
1340+ names_expand : bool = False ,
1341+ index_expand : bool = False ,
13231342) -> pd .DataFrame :
13241343 """
13251344 This is the main workhorse of the `pivot_wider` function.
@@ -1339,6 +1358,9 @@ def _computations_pivot_wider(
13391358 flatten_levels ,
13401359 names_sep ,
13411360 names_glue ,
1361+ reset_index ,
1362+ names_expand ,
1363+ index_expand ,
13421364 ) = _data_checks_pivot_wider (
13431365 df ,
13441366 index ,
@@ -1347,30 +1369,54 @@ def _computations_pivot_wider(
13471369 flatten_levels ,
13481370 names_sep ,
13491371 names_glue ,
1372+ reset_index ,
1373+ names_expand ,
1374+ index_expand ,
13501375 )
1351- if flatten_levels :
1352- # check dtype of `names_from` is string
1353- names_from_all_strings = (
1354- df .filter (names_from ).agg (is_string_dtype ).all ().item ()
1355- )
1356-
1357- # check dtype of columns
1358- column_dtype = is_string_dtype (df .columns )
13591376
13601377 df = df .pivot ( # noqa: PD010
13611378 index = index , columns = names_from , values = values_from
13621379 )
13631380
1364- # an empty df is likely because
1365- # there is no `values_from`
1381+ indexer = df .index
1382+ if index_expand and index :
1383+ any_categoricals = (indexer .get_level_values (name ) for name in index )
1384+ any_categoricals = any (map (is_categorical_dtype , any_categoricals ))
1385+ if any_categoricals :
1386+ indexer = _expand (indexer , retain_categories = True )
1387+ df = df .reindex (index = indexer )
1388+
1389+ indexer = df .columns
1390+ if names_expand :
1391+ any_categoricals = (
1392+ indexer .get_level_values (name ) for name in names_from
1393+ )
1394+ any_categoricals = any (map (is_categorical_dtype , any_categoricals ))
1395+ if any_categoricals :
1396+ retain_categories = True
1397+ if flatten_levels & (
1398+ (names_glue is not None )
1399+ | isinstance (indexer , pd .MultiIndex )
1400+ | ((index is not None ) & reset_index )
1401+ ):
1402+ retain_categories = False
1403+ indexer = _expand (indexer , retain_categories = retain_categories )
1404+ df = df .reindex (columns = indexer )
1405+
1406+ indexer = None
13661407 if any ((df .empty , not flatten_levels )):
13671408 return df
13681409
13691410 if isinstance (df .columns , pd .MultiIndex ):
1370- if (not names_from_all_strings ) or (not column_dtype ):
1371- new_columns = [tuple (map (str , entry )) for entry in df ]
1372- else :
1373- new_columns = [entry for entry in df ]
1411+ new_columns = df .columns
1412+ all_strings = (
1413+ new_columns .get_level_values (num )
1414+ for num in range (new_columns .nlevels )
1415+ )
1416+ all_strings = all (map (is_string_dtype , all_strings ))
1417+ if not all_strings :
1418+ new_columns = (tuple (map (str , entry )) for entry in new_columns )
1419+
13741420 if names_glue is not None :
13751421 if ("_value" in names_from ) and (None in df .columns .names ):
13761422 warnings .warn (
@@ -1403,24 +1449,18 @@ def _computations_pivot_wider(
14031449
14041450 df .columns = new_columns
14051451 else :
1406- if (not names_from_all_strings ) or (not column_dtype ):
1407- df .columns = df .columns .astype (str )
14081452 if names_glue is not None :
14091453 try :
14101454 df .columns = [
14111455 names_glue .format_map ({names_from [0 ]: entry })
1412- for entry in df
1456+ for entry in df . columns
14131457 ]
14141458 except KeyError as error :
14151459 raise KeyError (
14161460 f"{ error } is not a column label in names_from."
14171461 ) from error
14181462
1419- # if columns are of category type
1420- # this returns columns to object dtype
1421- # also, resetting index with category columns is not possible
1422- df .columns = [* df .columns ]
1423- if index :
1463+ if index and reset_index :
14241464 df = df .reset_index ()
14251465
14261466 if df .columns .names :
@@ -1437,6 +1477,9 @@ def _data_checks_pivot_wider(
14371477 flatten_levels ,
14381478 names_sep ,
14391479 names_glue ,
1480+ reset_index ,
1481+ names_expand ,
1482+ index_expand ,
14401483):
14411484
14421485 """
@@ -1464,9 +1507,12 @@ def _data_checks_pivot_wider(
14641507 if values_from is not None :
14651508 if is_list_like (values_from ):
14661509 values_from = [* values_from ]
1467- values_from = _select_column_names (values_from , df )
1468- if len (values_from ) == 1 :
1469- values_from = values_from [0 ]
1510+ out = _select_column_names (values_from , df )
1511+ # hack to align with pd.pivot
1512+ if values_from == out [0 ]:
1513+ values_from = out [0 ]
1514+ else :
1515+ values_from = out
14701516
14711517 check ("flatten_levels" , flatten_levels , [bool ])
14721518
@@ -1476,6 +1522,10 @@ def _data_checks_pivot_wider(
14761522 if names_glue is not None :
14771523 check ("names_glue" , names_glue , [str ])
14781524
1525+ check ("reset_index" , reset_index , [bool ])
1526+ check ("names_expand" , names_expand , [bool ])
1527+ check ("index_expand" , index_expand , [bool ])
1528+
14791529 return (
14801530 df ,
14811531 index ,
@@ -1484,4 +1534,51 @@ def _data_checks_pivot_wider(
14841534 flatten_levels ,
14851535 names_sep ,
14861536 names_glue ,
1537+ reset_index ,
1538+ names_expand ,
1539+ index_expand ,
14871540 )
1541+
1542+
1543+ def _expand (indexer , retain_categories ):
1544+ """
1545+ Expand Index to all categories.
1546+ Applies to categorical index, and used
1547+ in _computations_pivot_wider for scenarios where
1548+ names_expand and/or index_expand is True.
1549+ Categories are preserved where possible.
1550+ If `retain_categories` is False, a fastpath is taken
1551+ to generate all possible combinations.
1552+
1553+ Returns an Index.
1554+ """
1555+ if indexer .nlevels > 1 :
1556+ names = indexer .names
1557+ if not retain_categories :
1558+ indexer = pd .MultiIndex .from_product (indexer .levels , names = names )
1559+ else :
1560+ indexer = [
1561+ indexer .get_level_values (n ) for n in range (indexer .nlevels )
1562+ ]
1563+ indexer = [
1564+ pd .Categorical (
1565+ values = arr .categories ,
1566+ categories = arr .categories ,
1567+ ordered = arr .ordered ,
1568+ )
1569+ if is_categorical_dtype (arr )
1570+ else arr .unique ()
1571+ for arr in indexer
1572+ ]
1573+ indexer = pd .MultiIndex .from_product (indexer , names = names )
1574+
1575+ else :
1576+ if not retain_categories :
1577+ indexer = indexer .categories
1578+ else :
1579+ indexer = pd .Categorical (
1580+ values = indexer .categories ,
1581+ categories = indexer .categories ,
1582+ ordered = indexer .ordered ,
1583+ )
1584+ return indexer
0 commit comments