[Torch] support AtenScalarImplicitOp canonicalize with float (#3231)

pull/3238/head
Yuanqiang Liu 2024-04-26 02:36:13 +08:00 committed by GitHub
parent 4361178caa
commit b0ba3def93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 5 deletions

View File

@ -1957,11 +1957,19 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
Location loc = op.getLoc(); Location loc = op.getLoc();
Value a = op.getA(); Value a = op.getA();
auto outType = op.getResult().getType(); auto outType = op.getResult().getType();
Value scalarValue = getScalarIntValue(a, loc, rewriter); Value scalarIntValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue) if (scalarIntValue) {
return failure(); rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue); scalarIntValue);
return success(); return success();
}
Value scalarFloatValue = getScalarFloatValue(a, loc, rewriter);
if (scalarFloatValue) {
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
scalarFloatValue);
return success();
}
return failure();
}); });
} }

View File

@ -2174,6 +2174,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[]
return %2 : !torch.vtensor<[],si64> return %2 : !torch.vtensor<[],si64>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64> // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64> // CHECK: return %[[CST]] : !torch.vtensor<[],si64>
@ -2186,6 +2188,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso
return %2 : !torch.vtensor<[],si64> return %2 : !torch.vtensor<[],si64>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { // CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
// CHECK: %int1 = torch.constant.int 1 // CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number // CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
@ -2197,6 +2201,8 @@ func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.num
return %1 : !torch.number return %1 : !torch.number
} }
// -----
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number { // CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
// CHECK: %int1 = torch.constant.int 1 // CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number // CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number
@ -2209,6 +2215,18 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number
// ----- // -----
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number {
// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00
// CHECK: %[[VAL_0:.*]] = torch.derefine %float1.000000e00 : !torch.float to !torch.number
// CHECK: return %[[VAL_0]] : !torch.number
func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number {
%0 = torch.vtensor.literal(dense<1.0> : tensor<f64>) : !torch.vtensor<[],f64>
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],f64> -> !torch.number
return %1 : !torch.number
}
// -----
// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float { // CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: return %[[FLOAT1]] : !torch.float // CHECK: return %[[FLOAT1]] : !torch.float