[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)

Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451,
https://github.com/nod-ai/SHARK-Turbine/issues/452
pull/2941/head
Vivek Khandelwal 2024-02-27 11:02:05 +05:30 committed by GitHub
parent 3cbe6c98ec
commit d81747eadb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 243 additions and 21 deletions

View File

@ -11231,6 +11231,7 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
printDefaultTorchOp(printer, *this, 1, 1); printDefaultTorchOp(printer, *this, 1, 1);
} }
}]; }];
let hasCanonicalizer = 1;
} }
def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
@ -11254,6 +11255,7 @@ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [
printDefaultTorchOp(printer, *this, 1, 1); printDefaultTorchOp(printer, *this, 1, 1);
} }
}]; }];
let hasCanonicalizer = 1;
} }
def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [ def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [

View File

@ -1339,12 +1339,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value ratio, trainingMode; Value ratio, trainingMode;
if (numOperands == 3) { if (numOperands == 3) {
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]); ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
Value trainingModeScalar = Value trainVal = operands[2];
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]); auto trainTensorType =
Value cstOne = rewriter.create<Torch::ConstantIntOp>( trainVal.getType().dyn_cast<Torch::BaseTensorType>();
loc, rewriter.getI64IntegerAttr(1)); if (!trainTensorType)
trainingMode = rewriter.create<Torch::AtenEqIntOp>( return rewriter.notifyMatchFailure(binder.op,
loc, trainingModeScalar, cstOne); "train tensor must have a type");
Type inputDtype = trainTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(1))
return rewriter.notifyMatchFailure(
binder.op,
"train tensor must have an integer dtype of width 1");
std::optional<unsigned> inputRank = Torch::getTensorRank(trainVal);
if (!inputRank || *inputRank != 0)
return rewriter.notifyMatchFailure(binder.op,
"train tensor must have rank 0");
if (auto valueTensorLiteralOp =
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.getSplatValue<bool>();
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
} else {
Value trainingModeScalar =
rewriter.create<Torch::AtenIntImplicitOp>(loc, operands[2]);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
trainingMode = rewriter.create<Torch::AtenEqIntOp>(
loc, trainingModeScalar, cstOne);
}
} else if (numOperands == 2) { } else if (numOperands == 2) {
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]); ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false); trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, false);

View File

@ -191,12 +191,14 @@ public:
} // namespace } // namespace
namespace { namespace {
class ConvertAtenScalarImplicitOp // Converts a tensor with one element to a scalar value.
: public OpConversionPattern<AtenScalarImplicitOp> { template <typename OpTy>
class ConvertAtenImplicitLikeOp : public OpConversionPattern<OpTy> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor, matchAndRewrite(OpTy op,
typename OpConversionPattern<OpTy>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA()); rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA());
return success(); return success();
@ -224,6 +226,12 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context); patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>(); target.addIllegalOp<PrimNumToTensorScalarOp>();
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context); patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
patterns.add<ConvertAtenScalarImplicitOp>(typeConverter, context); patterns.add<ConvertAtenImplicitLikeOp<AtenScalarImplicitOp>>(typeConverter,
target.addIllegalOp<AtenScalarImplicitOp>(); context);
patterns.add<ConvertAtenImplicitLikeOp<AtenFloatImplicitOp>>(typeConverter,
context);
patterns.add<ConvertAtenImplicitLikeOp<AtenIntImplicitOp>>(typeConverter,
context);
target.addIllegalOp<AtenScalarImplicitOp, AtenFloatImplicitOp,
AtenIntImplicitOp>();
} }

View File

@ -725,13 +725,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(div.getType()) Type dtype = converter->convertType(div.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType(); .getElementType();
if (!dtype.isa<mlir::FloatType>()) {
div.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::DivFOp>(loc, lhs, rhs); if (dtype.isa<mlir::FloatType>())
return b.create<arith::DivFOp>(loc, lhs, rhs);
else if (dtype.isa<mlir::IntegerType>()) {
if (dtype.isUnsignedInteger())
return b.create<arith::DivUIOp>(loc, lhs, rhs);
return b.create<arith::DivSIOp>(loc, lhs, rhs);
}
div.emitError("unimplemented: non-floating point and non-integer dtype");
return nullptr;
} }
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) { if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
AtenDivTensorModeOp::Adaptor adaptor(operands); AtenDivTensorModeOp::Adaptor adaptor(operands);

View File

@ -1568,6 +1568,38 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
}); });
} }
//===----------------------------------------------------------------------===//
// AtenFloatImplicitOp
//===----------------------------------------------------------------------===//
void AtenFloatImplicitOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
Value scalarValue = getScalarFloatValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOp(op, scalarValue);
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenIntImplicitOp
//===----------------------------------------------------------------------===//
void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOp(op, scalarValue);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSizeOp // AtenSizeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -335,6 +335,9 @@ TORCHDYNAMO_XFAIL_SET = {
# Dynamo not supporting conv_tbc # Dynamo not supporting conv_tbc
"ConvTbcModule_basic", "ConvTbcModule_basic",
"FloatImplicitModule_basic",
"IntImplicitModule_basic",
} }
TORCHDYNAMO_CRASHING_SET = { TORCHDYNAMO_CRASHING_SET = {
@ -989,6 +992,8 @@ TOSA_PASS_SET = {
"ElementwiseCloneContiguousModule_basic", "ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic", "ElementwiseCloneModule_basic",
"ElementwiseDivScalarModule_basic", "ElementwiseDivScalarModule_basic",
"ElementwiseDivTensorIntegerModule_basic",
"ElementwiseDivTensorUnsignedIntegerModule_basic",
"ElementwiseEluModule_basic", "ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic", "ElementwiseEluNonDefaultModule_basic",
"ElementwiseEqBoolScalarModule_basic", "ElementwiseEqBoolScalarModule_basic",
@ -2146,8 +2151,6 @@ ONNX_XFAIL_SET = {
"ElementwiseSigmoidIntModule_basic", "ElementwiseSigmoidIntModule_basic",
# Failure - unknown # Failure - unknown
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic", "CopyWithDifferentDTypesModule_basic",
@ -2168,6 +2171,8 @@ ONNX_XFAIL_SET = {
"ReduceMinAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic", "TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic", "TensorsStackPromoteDTypeModule_basic",
"FloatImplicitModule_basic",
"IntImplicitModule_basic",
} }
ONNX_CRASHING_SET = { } ONNX_CRASHING_SET = { }

View File

@ -669,8 +669,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)") emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)") emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::FloatImplicit : (Tensor) -> (float)") emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)

View File

@ -3719,6 +3719,50 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100)) module.forward(tu.randint(low=-100, high=100))
# ==============================================================================
class FloatImplicitModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float64, True),
])
def forward(self, x):
return float(torch.ops.aten.FloatImplicit(x))
@register_test_case(module_factory=lambda: FloatImplicitModule())
def FloatImplicitModule_basic(module, tu: TestUtils):
module.forward(tu.rand().double())
# ==============================================================================
class IntImplicitModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int64, True),
])
def forward(self, x):
return float(torch.ops.aten.IntImplicit(x))
@register_test_case(module_factory=lambda: IntImplicitModule())
def IntImplicitModule_basic(module, tu: TestUtils):
module.forward(tu.randint())
# ============================================================================== # ==============================================================================
class PowIntFloat(torch.nn.Module): class PowIntFloat(torch.nn.Module):

View File

@ -2595,6 +2595,52 @@ def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseDivTensorIntegerModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int32, True),
])
def forward(self, a, b):
return torch.div(a, b)
@register_test_case(module_factory=lambda: ElementwiseDivTensorIntegerModule())
def ElementwiseDivTensorIntegerModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10), tu.randint(3, 4, low=-10, high=10).type(torch.int32))
# ==============================================================================
class ElementwiseDivTensorUnsignedIntegerModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.uint8, True),
([-1, -1], torch.uint8, True),
])
def forward(self, a, b):
return torch.div(a, b)
@register_test_case(module_factory=lambda: ElementwiseDivTensorUnsignedIntegerModule())
def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=0, high=10).to(torch.uint8), tu.randint(3, 4, low=0, high=10).type(torch.uint8))
# ==============================================================================
class ElementwiseDivRoundingModeTruncModule(torch.nn.Module): class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -706,6 +706,15 @@ func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3
// ----- // -----
// CHECK-LABEL: @test_div_int32
func.func @test_div_int32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],si32>
%0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32>
return %0 : !torch.vtensor<[3,4,5],si32>
}
// -----
// CHECK-LABEL: @test_div_uint8 // CHECK-LABEL: @test_div_uint8
func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8> // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8>

View File

@ -2145,6 +2145,52 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number
return %1 : !torch.number return %1 : !torch.number
} }
// -----
// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: return %[[FLOAT1]] : !torch.float
func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
%float1 = torch.constant.float 1.0
%0 = torch.prim.NumToTensor.Scalar %float1 : !torch.float -> !torch.vtensor<[],f64>
%1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float
return %1 : !torch.float
}
// -----
// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: return %[[FLOAT1]] : !torch.float
func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float {
%0 = torch.vtensor.literal(dense<1.0> : tensor<f64>) : !torch.vtensor<[],f64>
%1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float
return %1 : !torch.float
}
// -----
// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: return %[[INT1]] : !torch.int
func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int
return %1 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: return %[[INT1]] : !torch.int
func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int {
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int
return %1 : !torch.int
}
// -----
// CHECK-LABEL: func.func @torch.prims.view_of$fold( // CHECK-LABEL: func.func @torch.prims.view_of$fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32> // CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>