@@ -1762,6 +1762,94 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
17621762 return %0 : !torch.tensor
17631763}
17641764
1765+ // CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
1766+ func.func @torch.aten.to.dtype$fold_splat () -> (!torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[4 ,4 ],si32 >, !torch.vtensor <[10 ],si32 >, !torch.vtensor <[5 ,5 ],f64 >, !torch.vtensor <[3 ,3 ],f16 >, !torch.vtensor <[2 ,2 ],bf16 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[3 ],si16 >, !torch.vtensor <[2 ],i1 >, !torch.vtensor <[2 ],i1 >) {
1767+ // CHECK-NOT: torch.aten.to.dtype
1768+ %false = torch.constant.bool false
1769+ %none = torch.constant.none
1770+
1771+ // int32 splat → float32
1772+ %int_splat = torch.vtensor.literal (dense <42 > : tensor <2 x3 xsi32 >) : !torch.vtensor <[2 ,3 ],si32 >
1773+ %int6 = torch.constant.int 6 // torch.float32
1774+ // CHECK: %[[R1:.*]] = torch.vtensor.literal(dense<4.200000e+01> : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
1775+ %result1 = torch.aten.to.dtype %int_splat , %int6 , %false , %false , %none
1776+ : !torch.vtensor <[2 ,3 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1777+ -> !torch.vtensor <[2 ,3 ],f32 >
1778+
1779+ // float32 splat → int32 (rmTowardZero)
1780+ %float_splat = torch.vtensor.literal (dense <3.14159 > : tensor <4 x4 xf32 >) : !torch.vtensor <[4 ,4 ],f32 >
1781+ %int3 = torch.constant.int 3 // torch.int32
1782+ // CHECK: %[[R2:.*]] = torch.vtensor.literal(dense<3> : tensor<4x4xsi32>) : !torch.vtensor<[4,4],si32>
1783+ %result2 = torch.aten.to.dtype %float_splat , %int3 , %false , %false , %none
1784+ : !torch.vtensor <[4 ,4 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1785+ -> !torch.vtensor <[4 ,4 ],si32 >
1786+
1787+ // int64 splat (max int32 + 1) → int32 (trunc)
1788+ %int64_splat = torch.vtensor.literal (dense <2147483648 > : tensor <10 xsi64 >) : !torch.vtensor <[10 ],si64 >
1789+ // CHECK: %[[R3:.*]] = torch.vtensor.literal(dense<-2147483648> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
1790+ %result3 = torch.aten.to.dtype %int64_splat , %int3 , %false , %false , %none
1791+ : !torch.vtensor <[10 ],si64 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1792+ -> !torch.vtensor <[10 ],si32 >
1793+
1794+ // float32 splat → float64
1795+ %float32_splat = torch.vtensor.literal (dense <2.71828 > : tensor <5 x5 xf32 >) : !torch.vtensor <[5 ,5 ],f32 >
1796+ %int7 = torch.constant.int 7 // torch.float64
1797+ // CHECK: %[[R4:.*]] = torch.vtensor.literal(dense<2.7182800769805908> : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
1798+ %result4 = torch.aten.to.dtype %float32_splat , %int7 , %false , %false , %none
1799+ : !torch.vtensor <[5 ,5 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1800+ -> !torch.vtensor <[5 ,5 ],f64 >
1801+
1802+ // float64 splat → float16
1803+ %float64_splat = torch.vtensor.literal (dense <1.2 > : tensor <3 x3 xf64 >) : !torch.vtensor <[3 ,3 ],f64 >
1804+ %int5 = torch.constant.int 5 // torch.float16
1805+ // CHECK: %[[R5:.*]] = torch.vtensor.literal(dense<1.200200e+00> : tensor<3x3xf16>) : !torch.vtensor<[3,3],f16>
1806+ %result5 = torch.aten.to.dtype %float64_splat , %int5 , %false , %false , %none
1807+ : !torch.vtensor <[3 ,3 ],f64 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1808+ -> !torch.vtensor <[3 ,3 ],f16 >
1809+
1810+ // float32 splat → bfloat16
1811+ %float32_bf16 = torch.vtensor.literal (dense <-0.51 > : tensor <2 x2 xf32 >) : !torch.vtensor <[2 ,2 ],f32 >
1812+ %int15 = torch.constant.int 15 // torch.bfloat16
1813+ // CHECK: %[[R6:.*]] = torch.vtensor.literal(dense<-5.117190e-01> : tensor<2x2xbf16>) : !torch.vtensor<[2,2],bf16>
1814+ %result6 = torch.aten.to.dtype %float32_bf16 , %int15 , %false , %false , %none
1815+ : !torch.vtensor <[2 ,2 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1816+ -> !torch.vtensor <[2 ,2 ],bf16 >
1817+
1818+ // int32 splat → int64 (sign-extend)
1819+ %int32_ext = torch.vtensor.literal (dense <-1000 > : tensor <4 xsi32 >) : !torch.vtensor <[4 ],si32 >
1820+ %int4 = torch.constant.int 4 // torch.int64
1821+ // CHECK: %[[R7:.*]] = torch.vtensor.literal(dense<-1000> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
1822+ %result7 = torch.aten.to.dtype %int32_ext , %int4 , %false , %false , %none
1823+ : !torch.vtensor <[4 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1824+ -> !torch.vtensor <[4 ],si64 >
1825+
1826+ // int32 splat → int16 (trunc)
1827+ %int32_trunc = torch.vtensor.literal (dense <32768 > : tensor <3 xsi32 >) : !torch.vtensor <[3 ],si32 >
1828+ %int2 = torch.constant.int 2 // torch.int16
1829+ // CHECK: %[[R8:.*]] = torch.vtensor.literal(dense<-32768> : tensor<3xsi16>) : !torch.vtensor<[3],si16>
1830+ %result8 = torch.aten.to.dtype %int32_trunc , %int2 , %false , %false , %none
1831+ : !torch.vtensor <[3 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1832+ -> !torch.vtensor <[3 ],si16 >
1833+
1834+ // int32 splat → bool (i1), non-zero
1835+ %int40_splat = torch.vtensor.literal (dense <40 > : tensor <2 xsi32 >) : !torch.vtensor <[2 ],si32 >
1836+ %int11 = torch.constant.int 11 // torch.bool
1837+ // CHECK: %[[R9:.*]] = torch.vtensor.literal(dense<true> : tensor<2xi1>) : !torch.vtensor<[2],i1>
1838+ %result9 = torch.aten.to.dtype %int40_splat , %int11 , %false , %false , %none
1839+ : !torch.vtensor <[2 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1840+ -> !torch.vtensor <[2 ],i1 >
1841+
1842+ // float32 splat → bool (i1), zero
1843+ %float_zero = torch.vtensor.literal (dense <0.0 > : tensor <2 xf32 >) : !torch.vtensor <[2 ],f32 >
1844+ // CHECK: %[[R11:.*]] = torch.vtensor.literal(dense<false> : tensor<2xi1>) : !torch.vtensor<[2],i1>
1845+ %result10 = torch.aten.to.dtype %float_zero , %int11 , %false , %false , %none
1846+ : !torch.vtensor <[2 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1847+ -> !torch.vtensor <[2 ],i1 >
1848+
1849+ return %result1 , %result2 , %result3 , %result4 , %result5 , %result6 , %result7 , %result8 , %result9 , %result10
1850+ : !torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[4 ,4 ],si32 >, !torch.vtensor <[10 ],si32 >, !torch.vtensor <[5 ,5 ],f64 >, !torch.vtensor <[3 ,3 ],f16 >, !torch.vtensor <[2 ,2 ],bf16 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[3 ],si16 >, !torch.vtensor <[2 ],i1 >, !torch.vtensor <[2 ],i1 >
1851+ }
1852+
17651853// CHECK-LABEL: func.func @torch.aten.to.other$basic(
17661854// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
17671855// CHECK: %[[NONE:.*]] = torch.constant.none
0 commit comments