diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index c9c003631..4c807d8a6 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -50,6 +50,32 @@ MHLO_PASS_SET = { "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", + "ElementwiseAddModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalarInt64Module_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseGeluModule_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseMulScalarModule_basic", + "ElementwiseMulScalarModule_float", + "ElementwiseMulScalarModule_int", + "ElementwiseNeFloatTensorModule_basic", + "ElementwiseNeIntScalarModule_basic", + "ElementwiseReciprocalModule_basic", + "ElementwiseRelu6Module_basic", + "ElementwiseReluModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseSubScalarIntModule_basic", "ExpandAsIntModule_basic", "ExpandModule_basic", "FullLikeModuleDefaultDtype_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 26b77cca3..faaeaa7a3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -158,6 +158,51 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [ }]; } +def Torch_AtenRelu6Op : Torch_Op<"aten.relu6", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::relu6 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRelu6Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRelu6Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenRelu6_Op : Torch_Op<"aten.relu6_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::relu6_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRelu6_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRelu6_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 2deba5073..338b1aed1 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -184,6 +184,41 @@ public: } // namespace +// The binary broadcast patterns +namespace { +template +class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + Value rhs = adaptor.other(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError("only Tensor types supported"); + + auto lhsElemTy = lhsTy.getElementType(); + auto rhsElemTy = rhsTy.getElementType(); + + if (lhsElemTy != rhsElemTy) + return op.emitError("input data types mismatched"); + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + lhs, rhs, + /*broadcast_attr*/ nullptr); + return success(); + } +}; +} // namespace + // These binary op legalizations are specific to add/sub which have an // alpha multiplier. namespace { @@ -1231,4 +1266,13 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSizeIntOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); #undef INSERT_ATENOP_PATTERN + +#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context) + INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); + INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); + INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); +#undef INSERT_BINARY_BROADCAST_PATTERN } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 45a100177..73c47ec26 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -645,6 +645,20 @@ static Value getRelu6Results(PatternRewriter &rewriter, Location loc, return relu6Out; } +namespace { +class DecomposeAtenRelu6Op : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRelu6Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value relu6 = getRelu6Results(rewriter, loc, op.self()); + rewriter.replaceOp(op, relu6); + return success(); + } +}; +} // namespace + // Hardswish(x) = x * Relu6(x+3)/6 namespace { class DecomposeAtenHardswishOp : public OpRewritePattern { @@ -2907,6 +2921,8 @@ public: target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index c04cd7681..b286546ec 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -677,7 +677,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Take dtype from first operand. if (isa) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.relu6\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7821,4 +7825,4 @@ StringRef mlir::torch::Torch::getShapeLibrary() { #ifndef _MSC_VER #pragma clang diagnostic pop #endif -} \ No newline at end of file +} diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp index 8bc19645d..c28ac45eb 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp @@ -12,6 +12,7 @@ #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpDefinition.h" @@ -45,9 +46,10 @@ class VerifyMhloBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp( - opHasLegalTypes); - // Basic scalar operations. + target.addDynamicallyLegalOp(opHasLegalTypes); + // Shape operations. + target.addDynamicallyLegalOp(opHasLegalTypes); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 17da415d7..c0e96c40b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -397,6 +397,9 @@ def aten〇log(self: List[int]) -> List[int]: def aten〇relu(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇relu6(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇_softmax(self: List[int], dim: int, half_to_float: bool) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 153e00ad5..7192b5540 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -241,6 +241,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::tanh : (Tensor) -> (Tensor)", "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", + "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 68e84a07f..6770d7237 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -345,6 +345,28 @@ def ElementwiseReluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRelu6Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.relu6(x) + + +@register_test_case(module_factory=lambda: ElementwiseRelu6Module()) +def ElementwiseRelu6Module_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 2) - 0.5) + + +# ============================================================================== + + class ElementwiseLeakyReluModule(torch.nn.Module): def __init__(self):