From 8d4879feb0a0941636a106b6dd45f7a3eb071107 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Fri, 10 Dec 2021 20:25:47 +0530 Subject: [PATCH] [TORCH][MLIR] Add and templatize lowering of [`aten.zeros|aten.ones|aten.empty`] ops - Templatize `aten.zeros` and `aten.ones` ops lowering. - Add E2E support for `aten.empty` op. - Add Integer type support in `aten.mul.Scalar` op lowering. Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/basic.py | 51 ++++++ e2e_testing/torchscript/elementwise.py | 38 ++++- .../TorchToLinalg/TorchToLinalg.cpp | 156 +++++++++--------- 3 files changed, 168 insertions(+), 77 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 7b7a7abbb..bec21657e 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -637,6 +637,57 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils): # ============================================================================== +class EmptyIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return 0 * torch.empty((3, 4), dtype=torch.int64) + +@register_test_case(module_factory=lambda: EmptyIntModule()) +def EmptyModule_int(module, tu: TestUtils): + module.forward() + +# ============================================================================== + +class EmptyFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.abs(torch.empty((3, 4), dtype=torch.float32)) > -1.0 + +@register_test_case(module_factory=lambda: EmptyFloatModule()) +def EmptyModule_float(module, tu: TestUtils): + module.forward() + + +class EmptyFalsePinMemoryModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.abs(torch.empty((3, 4), dtype=torch.float32, + pin_memory=False)) > -1.0 + +@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule()) +def EmptyModule_falsePinMemory(module, tu: TestUtils): + module.forward() + +# ============================================================================== + class ContiguousModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index ab8e3a70f..0c9ae4cbe 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -443,8 +443,24 @@ def RsubModule_noalpha_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseMulScalarIntModule(torch.nn.Module): + def __init__(self): + super().__init__() -class ElementwiseMulScalarModule(torch.nn.Module): + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.mul(x, 4) + +@register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule()) +def ElementwiseMulScalarModule_int(module, tu: TestUtils): + module.forward(torch.randint(10, (3, 4))) + + +class ElementwiseMulScalarFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -456,11 +472,27 @@ class ElementwiseMulScalarModule(torch.nn.Module): def forward(self, x): return torch.mul(x, 100.0) -@register_test_case(module_factory=lambda: ElementwiseMulScalarModule()) -def ElementwiseMulScalarModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: ElementwiseMulScalarFloatModule()) +def ElementwiseMulScalarModule_float(module, tu: TestUtils): module.forward(tu.rand(3, 4)) +class ElementwiseMulScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.mul(x, 8.0) + +@register_test_case(module_factory=lambda: ElementwiseMulScalarModule()) +def ElementwiseMulScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (3, 4))) + class ElementwiseMulTensorFloatModule(torch.nn.Module): def __init__(self): diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index f6a46a32a..f4483bef5 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -189,6 +189,21 @@ static Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value initTensor = b.create(loc, sizes, elemTy); return b.create(loc, initElem, initTensor).getResult(0); } +// Creates a constant of type `elemType` with value `val`. +static Value getConstant(OpBuilder &b, Location loc, int64_t val, + Type elemType) { + Attribute attr = {}; + if (elemType.isa()) + attr = b.getFloatAttr(elemType, val); + if (elemType.isa()) + attr = b.getIndexAttr(val); + if (elemType.isa()) + attr = b.getIntegerAttr( + elemType, APInt(elemType.cast().getWidth(), val)); + if (!attr) + return nullptr; + return b.create(loc, elemType, attr); +} // Helper function to caculate the output tensor dims for convolution-like ops. // Along each dim: @@ -1828,13 +1843,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(mulScalar.getType()) .cast() .getElementType(); - if (!dtype.isa()) { - mulScalar.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } - Value self = payloadArgs[0]; - Value other = convertScalarToDtype(b, loc, operands[1], dtype); - return b.create(loc, self, other); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); + if (dtype.isa()) + return b.create(loc, lhs, rhs); + if (dtype.isa()) + return b.create(loc, lhs, rhs); + mulScalar.emitError("unimplemented: Only integer/float dtype supported"); + return nullptr; } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; @@ -3417,83 +3433,68 @@ public: } // namespace namespace { -// Converts AtenOnesOp and AtenZerosOp. -struct ConvertAtenOnesZerosOp : ConversionPattern { - ConvertAtenOnesZerosOp(TypeConverter &typeConverter, MLIRContext *context) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, - context) {} - +// Converts constant tensor allocation like ops. +template +class ConvertConstantTensorAllocOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!isa(op)) - return rewriter.notifyMatchFailure(op, - "not a supported ones or zeros op"); - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Location loc = op->getLoc(); - Value size, layout, pin_memory; - int64_t elementValue; - - if (AtenOnesOp onesOp = dyn_cast(op)) { - size = onesOp.size(); - layout = onesOp.layout(); - pin_memory = onesOp.pin_memory(); - elementValue = 1; - } else if (AtenZerosOp zerosOp = dyn_cast(op)) { - size = zerosOp.size(); - layout = zerosOp.layout(); - pin_memory = zerosOp.pin_memory(); - elementValue = 0; - } - - // We ignore device, but add simple asserts for unimplemented kwargs - if (!layout.getType().isa()) - return rewriter.notifyMatchFailure(op, - "only default layout is supported"); - - bool pinMemory = false; - if (!pin_memory.getType().isa() && - !matchPattern(pin_memory, m_TorchConstantBool(&pinMemory))) { + // Currently memory pinning and layout features are not supported. + if (!op.layout().getType().template isa()) return rewriter.notifyMatchFailure( - op, "pin_memory must be constant bool or None"); - } - if (pinMemory) - return rewriter.notifyMatchFailure(op, "memory pinning not supported"); - - SmallVector sizes, sizeIndex; - if (!getListConstructElements(size, sizes)) { + op, "unimplemented: only default layout is supported"); + bool pinMemory; + if (!op.pin_memory().getType().template isa() && + (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || + pinMemory)) { return rewriter.notifyMatchFailure( - op, "size must be created by ListConstruct"); + op, "unimplemented: pin_memory must be either None or false"); } - sizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), sizes); - for (size_t i = 0; i < sizes.size(); i++) - sizeIndex.push_back(castIntToIndex(rewriter, loc, sizes[i])); - RankedTensorType newResultType = - getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); - Type outElementType = newResultType.getElementType(); + // Memory formats are not supported in the case of `AtenEmptyMemoryFormat`. + if constexpr (std::is_same::value) { + if (!op.memory_format().getType().template isa()) + return rewriter.notifyMatchFailure( + op, "unimplemented: only default memory format is supported"); + } - Value constantOp = rewriter.create( - loc, outElementType, - (outElementType.isa() - ? rewriter.getFloatAttr(outElementType, elementValue) - .cast() - : rewriter.getIntegerAttr(outElementType, elementValue) - .cast())); - Value outTensor = rewriter - .create( - loc, sizeIndex, newResultType.getElementType()) - .getResult(); - Value fillOp = rewriter.create(loc, constantOp, outTensor) - .getResult(0); + Location loc = op.getLoc(); + TypeConverter *typeConverter = this->getTypeConverter(); + SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; + if (!getListConstructElements(op.size(), resultSizeTorchInt)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: size must be constructed using ListConstruct"); + } + resultSize = getTypeConvertedValues(rewriter, loc, typeConverter, + resultSizeTorchInt); + for (auto size : resultSize) + resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); - rewriter.replaceOpWithNewOp(op, newResultType, fillOp); + auto resultType = + typeConverter->convertType(op.getType()).template cast(); + Type outElemType = resultType.getElementType(); + // Create an uninitialized tensor of `resultSize` shape. It will be returned + // without initialization/filling in the case of `AtenEmptyMemoryFormatOp`. + Value outputTensor = rewriter.create( + loc, resultSizeIndex, outElemType); + + // `AtenZeros` and `AtenOnes` ops will be filled with corresponding values. + if (std::is_same::value) { + Value zero = getConstant(rewriter, loc, 0, outElemType); + outputTensor = + rewriter.create(loc, zero, outputTensor).getResult(0); + } else if (std::is_same::value) { + Value one = getConstant(rewriter, loc, 1, outElemType); + outputTensor = + rewriter.create(loc, one, outputTensor).getResult(0); + } + rewriter.replaceOpWithNewOp(op, resultType, outputTensor); return success(); } }; @@ -3704,8 +3705,15 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp();