mirror of https://github.com/llvm/torch-mlir
[Torch] add fold logic for some ops (#3794)
parent
6b289f29f2
commit
dc7a1ff7d9
|
@ -3425,6 +3425,7 @@ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -4902,6 +4903,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -12641,6 +12643,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -15334,6 +15337,7 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [
|
||||
|
|
|
@ -1535,6 +1535,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
// ===----------------------------------------------------------------------===//
|
||||
// AtenRSubScalarOp
|
||||
// ===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenRsubScalarOp::fold(FoldAdaptor adaptor) {
|
||||
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
||||
assert(inputs.size() == 3);
|
||||
return inputs[1] - inputs[0] * inputs[2];
|
||||
};
|
||||
|
||||
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
|
||||
assert(inputs.size() == 3);
|
||||
return inputs[1] - inputs[0] * inputs[2];
|
||||
};
|
||||
|
||||
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMulTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1979,6 +1997,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
// ===----------------------------------------------------------------------===//
|
||||
// AtenDivTensorModeOp
|
||||
// ===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivTensorModeOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
|
||||
if (!resultTy || !resultTy.hasDtype()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::function<double(ArrayRef<double>)> fpFold;
|
||||
std::function<APInt(ArrayRef<APInt>)> intFold;
|
||||
|
||||
auto roundMode = dyn_cast_or_null<StringAttr>(adaptor.getRoundingMode());
|
||||
auto unsign = false;
|
||||
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
|
||||
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
|
||||
}
|
||||
|
||||
fpFold = [roundMode](llvm::ArrayRef<double> inputs) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!roundMode) {
|
||||
return (double)inputs[0] / inputs[1];
|
||||
} else if (roundMode.getValue().str() == "floor") {
|
||||
return std::floor((double)inputs[0] / inputs[1]);
|
||||
} else {
|
||||
return std::trunc((double)inputs[0] / inputs[1]);
|
||||
}
|
||||
};
|
||||
|
||||
intFold = [unsign, roundMode](llvm::ArrayRef<APInt> inputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto lhs = unsign ? inputs[0].getZExtValue() : inputs[0].getSExtValue();
|
||||
auto rhs = unsign ? inputs[1].getZExtValue() : inputs[1].getSExtValue();
|
||||
int64_t bits = std::max(inputs[0].getBitWidth(), inputs[1].getBitWidth());
|
||||
int64_t res;
|
||||
if (roundMode.getValue().str() == "floor") {
|
||||
res = std::floor(lhs / rhs);
|
||||
} else {
|
||||
res = std::trunc(lhs / rhs);
|
||||
}
|
||||
return APInt(bits, res);
|
||||
};
|
||||
|
||||
if (!roundMode) {
|
||||
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
|
||||
fpFold, std::nullopt);
|
||||
}
|
||||
|
||||
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
|
||||
fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDivScalarModeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3597,6 +3667,34 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
|
|||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
|
||||
}
|
||||
|
||||
// ===----------------------------------------------------------------------===//
|
||||
// AtenRemainderScalarOp
|
||||
// ===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
|
||||
if (!resultTy || !resultTy.hasDtype()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto unsign = false;
|
||||
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
|
||||
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
|
||||
}
|
||||
auto fpFold = [](llvm::ArrayRef<double> inputs) {
|
||||
assert(inputs.size() == 2);
|
||||
return std::fmod(inputs[0], inputs[1]);
|
||||
};
|
||||
|
||||
auto intFold = [unsign](llvm::ArrayRef<APInt> inputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto ret = unsign ? inputs[0].urem(inputs[1]) : inputs[0].srem(inputs[1]);
|
||||
return ret;
|
||||
};
|
||||
|
||||
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenAddIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -4229,6 +4327,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
|
||||
auto value = adaptor.getA();
|
||||
auto dense = dyn_cast_or_null<DenseElementsAttr>(value);
|
||||
if (!dense || !dense.isSplat()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto splat = dense.getSplatValue<Attribute>();
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
|
||||
auto type = getType();
|
||||
if (!isa<mlir::IntegerType>(type)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (type.isSignlessInteger()) {
|
||||
return getI64IntegerAttr(getContext(), intAttr.getInt());
|
||||
} else if (type.isSignedInteger()) {
|
||||
return getI64IntegerAttr(getContext(), intAttr.getSInt());
|
||||
} else {
|
||||
return getI64IntegerAttr(getContext(), intAttr.getUInt());
|
||||
}
|
||||
}
|
||||
|
||||
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
|
||||
return getI64IntegerAttr(
|
||||
getContext(),
|
||||
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenFloatTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -379,6 +379,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
# variants.
|
||||
emit_with_mutating_variants(
|
||||
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
|
||||
has_folder=True,
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit_with_mutating_variants(
|
||||
|
@ -481,6 +482,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit(
|
||||
"aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
has_folder=True,
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
||||
|
@ -928,7 +930,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True
|
||||
)
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
|
||||
emit(
|
||||
"aten::Int.Tensor : (Tensor) -> (int)", has_folder=True, has_canonicalizer=True
|
||||
)
|
||||
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
||||
|
@ -1080,7 +1084,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
has_canonicalizer=True,
|
||||
)
|
||||
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
|
||||
|
|
|
@ -141,3 +141,113 @@ func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> {
|
|||
%0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_rsub_scalar_int
|
||||
func.func @fold_aten_rsub_scalar_int() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<-4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_2 = torch.constant.int 2
|
||||
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_rsub_scalar_float
|
||||
func.func @fold_aten_rsub_scalar_float() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<-4.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_2 = torch.constant.float 2.0
|
||||
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],f32>, !torch.float, !torch.float -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_remainder_scalar_int
|
||||
func.func @fold_aten_remainder_scalar_int() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_2 = torch.constant.int 2
|
||||
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_remainder_scalar_float
|
||||
func.func @fold_aten_remainder_scalar_float() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_2 = torch.constant.float 2.0
|
||||
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_int_tensor_int
|
||||
func.func @fold_aten_int_tensor_int() -> !torch.int {
|
||||
// CHECK: %int3 = torch.constant.int 3
|
||||
%cst_3 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],si64> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_int_tensor_bool
|
||||
func.func @fold_aten_int_tensor_bool() -> !torch.int {
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
%cst_false = torch.vtensor.literal(dense<true> : tensor<i1>) : !torch.vtensor<[],i1>
|
||||
%0 = torch.aten.Int.Tensor %cst_false : !torch.vtensor<[],i1> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_int_tensor_float
|
||||
func.func @fold_aten_int_tensor_float() -> !torch.int {
|
||||
// CHECK: %int3 = torch.constant.int 3
|
||||
%cst_3 = torch.vtensor.literal(dense<3.1> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],f32> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_div_tensor_mode_int
|
||||
func.func @fold_aten_div_tensor_mode_int() -> !torch.vtensor<[4],si64> {
|
||||
// CHECK: torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_2 = torch.vtensor.literal(dense<2> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%trunc = torch.constant.str "trunc"
|
||||
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %trunc : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.str -> !torch.vtensor<[4],si64>
|
||||
return %0 : !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_div_tensor_mode_float
|
||||
func.func @fold_aten_div_tensor_mode_float() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<3.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_8 = torch.vtensor.literal(dense<8.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_2 = torch.vtensor.literal(dense<2.1> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%floor = torch.constant.str "floor"
|
||||
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %floor : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.str -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_aten_div_tensor_mode_none
|
||||
func.func @fold_aten_div_tensor_mode_none() -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.vtensor.literal(dense<2.66666675> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_3, %none : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.none -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue