[Torch] add fold logic for some ops (#3794)

pull/3805/head
yyp0 2024-10-16 16:00:58 +08:00 committed by GitHub
parent 6b289f29f2
commit dc7a1ff7d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 254 additions and 2 deletions

View File

@ -3425,6 +3425,7 @@ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
printDefaultTorchOp(printer, *this, 3, 1); printDefaultTorchOp(printer, *this, 3, 1);
} }
}]; }];
let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
@ -4902,6 +4903,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1); printDefaultTorchOp(printer, *this, 3, 1);
} }
}]; }];
let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
@ -12641,6 +12643,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
printDefaultTorchOp(printer, *this, 1, 1); printDefaultTorchOp(printer, *this, 1, 1);
} }
}]; }];
let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
@ -15334,6 +15337,7 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [ def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [

View File

@ -1535,6 +1535,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
// ===----------------------------------------------------------------------===//
// AtenRSubScalarOp
// ===----------------------------------------------------------------------===//
OpFoldResult AtenRsubScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[1] - inputs[0] * inputs[2];
};
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[1] - inputs[0] * inputs[2];
};
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenMulTensorOp // AtenMulTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1979,6 +1997,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
}); });
} }
// ===----------------------------------------------------------------------===//
// AtenDivTensorModeOp
// ===----------------------------------------------------------------------===//
OpFoldResult AtenDivTensorModeOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype()) {
return nullptr;
}
std::function<double(ArrayRef<double>)> fpFold;
std::function<APInt(ArrayRef<APInt>)> intFold;
auto roundMode = dyn_cast_or_null<StringAttr>(adaptor.getRoundingMode());
auto unsign = false;
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
}
fpFold = [roundMode](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
if (!roundMode) {
return (double)inputs[0] / inputs[1];
} else if (roundMode.getValue().str() == "floor") {
return std::floor((double)inputs[0] / inputs[1]);
} else {
return std::trunc((double)inputs[0] / inputs[1]);
}
};
intFold = [unsign, roundMode](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
auto lhs = unsign ? inputs[0].getZExtValue() : inputs[0].getSExtValue();
auto rhs = unsign ? inputs[1].getZExtValue() : inputs[1].getSExtValue();
int64_t bits = std::max(inputs[0].getBitWidth(), inputs[1].getBitWidth());
int64_t res;
if (roundMode.getValue().str() == "floor") {
res = std::floor(lhs / rhs);
} else {
res = std::trunc(lhs / rhs);
}
return APInt(bits, res);
};
if (!roundMode) {
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
fpFold, std::nullopt);
}
return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(),
fpFold, intFold);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenDivScalarModeOp // AtenDivScalarModeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -3597,6 +3667,34 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; }); adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
} }
// ===----------------------------------------------------------------------===//
// AtenRemainderScalarOp
// ===----------------------------------------------------------------------===//
OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast_or_null<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype()) {
return nullptr;
}
auto unsign = false;
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
unsign = cast<IntegerType>(resultTy.getDtype()).isUnsigned();
}
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
return std::fmod(inputs[0], inputs[1]);
};
auto intFold = [unsign](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
auto ret = unsign ? inputs[0].urem(inputs[1]) : inputs[0].srem(inputs[1]);
return ret;
};
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenAddIntOp // AtenAddIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -4229,6 +4327,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
//===----------------------------------------------------------------------===//
// AtenIntTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
auto value = adaptor.getA();
auto dense = dyn_cast_or_null<DenseElementsAttr>(value);
if (!dense || !dense.isSplat()) {
return nullptr;
}
auto splat = dense.getSplatValue<Attribute>();
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
auto type = getType();
if (!isa<mlir::IntegerType>(type)) {
return nullptr;
}
if (type.isSignlessInteger()) {
return getI64IntegerAttr(getContext(), intAttr.getInt());
} else if (type.isSignedInteger()) {
return getI64IntegerAttr(getContext(), intAttr.getSInt());
} else {
return getI64IntegerAttr(getContext(), intAttr.getUInt());
}
}
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
return getI64IntegerAttr(
getContext(),
static_cast<long>(floatAttr.getValue().convertToDouble()));
}
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenFloatTensorOp // AtenFloatTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -379,6 +379,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# variants. # variants.
emit_with_mutating_variants( emit_with_mutating_variants(
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", "aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
has_folder=True,
has_canonicalizer=True, has_canonicalizer=True,
) )
emit_with_mutating_variants( emit_with_mutating_variants(
@ -481,6 +482,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
emit( emit(
"aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)",
has_folder=True,
has_canonicalizer=True, has_canonicalizer=True,
) )
emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::gelu : (Tensor, str) -> (Tensor)")
@ -928,7 +930,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True "aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True
) )
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True) emit(
"aten::Int.Tensor : (Tensor) -> (int)", has_folder=True, has_canonicalizer=True
)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
@ -1080,7 +1084,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
has_canonicalizer=True, has_canonicalizer=True,
) )
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True)

View File

@ -141,3 +141,113 @@ func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> {
%0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32> %0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32> return %0 : !torch.vtensor<[4],f32>
} }
// -----
// CHECK-LABEL: @fold_aten_rsub_scalar_int
func.func @fold_aten_rsub_scalar_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<-4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.constant.int 2
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_rsub_scalar_float
func.func @fold_aten_rsub_scalar_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<-4.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.constant.float 2.0
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],f32>, !torch.float, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @fold_aten_remainder_scalar_int
func.func @fold_aten_remainder_scalar_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.constant.int 2
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_remainder_scalar_float
func.func @fold_aten_remainder_scalar_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.constant.float 2.0
%cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @fold_aten_int_tensor_int
func.func @fold_aten_int_tensor_int() -> !torch.int {
// CHECK: %int3 = torch.constant.int 3
%cst_3 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],si64> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: @fold_aten_int_tensor_bool
func.func @fold_aten_int_tensor_bool() -> !torch.int {
// CHECK: %int1 = torch.constant.int 1
%cst_false = torch.vtensor.literal(dense<true> : tensor<i1>) : !torch.vtensor<[],i1>
%0 = torch.aten.Int.Tensor %cst_false : !torch.vtensor<[],i1> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: @fold_aten_int_tensor_float
func.func @fold_aten_int_tensor_float() -> !torch.int {
// CHECK: %int3 = torch.constant.int 3
%cst_3 = torch.vtensor.literal(dense<3.1> : tensor<f32>) : !torch.vtensor<[],f32>
%0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],f32> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: @fold_aten_div_tensor_mode_int
func.func @fold_aten_div_tensor_mode_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_2 = torch.vtensor.literal(dense<2> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%trunc = torch.constant.str "trunc"
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %trunc : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.str -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_div_tensor_mode_float
func.func @fold_aten_div_tensor_mode_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<3.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_8 = torch.vtensor.literal(dense<8.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_2 = torch.vtensor.literal(dense<2.1> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%floor = torch.constant.str "floor"
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %floor : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.str -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @fold_aten_div_tensor_mode_none
func.func @fold_aten_div_tensor_mode_none() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<2.66666675> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%none = torch.constant.none
%0 = torch.aten.div.Tensor_mode %cst_8, %cst_3, %none : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.none -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}