diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a8769def6..9d4687596 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1957,11 +1957,19 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns( Location loc = op.getLoc(); Value a = op.getA(); auto outType = op.getResult().getType(); - Value scalarValue = getScalarIntValue(a, loc, rewriter); - if (!scalarValue) - return failure(); - rewriter.replaceOpWithNewOp(op, outType, scalarValue); - return success(); + Value scalarIntValue = getScalarIntValue(a, loc, rewriter); + if (scalarIntValue) { + rewriter.replaceOpWithNewOp(op, outType, + scalarIntValue); + return success(); + } + Value scalarFloatValue = getScalarFloatValue(a, loc, rewriter); + if (scalarFloatValue) { + rewriter.replaceOpWithNewOp(op, outType, + scalarFloatValue); + return success(); + } + return failure(); }); } diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 4d2a595da..1823393f2 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2174,6 +2174,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[] return %2 : !torch.vtensor<[],si64> } +// ----- + // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> // CHECK: return %[[CST]] : !torch.vtensor<[],si64> @@ -2186,6 +2188,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso return %2 : !torch.vtensor<[],si64> } +// ----- + // CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { // CHECK: %int1 = torch.constant.int 1 // CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number @@ -2197,6 +2201,8 @@ func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.num return %1 : !torch.number } +// ----- + // CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number { // CHECK: %int1 = torch.constant.int 1 // CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number @@ -2209,6 +2215,18 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number // ----- +// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number { +// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00 +// CHECK: %[[VAL_0:.*]] = torch.derefine %float1.000000e00 : !torch.float to !torch.number +// CHECK: return %[[VAL_0]] : !torch.number +func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number { + %0 = torch.vtensor.literal(dense<1.0> : tensor) : !torch.vtensor<[],f64> + %1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],f64> -> !torch.number + return %1 : !torch.number +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float { // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 // CHECK: return %[[FLOAT1]] : !torch.float