mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg]Lower torch.gcd to linalg and scf
Add verify() method to check if tensors are of integer type. Also check if tensors are of same shape, or if the second tensor is a single element tensor. Add e2e tests. Put them into onnx and stablehlo xfailed sets.pull/3732/head
parent
67732883fa
commit
53b1ec3134
|
@ -13148,6 +13148,31 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenGcdOp : Torch_Op<"aten.gcd", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::gcd : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenGcdOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenGcdOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
@ -213,6 +215,83 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenGcdOp : public OpConversionPattern<torch::Torch::AtenGcdOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(torch::Torch::AtenGcdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto self = adaptor.getSelf(); // tensor A
|
||||
auto other = adaptor.getOther(); // tensor B of the same size
|
||||
auto loc = op.getLoc();
|
||||
|
||||
TensorType resultType =
|
||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
auto gcdPayloadBody = [&](OpBuilder &b, Location loc,
|
||||
ValueRange genericInstructionArgs) {
|
||||
auto A = genericInstructionArgs[0];
|
||||
A = b.create<mlir::math::AbsIOp>(loc, A);
|
||||
auto B = genericInstructionArgs[1];
|
||||
B = b.create<mlir::math::AbsIOp>(loc, B);
|
||||
auto zero = b.create<mlir::arith::ConstantIntOp>(loc, 0, A.getType());
|
||||
|
||||
Value AtrailingZerosCount =
|
||||
b.create<mlir::math::CountTrailingZerosOp>(loc, A);
|
||||
Value BtrailingZerosCount =
|
||||
b.create<mlir::math::CountTrailingZerosOp>(loc, B);
|
||||
auto smalerZerosCount = b.create<mlir::arith::MinSIOp>(
|
||||
loc, AtrailingZerosCount, BtrailingZerosCount);
|
||||
auto shiftedA = b.create<mlir::arith::ShRSIOp>(loc, A, smalerZerosCount);
|
||||
auto shiftedB = b.create<mlir::arith::ShRSIOp>(loc, B, smalerZerosCount);
|
||||
|
||||
auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
|
||||
mlir::ValueRange innerLoopArgs) {
|
||||
Value min = b.create<mlir::arith::MinSIOp>(loc, innerLoopArgs[0],
|
||||
innerLoopArgs[1]);
|
||||
Value max = b.create<mlir::arith::MaxSIOp>(loc, innerLoopArgs[0],
|
||||
innerLoopArgs[1]);
|
||||
|
||||
auto cmp = b.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ne, min, zero);
|
||||
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max});
|
||||
};
|
||||
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
|
||||
mlir::ValueRange innerLoopArgs) {
|
||||
Value min = innerLoopArgs[0];
|
||||
Value max = innerLoopArgs[1];
|
||||
max = b.create<mlir::arith::SubIOp>(loc, max, min);
|
||||
|
||||
Value maxTrailingZerosCount =
|
||||
b.create<mlir::math::CountTrailingZerosOp>(loc, max);
|
||||
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount);
|
||||
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max});
|
||||
};
|
||||
|
||||
auto findGcdWhileOp = b.create<mlir::scf::WhileOp>(
|
||||
loc, TypeRange{shiftedA.getType(), shiftedB.getType()},
|
||||
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock,
|
||||
findGcdBodyBlock);
|
||||
|
||||
Value gcdResult = findGcdWhileOp.getResult(1);
|
||||
gcdResult =
|
||||
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount);
|
||||
|
||||
b.create<linalg::YieldOp>(loc, gcdResult);
|
||||
};
|
||||
|
||||
other = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, ValueRange{self, other},
|
||||
cast<TensorType>(self.getType()).getElementType(), gcdPayloadBody);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, other);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
|
||||
public:
|
||||
|
@ -1400,4 +1479,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
|
|||
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenConvolutionOp>();
|
||||
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenGcdOp>();
|
||||
patterns.add<ConvertAtenGcdOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -5524,3 +5524,37 @@ LogicalResult AtenRot90Op::verify() {
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AtenGcdOp::verify() {
|
||||
|
||||
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
||||
auto otherType = cast<BaseTensorType>(getOther().getType());
|
||||
|
||||
if (!selfType.hasDtype() || !selfType.hasSizes() || !otherType.hasDtype() ||
|
||||
!otherType.hasSizes())
|
||||
return success();
|
||||
|
||||
auto selfShape = selfType.getSizes();
|
||||
auto otherShape = selfType.getSizes();
|
||||
int64_t selfRank = selfShape.size();
|
||||
int64_t otherRank = otherShape.size();
|
||||
auto selfDtype = selfType.getDtype();
|
||||
|
||||
if (!isa<mlir::IntegerType>(selfDtype))
|
||||
return emitOpError("expected an integer type for input tensor, but got ")
|
||||
<< selfDtype;
|
||||
|
||||
if (otherRank == 1 && otherShape[0] == 1)
|
||||
return success();
|
||||
|
||||
if (selfRank != otherRank)
|
||||
return emitOpError("Tensors must be of same rank or second tensor must be "
|
||||
"a single element tensor");
|
||||
|
||||
for (int i = 0; i < selfRank; i++) {
|
||||
if (selfShape[i] != otherShape[i])
|
||||
return emitOpError("Dimensions od tensors font match in dim ") << i;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -6639,6 +6639,72 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %8 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.gcd\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: Shapes must be the same or 'other' must be a single element tensor.\"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.aten.eq.int_list %arg0, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %6 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.gcd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: aten.gcd works only with integer types\"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
|
||||
" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
|
||||
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" return %1 : !torch.bool\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -11238,21 +11304,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
|
||||
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" return %1 : !torch.bool\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
|
|
|
@ -923,6 +923,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
"GCDBatchedModule_I32",
|
||||
"GCDDynamicModule_I32",
|
||||
"GCDModule_I32",
|
||||
}
|
||||
|
||||
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
||||
|
@ -3126,6 +3129,9 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"UnfoldModule_basic",
|
||||
"GCDBatchedModule_I32",
|
||||
"GCDDynamicModule_I32",
|
||||
"GCDModule_I32",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||
|
|
|
@ -265,6 +265,17 @@ def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]:
|
|||
shape = upstream_shape_functions.zero_dim_tensor(A)
|
||||
return shape, shape
|
||||
|
||||
def aten〇gcd〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
assert self == other or (len(other) == 1 and other[0]==0), "Shapes must be the same or 'other' must be a single element tensor."
|
||||
return self
|
||||
|
||||
def aten〇gcd〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
other_rank, other_dtype = other_rank_dtype
|
||||
assert is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype), "aten.gcd works only with integer types"
|
||||
return self_dtype
|
||||
|
||||
|
||||
def aten〇detach〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
|
|
@ -964,6 +964,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::gcd : (Tensor, Tensor) -> (Tensor)", has_verifier=True)
|
||||
|
||||
# Functionalization ops
|
||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -6845,3 +6845,52 @@ class TrilIndicesOfssetGreaterThanRowModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule())
|
||||
def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GCDModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([None, [(4, 4), torch.int32, True], [(4, 4), torch.int32, True]])
|
||||
def forward(self, A, B):
|
||||
return torch.gcd(A, B)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GCDModule())
|
||||
def GCDModule_I32(module, tu: TestUtils):
|
||||
A = tu.rand(4, 4, low=-100, high=100).to(dtype=torch.int32)
|
||||
B = tu.rand(4, 4, low=-100, high=100).to(dtype=torch.int32)
|
||||
module.forward(A, B)
|
||||
|
||||
|
||||
class GCDBatchedModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args(
|
||||
[None, [(4, 4, 4), torch.int32, True], [(4, 4, 4), torch.int32, True]]
|
||||
)
|
||||
def forward(self, A, B):
|
||||
return torch.gcd(A, B)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GCDBatchedModule())
|
||||
def GCDBatchedModule_I32(module, tu: TestUtils):
|
||||
A = tu.rand(4, 4, 4, low=-100, high=100).to(dtype=torch.int32)
|
||||
B = tu.rand(4, 4, 4, low=-100, high=100).to(dtype=torch.int32)
|
||||
module.forward(A, B)
|
||||
|
||||
|
||||
class GCDDynamicModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args(
|
||||
[None, [(-1, -1, -1), torch.int32, True], [(-1, -1, -1), torch.int32, True]]
|
||||
)
|
||||
def forward(self, A, B):
|
||||
return torch.gcd(A, B)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GCDDynamicModule())
|
||||
def GCDDynamicModule_I32(module, tu: TestUtils):
|
||||
A = tu.rand(3, 4, 4, low=-100, high=100).to(dtype=torch.int32)
|
||||
B = tu.rand(3, 4, 4, low=-100, high=100).to(dtype=torch.int32)
|
||||
module.forward(A, B)
|
||||
|
|
|
@ -8,6 +8,4 @@ cd "$src_dir"
|
|||
|
||||
# Ensure PYTHONPATH is set for export to child processes, even if empty.
|
||||
export PYTHONPATH=${PYTHONPATH-}
|
||||
source $project_dir/.env
|
||||
|
||||
python -m e2e_testing.main "$@"
|
||||
|
|
Loading…
Reference in New Issue