Add relu6 and binary broadcasts (#1408)

* Add relu6 and binary broadcasts
pull/1371/head
Tanyo Kwok 2022-09-23 20:39:15 +08:00 committed by GitHub
parent 7a77f9fe3d
commit 72e422b589
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 168 additions and 5 deletions

View File

@ -50,6 +50,32 @@ MHLO_PASS_SET = {
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseAddModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatScalarModule_basic",
"ElementwiseEqIntScalarModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseGtFloatScalarModule_basic",
"ElementwiseGtIntScalarModule_basic",
"ElementwiseGtMixed2ScalarModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic",
"ElementwiseLtFloatScalarModule_basic",
"ElementwiseLtIntScalarModule_basic",
"ElementwiseMulScalarModule_basic",
"ElementwiseMulScalarModule_float",
"ElementwiseMulScalarModule_int",
"ElementwiseNeFloatTensorModule_basic",
"ElementwiseNeIntScalarModule_basic",
"ElementwiseReciprocalModule_basic",
"ElementwiseRelu6Module_basic",
"ElementwiseReluModule_basic",
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ExpandAsIntModule_basic",
"ExpandModule_basic",
"FullLikeModuleDefaultDtype_basic",

View File

@ -158,6 +158,51 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [
}];
}
def Torch_AtenRelu6Op : Torch_Op<"aten.relu6", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::relu6 : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRelu6Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenRelu6Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenRelu6_Op : Torch_Op<"aten.relu6_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::relu6_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRelu6_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenRelu6_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -184,6 +184,41 @@ public:
} // namespace
// The binary broadcast patterns
namespace {
template <typename AtenOpT, typename ChloOpT>
class ConvertAtenBinaryBroadcastOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<TensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("only Tensor types supported");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
if (lhsElemTy != rhsElemTy)
return op.emitError("input data types mismatched");
rewriter.replaceOpWithNewOp<ChloOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhs,
/*broadcast_attr*/ nullptr);
return success();
}
};
} // namespace
// These binary op legalizations are specific to add/sub which have an
// alpha multiplier.
namespace {
@ -1231,4 +1266,13 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSizeIntOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \
context)
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);
#undef INSERT_BINARY_BROADCAST_PATTERN
}

View File

@ -645,6 +645,20 @@ static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
return relu6Out;
}
namespace {
class DecomposeAtenRelu6Op : public OpRewritePattern<AtenRelu6Op> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRelu6Op op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value relu6 = getRelu6Results(rewriter, loc, op.self());
rewriter.replaceOp(op, relu6);
return success();
}
};
} // namespace
// Hardswish(x) = x * Relu6(x+3)/6
namespace {
class DecomposeAtenHardswishOp : public OpRewritePattern<AtenHardswishOp> {
@ -2907,6 +2921,8 @@ public:
target.addIllegalOp<AtenRandLikeOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);
target.addIllegalOp<AtenHardsigmoidOp>();
patterns.add<DecomposeAtenRelu6Op>(context);
target.addIllegalOp<AtenRelu6Op>();
patterns.add<DecomposeAtenHardswishOp>(context);
target.addIllegalOp<AtenHardswishOp>();
patterns.add<DecomposeAtenSoftplusOp>(context);

View File

@ -677,7 +677,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// Take dtype from first operand.
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenBatchNormOp,
AtenReluOp, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
AtenReluOp, AtenRelu6Op, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op,
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenClampMinOp,

View File

@ -6402,6 +6402,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.relu6\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -7821,4 +7825,4 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
#ifndef _MSC_VER
#pragma clang diagnostic pop
#endif
}
}

View File

@ -12,6 +12,7 @@
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"
@ -45,9 +46,10 @@ class VerifyMhloBackendContractPass
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
opHasLegalTypes);
// Basic scalar operations.
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(opHasLegalTypes);
// Shape operations.
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
target.addLegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<chlo::ChloDialect>();
target.addLegalDialect<tensor::TensorDialect>();

View File

@ -397,6 +397,9 @@ def atenlog(self: List[int]) -> List[int]:
def atenrelu(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenrelu6(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def aten_softmax(self: List[int], dim: int, half_to_float: bool) -> List[int]:
return upstream_shape_functions.unary(self)

View File

@ -241,6 +241,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::tanh : (Tensor) -> (Tensor)",
"aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::relu : (Tensor) -> (Tensor)",
"aten::relu6 : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",

View File

@ -345,6 +345,28 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseRelu6Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.relu6(x)
@register_test_case(module_factory=lambda: ElementwiseRelu6Module())
def ElementwiseRelu6Module_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2) - 0.5)
# ==============================================================================
class ElementwiseLeakyReluModule(torch.nn.Module):
def __init__(self):