mirror of https://github.com/llvm/torch-mlir
[Torch] support AtenScalarImplicitOp canonicalize with float (#3231)
parent
4361178caa
commit
b0ba3def93
|
@ -1957,11 +1957,19 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
|
|||
Location loc = op.getLoc();
|
||||
Value a = op.getA();
|
||||
auto outType = op.getResult().getType();
|
||||
Value scalarValue = getScalarIntValue(a, loc, rewriter);
|
||||
if (!scalarValue)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
|
||||
Value scalarIntValue = getScalarIntValue(a, loc, rewriter);
|
||||
if (scalarIntValue) {
|
||||
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
|
||||
scalarIntValue);
|
||||
return success();
|
||||
}
|
||||
Value scalarFloatValue = getScalarFloatValue(a, loc, rewriter);
|
||||
if (scalarFloatValue) {
|
||||
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
|
||||
scalarFloatValue);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -2174,6 +2174,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[]
|
|||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
|
||||
|
@ -2186,6 +2188,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso
|
|||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
||||
|
@ -2197,6 +2201,8 @@ func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.num
|
|||
return %1 : !torch.number
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
||||
|
@ -2209,6 +2215,18 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number {
|
||||
// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[VAL_0:.*]] = torch.derefine %float1.000000e00 : !torch.float to !torch.number
|
||||
// CHECK: return %[[VAL_0]] : !torch.number
|
||||
func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number {
|
||||
%0 = torch.vtensor.literal(dense<1.0> : tensor<f64>) : !torch.vtensor<[],f64>
|
||||
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],f64> -> !torch.number
|
||||
return %1 : !torch.number
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: return %[[FLOAT1]] : !torch.float
|
||||
|
|
Loading…
Reference in New Issue