[TorchToLinalg]Lower torch.gcd to linalg and scf

Add verify() method to check if tensors are of
integer type. Also check if tensors are of same shape,
or if the second tensor is a single element tensor.

Add e2e tests. Put them into onnx and stablehlo
xfailed sets.
pull/3732/head
Bratislav Filipovic 2024-09-19 14:55:21 +02:00
parent 67732883fa
commit 7673a8ff28
9 changed files with 302 additions and 17 deletions

View File

@ -13148,6 +13148,31 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [
}];
}
def Torch_AtenGcdOp : Torch_Op<"aten.gcd", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::gcd : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGcdOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenGcdOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -13,6 +13,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
@ -213,6 +215,112 @@ public:
};
} // namespace
namespace {
class ConvertAtenGcdOp : public OpConversionPattern<torch::Torch::AtenGcdOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(torch::Torch::AtenGcdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto self = adaptor.getSelf(); // tensor A
auto other = adaptor.getOther(); // tensor B of the same size
auto loc = op.getLoc();
TensorType resultType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
auto gcdPayloadBody = [&](OpBuilder &b, Location loc,
ValueRange payloadArgs) {
auto A = payloadArgs[0];
A = b.create<mlir::math::AbsIOp>(loc, A);
auto B = payloadArgs[1];
B = b.create<mlir::math::AbsIOp>(loc, B);
auto two = b.create<mlir::arith::ConstantIntOp>(loc, 2, A.getType());
auto one = b.create<mlir::arith::ConstantIntOp>(loc, 1, A.getType());
auto zero = b.create<mlir::arith::ConstantIntOp>(loc, 0, A.getType());
auto trailingZeroConditionBlock = [&](mlir::OpBuilder &b,
mlir::Location loc,
mlir::ValueRange whileArgs) {
auto current = whileArgs[0];
auto counter = whileArgs[1];
auto currentAndOne = b.create<mlir::arith::AndIOp>(loc, current, one);
auto cmp = b.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, currentAndOne, one);
b.create<mlir::scf::ConditionOp>(loc, cmp,
ValueRange{current, counter});
};
auto trailingZerosBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange args) {
auto current = args[0];
auto counter = args[1];
auto divided = b.create<mlir::arith::DivUIOp>(loc, current, two);
auto newCounter = b.create<mlir::arith::AddIOp>(loc, counter, one);
b.create<mlir::scf::YieldOp>(
loc, ValueRange{divided.getResult(), newCounter.getResult()});
};
auto AtrailingZerosOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{A.getType(), zero.getType()}, ValueRange{A, zero},
trailingZeroConditionBlock, trailingZerosBodyBlock);
auto BtrailingZerosOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{B.getType(), zero.getType()}, ValueRange{B, zero},
trailingZeroConditionBlock, trailingZerosBodyBlock);
Value AtrailingZerosCount = AtrailingZerosOp.getResult(0);
Value BtrailingZerosCount = BtrailingZerosOp.getResult(0);
auto smalerZerosCount = b.create<mlir::arith::MinSIOp>(
loc, AtrailingZerosCount, BtrailingZerosCount);
auto shiftedA = b.create<mlir::arith::ShRSIOp>(loc, A, smalerZerosCount);
auto shiftedB = b.create<mlir::arith::ShRSIOp>(loc, B, smalerZerosCount);
auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange args) {
Value min = b.create<mlir::arith::MinSIOp>(loc, args[0], args[1]);
Value max =
b.create<mlir::arith::MaxSIOp>(loc, payloadArgs[0], payloadArgs[1]);
auto cmp = b.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ne, min, zero);
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max});
};
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange args) {
Value min = args[0];
Value max = args[1];
max = b.create<mlir::arith::SubIOp>(loc, max, min);
auto maxTrailingZerosOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{B.getType(), zero.getType()}, ValueRange{max, zero},
trailingZeroConditionBlock, trailingZerosBodyBlock);
Value maxTrailingZerosCount = maxTrailingZerosOp.getResult(0);
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount);
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max});
};
auto findGcdWhileOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{shiftedA.getType(), shiftedB.getType()},
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock,
findGcdBodyBlock);
Value gcdResult = findGcdWhileOp.getResult(1);
gcdResult =
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount);
b.create<linalg::YieldOp>(loc, gcdResult);
};
other = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{self, other},
cast<TensorType>(self.getType()).getElementType(), gcdPayloadBody);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, other);
return success();
}
};
} // namespace
namespace {
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
public:
@ -1400,4 +1508,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
target.addIllegalOp<AtenGcdOp>();
patterns.add<ConvertAtenGcdOp>(typeConverter, context);
}

View File

@ -5524,3 +5524,37 @@ LogicalResult AtenRot90Op::verify() {
return success();
}
LogicalResult AtenGcdOp::verify() {
auto selfType = cast<BaseTensorType>(getSelf().getType());
auto otherType = cast<BaseTensorType>(getOther().getType());
if (!selfType.hasDtype() || !selfType.hasSizes() || !otherType.hasDtype() ||
!otherType.hasSizes())
return success();
auto selfShape = selfType.getSizes();
auto otherShape = selfType.getSizes();
int64_t selfRank = selfShape.size();
int64_t otherRank = otherShape.size();
auto selfDtype = selfType.getDtype();
if (!isa<mlir::IntegerType>(selfDtype))
return emitOpError("expected an integer type for input tensor, but got ")
<< selfDtype;
if (otherRank == 1 && otherShape[0] == 1)
return success();
if (selfRank != otherRank)
return emitOpError("Tensors must be of same rank or second tensor must be "
"a single element tensor");
for (int i = 0; i < selfRank; i++) {
if (selfShape[i] != otherShape[i])
return emitOpError("Dimensions od tensors font match in dim ") << i;
}
return success();
}

View File

@ -6639,6 +6639,72 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %8 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.gcd\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Shapes must be the same or 'other' must be a single element tensor.\"\n"
" %false = torch.constant.bool false\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.eq.int_list %arg0, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %6 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.gcd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: aten.gcd works only with integer types\"\n"
" %false = torch.constant.bool false\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int11 = torch.constant.int 11\n"
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -11238,21 +11304,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int11 = torch.constant.int 11\n"
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

View File

@ -923,6 +923,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"GCDBatchedModule_I32",
"GCDDynamicModule_I32",
"GCDModule_I32",
}
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@ -3126,6 +3129,9 @@ ONNX_XFAIL_SET = {
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"UnfoldModule_basic",
"GCDBatchedModule_I32",
"GCDDynamicModule_I32",
"GCDModule_I32",
}
if torch_version_for_comparison() < version.parse("2.3.0.dev"):

View File

@ -265,6 +265,17 @@ def atenlinalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]:
shape = upstream_shape_functions.zero_dim_tensor(A)
return shape, shape
def atengcd〡shape(self: List[int], other: List[int]) -> List[int]:
assert self == other or (len(other) == 1 and other[0]==0), "Shapes must be the same or 'other' must be a single element tensor."
return self
def atengcd〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
other_rank, other_dtype = other_rank_dtype
assert is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype), "aten.gcd works only with integer types"
return self_dtype
def atendetach〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

View File

@ -964,6 +964,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)"
)
emit("aten::gcd : (Tensor, Tensor) -> (Tensor)", has_verifier=True)
# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")

View File

@ -6845,3 +6845,52 @@ class TrilIndicesOfssetGreaterThanRowModule(torch.nn.Module):
@register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule())
def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class GCDModule(torch.nn.Module):
@export
@annotate_args([None, [(4, 4), torch.int32, True], [(4, 4), torch.int32, True]])
def forward(self, A, B):
return torch.gcd(A, B)
@register_test_case(module_factory=lambda: GCDModule())
def GCDModule_I32(module, tu: TestUtils):
A = tu.rand(4, 4).to(dtype=torch.int32)
B = tu.rand(4, 4).to(dtype=torch.int32)
module.forward(A, B)
class GCDBatchedModule(torch.nn.Module):
@export
@annotate_args(
[None, [(4, 4, 4), torch.int32, True], [(4, 4, 4), torch.int32, True]]
)
def forward(self, A, B):
return torch.gcd(A, B)
@register_test_case(module_factory=lambda: GCDBatchedModule())
def GCDBatchedModule_I32(module, tu: TestUtils):
A = tu.rand(4, 4, 4).to(dtype=torch.int32)
B = tu.rand(4, 4, 4).to(dtype=torch.int32)
module.forward(A, B)
class GCDDynamicModule(torch.nn.Module):
@export
@annotate_args(
[None, [(-1, -1, -1), torch.int32, True], [(-1, -1, -1), torch.int32, True]]
)
def forward(self, A, B):
return torch.gcd(A, B)
@register_test_case(module_factory=lambda: GCDDynamicModule())
def GCDDynamicModule_I32(module, tu: TestUtils):
A = tu.rand(3, 4, 4).to(dtype=torch.int32)
B = tu.rand(3, 4, 4).to(dtype=torch.int32)
module.forward(A, B)

View File

@ -8,6 +8,4 @@ cd "$src_dir"
# Ensure PYTHONPATH is set for export to child processes, even if empty.
export PYTHONPATH=${PYTHONPATH-}
source $project_dir/.env
python -m e2e_testing.main "$@"