[torch] Add folder for `prim.NumToTensor.Scalar` (#2921)

Useful for `slice` lowerings that depend on tensors made form scalars.
pull/2924/merge
Rob Suderman 2024-02-19 11:55:54 -08:00 committed by GitHub
parent e80054a3cc
commit 135c81a416
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 77 additions and 99 deletions

View File

@ -15070,6 +15070,7 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [

View File

@ -452,7 +452,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto exp = b.create<math::ExpOp>(loc, negate);
auto added = b.create<arith::AddFOp>(loc, exp, one);
auto div = b.create<arith::DivFOp>(loc, one, added);
outTy.dump();
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
}
if (auto relu = dyn_cast<AtenReluOp>(op)) {

View File

@ -3641,6 +3641,30 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
}
//===----------------------------------------------------------------------===//
// PrimNumToTensorScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
Attribute a = adaptor.getA();
auto resultTy = cast<BaseTensorType>(getType());
if (!a)
return {};
if (!resultTy.hasDtype() || !resultTy.hasSizes())
return {};
auto dty = resultTy.getDtype();
if (auto iattr = dyn_cast<IntegerAttr>(a)) {
a = IntegerAttr::get(dty, iattr.getInt());
} else if (auto fattr = dyn_cast<FloatAttr>(a)) {
a = FloatAttr::get(dty, fattr.getValueAsDouble());
}
auto mlirTensorType =
RankedTensorType::get(resultTy.getSizes(), resultTy.getDtype());
return SplatElementsAttr::get(mlirTensorType, a);
}
//===----------------------------------------------------------------------===//
// PrimMinSelfIntOp
//===----------------------------------------------------------------------===//

View File

@ -419,6 +419,7 @@ STABLEHLO_PASS_SET = {
"AtenEyeModuleFloat2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic",
"AtenInstanceNormModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
"AtenIntBoolOpModule_basic",

View File

@ -846,7 +846,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("prim::device : (Tensor) -> (Device)", has_canonicalizer=True)
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)", has_folder=True)
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
emit("prim::min.int : (int, int) -> (int)", has_folder=True)
emit("prim::max.self_int : (int[]) -> (int)")

View File

@ -27,13 +27,9 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic(
// CHECK-SAME: ) -> !torch.vtensor<[],si64> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.*]] = stablehlo.constant dense<1> : tensor<i64>
// CHECK: %[[FROM:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[FROM]] : !torch.vtensor<[],si64>
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64>

View File

@ -1687,13 +1687,8 @@ func.func @torch.aten.Bool.int$fold_cst() -> !torch.bool {
}
// CHECK-LABEL: func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
@ -1705,11 +1700,8 @@ func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor
}
// CHECK-LABEL: @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
@ -1760,11 +1752,8 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t
}
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int6 = torch.constant.int 6
%str = torch.constant.str "floor"
@ -1775,13 +1764,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtenso
}
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int6 = torch.constant.int 6
%int2 = torch.constant.int 2
@ -1793,11 +1777,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vt
}
// CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR1]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
@ -1808,9 +1789,8 @@ func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor
}
// CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR0]] : !torch.vtensor<[],si64>
// CHECK: %[[CST]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
@ -1820,13 +1800,8 @@ func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],
}
// CHECK-LABEL: func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT_6:.*]] = torch.constant.int -6
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
@ -1838,11 +1813,8 @@ func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor
}
// CHECK-LABEL: @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT_6:.*]] = torch.constant.int -6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]]
func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
@ -1854,11 +1826,8 @@ func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],
}
// CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT_6:.*]] = torch.constant.int -6
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR1]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
@ -1869,9 +1838,8 @@ func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor
}
// CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT_6:.*]] = torch.constant.int -6
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR0]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
@ -1891,9 +1859,8 @@ func.func @torch.aten.sub.float$fold() -> !torch.float {
}
// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6]] = torch.constant.int 6
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR0]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int3 = torch.constant.int 3
%0 = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
@ -1902,11 +1869,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],
}
// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR1]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
@ -1916,8 +1880,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor
}
// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[INT6]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%0 = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
@ -1926,13 +1890,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],
}
// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
@ -1943,13 +1902,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor
}
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> {
%int6 = torch.constant.int 6
%int2 = torch.constant.int 2
@ -1961,11 +1915,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !to
}
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> {
%int6 = torch.constant.int 6
%str = torch.constant.str "trunc"
@ -2151,9 +2102,8 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>,
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %int-1 = torch.constant.int -1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[VAL_0]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
@ -2163,11 +2113,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[]
}
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %int-1 = torch.constant.int -1
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[VAL_1:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
@ -2179,7 +2126,6 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
// CHECK: return %[[VAL_1]] : !torch.number
func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
@ -2347,6 +2293,17 @@ func.func @fold_aten_where_true_attr() -> !torch.vtensor<[4],si64> {
// -----
// CHECK-LABEL: @fold_prim_numtotensor_scalar
func.func @fold_prim_numtotensor_scalar() -> !torch.vtensor<[1],si64> {
%int42 = torch.constant.int 42
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
// CHECK: return %[[TENSOR]]
%0 = torch.prim.NumToTensor.Scalar %int42 : !torch.int -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}
// -----
// CHECK-LABEL: @fold_aten_where_false_attr
func.func @fold_aten_where_false_attr() -> !torch.vtensor<[4],si64> {
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>