mirror of https://github.com/llvm/torch-mlir
[Torch] support AtenScalarImplicitOp canonicalize with float (#3231)
parent
4361178caa
commit
b0ba3def93
|
@ -1957,11 +1957,19 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value a = op.getA();
|
Value a = op.getA();
|
||||||
auto outType = op.getResult().getType();
|
auto outType = op.getResult().getType();
|
||||||
Value scalarValue = getScalarIntValue(a, loc, rewriter);
|
Value scalarIntValue = getScalarIntValue(a, loc, rewriter);
|
||||||
if (!scalarValue)
|
if (scalarIntValue) {
|
||||||
return failure();
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
|
||||||
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
|
scalarIntValue);
|
||||||
return success();
|
return success();
|
||||||
|
}
|
||||||
|
Value scalarFloatValue = getScalarFloatValue(a, loc, rewriter);
|
||||||
|
if (scalarFloatValue) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
|
||||||
|
scalarFloatValue);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2174,6 +2174,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[]
|
||||||
return %2 : !torch.vtensor<[],si64>
|
return %2 : !torch.vtensor<[],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
|
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
|
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
|
||||||
|
@ -2186,6 +2188,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso
|
||||||
return %2 : !torch.vtensor<[],si64>
|
return %2 : !torch.vtensor<[],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
|
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
|
||||||
// CHECK: %int1 = torch.constant.int 1
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
||||||
|
@ -2197,6 +2201,8 @@ func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.num
|
||||||
return %1 : !torch.number
|
return %1 : !torch.number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
|
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
|
||||||
// CHECK: %int1 = torch.constant.int 1
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
||||||
|
@ -2209,6 +2215,18 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number {
|
||||||
|
// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch.derefine %float1.000000e00 : !torch.float to !torch.number
|
||||||
|
// CHECK: return %[[VAL_0]] : !torch.number
|
||||||
|
func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number {
|
||||||
|
%0 = torch.vtensor.literal(dense<1.0> : tensor<f64>) : !torch.vtensor<[],f64>
|
||||||
|
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],f64> -> !torch.number
|
||||||
|
return %1 : !torch.number
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
|
// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
|
||||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
// CHECK: return %[[FLOAT1]] : !torch.float
|
// CHECK: return %[[FLOAT1]] : !torch.float
|
||||||
|
|
Loading…
Reference in New Issue