mirror of https://github.com/llvm/torch-mlir
parent
7a77f9fe3d
commit
72e422b589
|
@ -50,6 +50,32 @@ MHLO_PASS_SET = {
|
|||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
"ElementwiseToDtypeF32ToI64Module_basic",
|
||||
"ElementwiseAddModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
"ElementwiseDivScalarModule_basic",
|
||||
"ElementwiseEqDiffWidthScalarModule_basic",
|
||||
"ElementwiseEqFloatScalarModule_basic",
|
||||
"ElementwiseEqIntScalarModule_basic",
|
||||
"ElementwiseErfModule_basic",
|
||||
"ElementwiseGeluModule_basic",
|
||||
"ElementwiseGtFloatScalarModule_basic",
|
||||
"ElementwiseGtIntScalarModule_basic",
|
||||
"ElementwiseGtMixed2ScalarModule_basic",
|
||||
"ElementwiseLtDiffWidthScalarModule_basic",
|
||||
"ElementwiseLtFloatScalarModule_basic",
|
||||
"ElementwiseLtIntScalarModule_basic",
|
||||
"ElementwiseMulScalarModule_basic",
|
||||
"ElementwiseMulScalarModule_float",
|
||||
"ElementwiseMulScalarModule_int",
|
||||
"ElementwiseNeFloatTensorModule_basic",
|
||||
"ElementwiseNeIntScalarModule_basic",
|
||||
"ElementwiseReciprocalModule_basic",
|
||||
"ElementwiseRelu6Module_basic",
|
||||
"ElementwiseReluModule_basic",
|
||||
"ElementwiseSubScalarFloatModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
"ExpandAsIntModule_basic",
|
||||
"ExpandModule_basic",
|
||||
"FullLikeModuleDefaultDtype_basic",
|
||||
|
|
|
@ -158,6 +158,51 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRelu6Op : Torch_Op<"aten.relu6", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::relu6 : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRelu6Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenRelu6Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRelu6_Op : Torch_Op<"aten.relu6_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::relu6_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRelu6_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenRelu6_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -184,6 +184,41 @@ public:
|
|||
|
||||
} // namespace
|
||||
|
||||
// The binary broadcast patterns
|
||||
namespace {
|
||||
template <typename AtenOpT, typename ChloOpT>
|
||||
class ConvertAtenBinaryBroadcastOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("only Tensor types supported");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return op.emitError("input data types mismatched");
|
||||
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rhs,
|
||||
/*broadcast_attr*/ nullptr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// These binary op legalizations are specific to add/sub which have an
|
||||
// alpha multiplier.
|
||||
namespace {
|
||||
|
@ -1231,4 +1266,13 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenSizeIntOp);
|
||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \
|
||||
context)
|
||||
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
|
||||
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
|
||||
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);
|
||||
#undef INSERT_BINARY_BROADCAST_PATTERN
|
||||
}
|
||||
|
|
|
@ -645,6 +645,20 @@ static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
|||
return relu6Out;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenRelu6Op : public OpRewritePattern<AtenRelu6Op> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value relu6 = getRelu6Results(rewriter, loc, op.self());
|
||||
rewriter.replaceOp(op, relu6);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Hardswish(x) = x * Relu6(x+3)/6
|
||||
namespace {
|
||||
class DecomposeAtenHardswishOp : public OpRewritePattern<AtenHardswishOp> {
|
||||
|
@ -2907,6 +2921,8 @@ public:
|
|||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
patterns.add<DecomposeAtenRelu6Op>(context);
|
||||
target.addIllegalOp<AtenRelu6Op>();
|
||||
patterns.add<DecomposeAtenHardswishOp>(context);
|
||||
target.addIllegalOp<AtenHardswishOp>();
|
||||
patterns.add<DecomposeAtenSoftplusOp>(context);
|
||||
|
|
|
@ -677,7 +677,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
// Take dtype from first operand.
|
||||
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenBatchNormOp,
|
||||
AtenReluOp, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
|
||||
AtenReluOp, AtenRelu6Op, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
|
||||
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
||||
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op,
|
||||
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenClampMinOp,
|
||||
|
|
|
@ -6402,6 +6402,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.relu6\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -7821,4 +7825,4 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
#ifndef _MSC_VER
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
@ -45,9 +46,10 @@ class VerifyMhloBackendContractPass
|
|||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
|
||||
opHasLegalTypes);
|
||||
// Basic scalar operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(opHasLegalTypes);
|
||||
// Shape operations.
|
||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
||||
|
||||
target.addLegalDialect<mhlo::MhloDialect>();
|
||||
target.addLegalDialect<chlo::ChloDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
|
|
|
@ -397,6 +397,9 @@ def aten〇log(self: List[int]) -> List[int]:
|
|||
def aten〇relu(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇relu6(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇_softmax(self: List[int], dim: int, half_to_float: bool) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
|
|
@ -241,6 +241,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::tanh : (Tensor) -> (Tensor)",
|
||||
"aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
"aten::relu : (Tensor) -> (Tensor)",
|
||||
"aten::relu6 : (Tensor) -> (Tensor)",
|
||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::log : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -345,6 +345,28 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRelu6Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.relu6(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRelu6Module())
|
||||
def ElementwiseRelu6Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 2) - 0.5)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeakyReluModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue