diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index f7e146d8a..21002e03b 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -259,3 +259,125 @@ class GatherModule(torch.nn.Module): @register_test_case(module_factory=lambda: GatherModule()) def GatherModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]])) + +class AddSizeIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, tensor): + # This is a workaround for not supporting scalar arguments. + # TODO: pass in dim as an argument to the forward method when scalar + # arguments are supported. + return tensor.add(tensor, alpha=tensor.size(1)) + + +@register_test_case(module_factory=lambda: AddSizeIntModule()) +def AddSizeIntModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3)) + + +class AddSizeIntNegDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, tensor): + # This is a workaround for not supporting scalar arguments. + # TODO: pass in dim as an argument to the forward method when scalar + # arguments are supported. + return tensor.add(tensor, alpha=tensor.size(-2)) + + +@register_test_case(module_factory=lambda: AddSizeIntNegDimModule()) +def AddSizeIntNegDimModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3)) + +class EmbeddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.embed = torch.nn.Embedding(num_embeddings=100, + embedding_dim=50, + padding_idx=4) + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, indices): + return self.embed.forward(indices) + + +@register_test_case(module_factory=lambda: EmbeddingModule()) +def EmbeddingModule_basic(module, tu: TestUtils): + module.forward(torch.randint(100, (3, 3))) + + +class SoftmaxIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.softmax = torch.nn.Softmax(2) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, tensor): + return self.softmax.forward(tensor) + + +@register_test_case(module_factory=lambda: SoftmaxIntModule()) +def SoftmaxIntModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 2, 4)) + + +class SoftmaxIntNegDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.softmax = torch.nn.Softmax(-2) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, tensor): + return self.softmax.forward(tensor) + + +@register_test_case(module_factory=lambda: SoftmaxIntNegDimModule()) +def SoftmaxIntNegDimModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 2, 4)) + + +class SoftmaxIntArgTypeF64Module(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.softmax = torch.nn.Softmax(2) + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, tensor): + return self.softmax.forward(tensor) + + +@register_test_case(module_factory=lambda: SoftmaxIntArgTypeF64Module()) +def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 2, 4).double()) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 5bc50c3e1..0796a19ff 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -937,6 +937,22 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ let assemblyFormat = "$self `,` $kernel_size `,` $stride `,` $padding `,` $dilation `,` $ceil_mode attr-dict `:` type($self) `,` type($kernel_size) `,` type($stride) `,` type($padding) `,` type($dilation) `,` type($ceil_mode) `->` type($result)"; } +def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::softmax.int : (Tensor, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + TorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($dtype) `->` type($result)"; +} + def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics @@ -1591,6 +1607,7 @@ def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ Torch_IntType:$result ); let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)"; + let hasFolder = 1; } def Torch_AtenStackOp : Torch_Op<"aten.stack", [ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 9775ca3ec..fc2a0f06c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -819,6 +819,7 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [ let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) }]; + let hasCanonicalizer = 1; } def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [ diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 34fd64170..85f55b981 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -54,6 +54,8 @@ std::unique_ptr> createMaximizeValueSemanticsPass(); std::unique_ptr> createRefinePublicReturnPass(); +std::unique_ptr> createDecomposeComplexOpsPass(); + } // namespace Torch /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 464e64079..9d973b1d3 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -215,4 +215,20 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> { }]; } +def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> { + let summary = "Decompose complicated torch operations"; + let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()"; + let description = [{ + Decompose torch operation that are losslessly represented as combinations of + other operations, modulo appropropriate compiler fusion. Note that this pass + is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about + systematic reductions of a large number of ops at once, guided mostly by + traits. + + An example of the transformations done in this pass is: + - convert aten.softmax to softmax(x, dim) + => tmp=exp(x); tmp / sum(tmp, dim, keepdim=True) + }]; +} + #endif // TORCHMLIR_TORCH_PASSES diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h new file mode 100644 index 000000000..d1e43d152 --- /dev/null +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +#ifndef TORCHMLIR_DIALECT_TORCH_UTILS_H +#define TORCHMLIR_DIALECT_TORCH_UTILS_H + +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace torch { +namespace Torch { + +int64_t toPositiveDim(int64_t dim, int64_t inputRank); +bool isValidDim(int64_t dim, int64_t inputRank); + +} // namespace Torch +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_DIALECT_TORCH_UTILS_H diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 7442ebc3c..7e4680dfb 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" @@ -70,6 +71,35 @@ static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, return success(); } +// Generate IR: dim = dim >= 0 ? dim : dim + inputRank +static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, + Value inputRank) { + assert(dim.getType().isa() && + "dim arg of toPositiveDim must be integer type"); + Value dimAddInputRank = b.create(loc, dim, inputRank); + Value cst0 = b.create(loc, b.getZeroAttr(inputRank.getType())); + Value predDimGEZero = + b.create(loc, arith::CmpIPredicate::sge, dim, cst0); + Value dimInt = b.create(loc, predDimGEZero, dim, dimAddInputRank); + return dimInt; +} + +// Generate IR: assert(dim >= 0 && dim < inputRank) +static void assertIsValidDim(OpBuilder &b, Location loc, Value dim, + Value inputRank) { + assert(dim.getType().isa() && + "dim arg of assertIsValidDim must be integer type"); + Value cst0 = b.create(loc, b.getZeroAttr(inputRank.getType())); + Value predGEZero = + b.create(loc, arith::CmpIPredicate::sge, dim, cst0); + b.create(loc, predGEZero, + b.getStringAttr("dim must be greater or equal to zero")); + Value predLTInputRank = + b.create(loc, arith::CmpIPredicate::slt, dim, inputRank); + b.create(loc, predLTInputRank, + b.getStringAttr("dim must be smaller than inputRank")); +} + // Hack to deal with the Torch list type arguments which is not supported end // to end. Constant values can be be extracted directly and non constant // list values are not supported. @@ -459,6 +489,41 @@ static Value createLinalgPayloadCalculationForNormOps( return plusBias; } +static void createLinalgPayloadCalculationForGatherOps( + OpBuilder &b, Location loc, Value input, int64_t inputRank, Value index, + int64_t dim, int64_t outputRank) { + SmallVector indices; + for (int i = 0; i < inputRank; i++) { + if (i == dim) { + indices.push_back(castIntToIndex(b, loc, index)); + } else { + // `outputRank` might be larger than `inputRank`. The `linalg::IndexOp` + // takes in the dimension of the output. Add `inputDimOffset` to + // related to the correct dimension of the output for dimension larger + // than the given `dim`. + int64_t inputDimOffset = i < dim ? 0 : outputRank - inputRank; + indices.push_back(b.create(loc, i + inputDimOffset)); + } + } + + // Assert index < input.sizes[dim] + Value indexLTInputDim = b.create( + loc, arith::CmpIPredicate::slt, index, + castIndexToInt(b, loc, getDimOp(b, loc, input, dim))); + b.create(loc, indexLTInputDim, + b.getStringAttr("index must be smaller than dim size")); + + // Assert index >= 0 + Value cst0 = b.create(loc, b.getZeroAttr(index.getType())); + Value indexGEThanZero = + b.create(loc, arith::CmpIPredicate::sge, index, cst0); + b.create(loc, indexGEThanZero, + b.getStringAttr("index must be larger or equal to 0")); + + Value extract = b.create(loc, input, indices); + b.create(loc, extract); +} + namespace { class ConvertAtenBatchNormOp : public OpConversionPattern { public: @@ -1027,6 +1092,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( ArrayRef operands) { if (isa(op)) return b.create(loc, payloadArgs[0]); + if (isa(op)) + return b.create(loc, payloadArgs[0]); if (isa(op)) { Type elementType = payloadArgs[0].getType(); auto one = b.create(loc, FloatAttr::get(elementType, 1)); @@ -1330,8 +1397,8 @@ struct ConvertElementwiseOp : ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isa( - op)) + AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, + AtenExpOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1846,13 +1913,12 @@ public: for (auto i = 0; i < inputRank; i++) idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); for (auto i = 0; i < inputRank; i++) { - if (i == dim0) { + if (i == dim0) swapExprs.push_back(idExprs[dim1]); - } else if (i == dim1) { + else if (i == dim1) swapExprs.push_back(idExprs[dim0]); - } else { + else swapExprs.push_back(idExprs[i]); - } } SmallVector indexingMaps = { @@ -1893,19 +1959,15 @@ public: // Collect all the tensors to be concatenated. auto tensorList = op.tensors(); - auto listConstruct = tensorList.getDefiningOp(); - if (!listConstruct) + SmallVector tensorsTorchType; + if (!getListConstructElements(tensorList, tensorsTorchType)) return op.emitError( "unimplemented: the tensor list is not from list construct"); - auto tensors = llvm::to_vector<4>( - llvm::map_range(listConstruct.elements(), [&](Value tensor) -> Value { - return typeConverter->materializeTargetConversion( - rewriter, loc, getTypeConverter()->convertType(tensor.getType()), - tensor); - })); + auto tensors = + getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); RankedTensorType newResultType = - getTypeConverter()->convertType(op.getType()).cast(); + typeConverter->convertType(op.getType()).cast(); int rank = newResultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); @@ -1975,18 +2037,9 @@ public: auto genericOp = rewriter.create( loc, newResultTy, indices, result, affineMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - auto indexValue = args[0]; - Value indexOfDim = rewriter.create( - loc, rewriter.getIndexType(), indexValue); - SmallVector indices; - for (int i = 0; i < rank; i++) { - indices.push_back(i == dim - ? indexOfDim - : rewriter.create(loc, i)); - } - Value extract = - rewriter.create(loc, self, indices); - rewriter.create(loc, extract); + auto index = args[0]; + createLinalgPayloadCalculationForGatherOps(b, loc, self, rank, index, + dim, rank); }); rewriter.replaceOp(op, genericOp.getResult(0)); return success(); @@ -1994,6 +2047,93 @@ public: }; } // namespace +namespace { +class ConvertAtenEmbeddingOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenEmbeddingOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op->getLoc(); + AtenEmbeddingOp::Adaptor adaptor(operands); + Value weight = adaptor.weight(); + Value indices = adaptor.indices(); + RankedTensorType newResultType = + typeConverter->convertType(op.getType()).cast(); + + auto weightTy = weight.getType().cast(); + if (weightTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "weight must be rank 2"); + Value embeddingDim = getDimOp(rewriter, loc, weight, 1); + Type elemTy = weightTy.getElementType(); + + SmallVector sizes = getTensorSizes(rewriter, loc, indices); + sizes.push_back(embeddingDim); + int64_t resultRank = sizes.size(); + + auto indicesTy = weight.getType().cast(); + int64_t indicesRank = indicesTy.getRank(); + SmallVector indicesExprs; + for (int i = 0; i < indicesRank; i++) + indicesExprs.push_back(rewriter.getAffineDimExpr(i)); + auto indicesAffineMap = AffineMap::get( + /*dimCount=*/resultRank, + /*symbolCount=*/0, indicesExprs, op->getContext()); + SmallVector indexingMaps = { + indicesAffineMap, + rewriter.getMultiDimIdentityMap(resultRank), + }; + SmallVector iteratorTypes(sizes.size(), + getParallelIteratorTypeName()); + Value initTensor = + rewriter.create(loc, sizes, elemTy); + Value embeddingResult = + rewriter + .create( + loc, initTensor.getType(), indices, initTensor, + /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value index = args[0]; + createLinalgPayloadCalculationForGatherOps( + b, loc, weight, weightTy.getRank(), index, /*dim=*/0, + resultRank); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, + embeddingResult); + return success(); + } +}; +} // namespace + +namespace { +class ConvertAtenSizeIntOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSizeIntOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op->getLoc(); + AtenSizeIntOp::Adaptor adaptor(operands); + Value self = adaptor.self(); + Value dim = adaptor.dim(); + auto type = self.getType().cast(); + Value inputRank = rewriter.create( + loc, rewriter.getI64IntegerAttr(type.getRank())); + Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); + assertIsValidDim(rewriter, loc, dimPositive, inputRank); + Value size = rewriter.create( + loc, adaptor.self(), castIntToIndex(rewriter, loc, dimPositive)); + rewriter.replaceOp(op, castIndexToInt(rewriter, loc, size)); + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -2057,6 +2197,10 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/CMakeLists.txt b/lib/Dialect/Torch/CMakeLists.txt index 9f57627c3..31167e6af 100644 --- a/lib/Dialect/Torch/CMakeLists.txt +++ b/lib/Dialect/Torch/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index f8357002f..cf2f904c4 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/StringMap.h" using namespace mlir; @@ -505,6 +506,29 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenSizeIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenSizeIntOp::fold(ArrayRef operands) { + auto type = getOperand(0).getType().dyn_cast(); + if (!type || !type.hasSizes()) + return nullptr; + + int64_t inputRank = type.getSizes().size(); + int64_t dim; + if (!matchPattern(this->dim(), m_TorchConstantInt(&dim))) + return nullptr; + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return nullptr; + + if (type.getSizes()[dim] == kUnknownSize) + return nullptr; + return IntegerAttr::get(IntegerType::get(getContext(), 64), + type.getSizes()[dim]); +} + //===----------------------------------------------------------------------===// // AtenGtIntOp //===----------------------------------------------------------------------===// @@ -655,6 +679,19 @@ bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs, outputs[0].cast()); } +void TensorStaticInfoCastOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) { + auto reverseCast = + op.operand().getDefiningOp(); + if (!reverseCast || reverseCast.operand().getType() != op.getType()) + return failure(); + + rewriter.replaceOp(op, reverseCast.operand()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // CopyToNonValueTensorOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index 81187ec58..adde4c377 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TorchMLIRTorchPasses AdjustCallingConventions.cpp + DecomposeComplexOps.cpp Passes.cpp GlobalizeObjectGraph.cpp InlineGlobalSlots.cpp @@ -23,6 +24,7 @@ add_mlir_library(TorchMLIRTorchPasses MLIRPass MLIRTransforms TorchMLIRTorchDialect + TorchMLIRTorchUtils ) torch_mlir_target_includes(TorchMLIRTorchPasses) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp new file mode 100644 index 000000000..46e7432d9 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -0,0 +1,100 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/StringExtras.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +// Decompose softmax into: exp(x) / sum(exp(x)) +namespace { +class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.self(); + Value dim = op.dim(); + if (!op.dtype().getType().isa()) + return rewriter.notifyMatchFailure( + op, "Unimplemented non-None dtype for softmax"); + + BaseTensorType tensorType = self.getType().cast(); + if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + return rewriter.notifyMatchFailure(op, "Only support floating type"); + // exp(x) + Value exp = rewriter.create(loc, tensorType, self); + + // sum(exp(x)) + Value dimList = rewriter.create( + loc, Torch::ListType::get(dim.getType()), dim); + Value keepDim = rewriter.create(loc, true); + Value dtype = rewriter.create(loc); + SmallVector sizes; + int64_t dimInt; + if (tensorType.hasSizes()) { + ArrayRef inputShape = tensorType.getSizes(); + int64_t inputRank = inputShape.size(); + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + sizes.append(inputShape.begin(), inputShape.end()); + sizes[dimInt] = 1; + } else { + sizes.resize(inputRank, kUnknownSize); + } + } + Type resultType = tensorType.getWithSizesAndDtype( + sizes.size() == 0 ? Optional>() + : llvm::makeArrayRef(sizes), + tensorType.getDtype()); + Value sum = rewriter.create(loc, resultType, exp, + dimList, keepDim, dtype); + // exp(x) / sum(exp(x)) + Value result = rewriter.create(loc, tensorType, exp, sum); + rewriter.replaceOpWithNewOp(op, op.getType(), + result); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeComplexOpsPass + : public DecomposeComplexOpsBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + target.addLegalDialect(); + + patterns.add(context); + target.addIllegalOp(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace +std::unique_ptr> +mlir::torch::Torch::createDecomposeComplexOpsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 5d641269c..42a8ea33c 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -90,7 +90,7 @@ public: if (auto copyToValueTensor = dyn_cast(op)) { copyToValueTensorOps.push_back(copyToValueTensor); } else if (isa(op)) { + AtenTransposeIntOp, TensorStaticInfoCastOp>(op)) { viewLikeOps.push_back(op); llvm::append_range(workList, op->getResult(0).getUsers()); } else { diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index a7891cee6..164355df4 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -119,9 +119,17 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( // Do shape and dtype refinement. pm.addNestedPass(Torch::createRefineTypesPass()); + // Propagate to ABI return types the shape/dtype information discovered by // the previous pass. Doing this is ABI-compatible for our backends. pm.addPass(Torch::createRefinePublicReturnPass()); + + if (options.optimize) { + // This can fold away some branches given the information got from + // RefineTypes before doing maximize value sematics which only works with + // basic blocks. + pm.addNestedPass(createCanonicalizerPass()); + } // Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's. pm.addNestedPass(Torch::createMaximizeValueSemanticsPass()); @@ -134,6 +142,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( // only-used-in-training operations on `torch.global_slot`'s. pm.addNestedPass(createCanonicalizerPass()); } + pm.addNestedPass(Torch::createDecomposeComplexOpsPass()); // TODO: VerifyTorchBackendContractPass. } diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 9d46ecc06..8db7cefbe 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -19,6 +19,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; @@ -263,7 +264,8 @@ public: } else if (auto arangeStart = dyn_cast(op)) { return visitAtenArangeStartOp(arangeStart); } else if (auto sum = dyn_cast(op)) { - return visitReductionAlongAllDimsOp(sum, operands); + Type dtype = operands[0]->getValue().dtype; + return visitReductionAlongAllDimsOp(sum, dtype, operands); } else if (auto sumDimIntList = dyn_cast(op)) { return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(), sumDimIntList.keepdim(), operands); @@ -271,9 +273,17 @@ public: return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(), meanDim.keepdim(), operands); } else if (auto argmax = dyn_cast(op)) { - return visitAtenArgmaxOp(argmax, operands); + Value dim = argmax.dim(); + Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); + if (dim.getType().isa()) + return visitReductionAlongAllDimsOp(op, dtype, operands); + if (dim.getType().isa()) + return visitReductionAlongDimIntOp(argmax, argmax.dim(), + argmax.keepdim(), dtype, operands); } else if (auto anyDim = dyn_cast(op)) { - return visitAtenAnyDimOp(anyDim, operands); + Type dtype = operands[0]->getValue().dtype; + return visitReductionAlongDimIntOp(anyDim, anyDim.dim(), anyDim.keepdim(), + dtype, operands); } else if (auto view = dyn_cast(op)) { return visitReshapeLikeOp(view, operands); } else if (auto resize = dyn_cast(op)) { @@ -353,6 +363,8 @@ public: return visitAtenEmbeddingOp(embedding, operands); } else if (auto bmm = dyn_cast(op)) { return visitAtenBmmOp(bmm, operands); + } else if (auto softmaxIntOp = dyn_cast(op)) { + return visitAtenSoftmaxIntOp(softmaxIntOp, operands); } // Otherwise, this is an unknown operation. Just mark all results as @@ -394,16 +406,14 @@ private: ChangeResult visitAtenArangeStartOp(AtenArangeStartOp op); ChangeResult visitAtenArangeOp(AtenArangeOp op); ChangeResult visitReductionAlongAllDimsOp( - Operation *op, ArrayRef *> operands); + Operation *op, Type dtype, + ArrayRef *> operands); ChangeResult visitReductionAlongDimIntListOp( Operation *op, Value dim, Value keepdim, ArrayRef *> operands); - ChangeResult - visitAtenArgmaxOp(AtenArgmaxOp op, - ArrayRef *> operands); - ChangeResult - visitAtenAnyDimOp(AtenAnyDimOp op, - ArrayRef *> operands); + ChangeResult visitReductionAlongDimIntOp( + Operation *op, Value dim, Value keepdim, Type dtype, + ArrayRef *> operands); template ChangeResult visitReshapeLikeOp(OpTy op, @@ -448,27 +458,34 @@ private: ChangeResult visitAtenBmmOp(AtenBmmOp op, ArrayRef *> operands); + ChangeResult + visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op, + ArrayRef *> operands); }; } // namespace -static int64_t toPositiveDim(int64_t dim, int64_t inputRank) { - return dim >= 0 ? dim : dim + inputRank; -} - -static bool isValidDim(int64_t dim, int64_t inputRank) { - return dim >= 0 && dim < inputRank; +// Get the MLIR type of the tensor dtype given the dtype integer value and the +// input dtype. When DType is None the type is inferred from the input dtype. +static void fillInDTypeGivenDTypeIntAndInputDType(MLIRContext *context, + ValueKnowledge &knowledge, + Value dtype, + Type inputDType) { + int64_t dtypeInt; + if (dtype.getType().isa()) + knowledge.dtype = inputDType; + else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + knowledge.dtype = getTypeFromDTypeInteger(context, dtypeInt); } // Get the MLIR type of the tensor dtype given the dtype integer value and data -// type. When DType is None the type is inferred from the data type. +// type of torch type. When DType is None the type is inferred from the data +// type. static void fillInDTypeGivenDTypeAndDataType(MLIRContext *context, ValueKnowledge &knowledge, Value dtype, Type dataType) { - int64_t dtypeInt; - if (dtype.getType().isa()) - knowledge.dtype = getDTypeFromTorchType(context, dataType); - else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) - knowledge.dtype = getTypeFromDTypeInteger(context, dtypeInt); + Type dtypeFromDataType = getDTypeFromTorchType(context, dataType); + fillInDTypeGivenDTypeIntAndInputDType(context, knowledge, dtype, + dtypeFromDataType); } static void fillInSizesGivenSizesList(ValueKnowledge &knowledge, Value sizes) { @@ -718,10 +735,10 @@ ChangeResult TypeAnalyzer::visitAtenArangeOp(AtenArangeOp op) { } ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp( - Operation *op, ArrayRef *> operands) { - auto input = operands[0]->getValue(); + Operation *op, Type dtype, + ArrayRef *> operands) { auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); - knowledge.dtype = input.dtype; + knowledge.dtype = dtype; // Reduction along all dims always results in a tensor of rank zero, // which is represented by the default empty `knowledge.sizes` vector knowledge.hasSizes = true; @@ -764,53 +781,28 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp( } return getLatticeElement(op->getResult(0)).join(knowledge); } -ChangeResult TypeAnalyzer::visitAtenArgmaxOp( - AtenArgmaxOp op, ArrayRef *> operands) { - auto input = operands[0]->getValue(); - auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); - knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); - int64_t dim; - bool keepDim; - if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { - int64_t inputRank = input.sizes.size(); - knowledge.hasSizes = true; - if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) { - knowledge.sizes = input.sizes; - dim = toPositiveDim(dim, inputRank); - if (isValidDim(dim, inputRank)) { - if (keepDim) - knowledge.sizes[dim] = 1; - else - knowledge.sizes.erase(knowledge.sizes.begin() + dim); - } - } else if (op.dim().getType().isa()) - knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, - kUnknownSize); - } - // If dim is no kind of Integer, keepDim is ignored, - // and the result will bea rank-0 tensor - return getLatticeElement(op->getResult(0)).join(knowledge); -} -ChangeResult TypeAnalyzer::visitAtenAnyDimOp( - AtenAnyDimOp op, ArrayRef *> operands) { +ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp( + Operation *op, Value dim, Value keepdim, Type dtype, + ArrayRef *> operands) { + assert(dim.getType().isa() && "dim must be int type"); auto input = operands[0]->getValue(); auto knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); - knowledge.dtype = input.dtype; - int64_t dim; + knowledge.dtype = dtype; + int64_t dimInt; bool keepDim; - if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { + if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) { int64_t inputRank = input.sizes.size(); knowledge.hasSizes = true; - if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { knowledge.sizes = input.sizes; - dim = toPositiveDim(dim, inputRank); - if (isValidDim(dim, inputRank)) { + dimInt = toPositiveDim(dimInt, inputRank); + if (isValidDim(dimInt, inputRank)) { if (keepDim) - knowledge.sizes[dim] = 1; + knowledge.sizes[dimInt] = 1; else - knowledge.sizes.erase(knowledge.sizes.begin() + dim); + knowledge.sizes.erase(knowledge.sizes.begin() + dimInt); } } else { knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize); @@ -1081,6 +1073,19 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp( return getLatticeElement(op.getResult()).join(knowledge); } +ChangeResult TypeAnalyzer::visitAtenSoftmaxIntOp( + AtenSoftmaxIntOp op, ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto dtype = op.dtype(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + knowledge.hasSizes = input.hasSizes; + knowledge.sizes = input.sizes; + fillInDTypeGivenDTypeIntAndInputDType(op->getContext(), knowledge, dtype, + input.dtype); + return getLatticeElement(op.getResult()).join(knowledge); +} + ChangeResult TypeAnalyzer::visitAtenBmmOp( AtenBmmOp op, ArrayRef *> operands) { auto knowledge = diff --git a/lib/Dialect/Torch/Utils/CMakeLists.txt b/lib/Dialect/Torch/Utils/CMakeLists.txt new file mode 100644 index 000000000..18fc94edd --- /dev/null +++ b/lib/Dialect/Torch/Utils/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect_library(TorchMLIRTorchUtils + Utils.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Utils + ) diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp new file mode 100644 index 000000000..f74ad4f24 --- /dev/null +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +namespace mlir { +namespace torch { +namespace Torch { + +int64_t toPositiveDim(int64_t dim, int64_t inputRank) { + return dim >= 0 ? dim : dim + inputRank; +} + +bool isValidDim(int64_t dim, int64_t inputRank) { + return dim >= 0 && dim < inputRank; +} + +} // namespace Torch +} // namespace torch +} // namespace mlir diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index a8898c071..3771a85a8 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -45,6 +45,8 @@ static bool isArgMemRefTypeValid(Type type) { Type elemTy = memRefType.getElementType(); if (elemTy.isa()) { return true; + } else if (elemTy.isa()) { + return true; } else if (auto integerTy = elemTy.dyn_cast()) { if (integerTy.isSignlessInteger(64)) return true; @@ -81,7 +83,8 @@ static LogicalResult mungeFunction( for (auto arg : func.getArguments()) { auto type = arg.getType(); if (!isArgMemRefTypeValid(type)) - return emitError(arg.getLoc(), "argument must be a memref of f32 or i64"); + return emitError(arg.getLoc(), + "argument must be a memref of f32, f64, i64"); auto cast = b.create(arg.getLoc(), arg, type); arg.replaceAllUsesExcept(cast, cast); arg.setType(getAbiTypeForMemRef(type)); @@ -91,12 +94,17 @@ static LogicalResult mungeFunction( SmallVector toErase; bool hadError = false; func.walk([&](ReturnOp op) { - auto returnType = - op.getOperandTypes()[0].dyn_cast().getElementType(); + auto memRefType = op.getOperandTypes()[0].dyn_cast(); + if (!memRefType) { + hadError = true; + op.emitError("return value must be memref type"); + return; + } + auto returnType = memRefType.getElementType(); auto it = consumeFuncReturnFuncs.find(returnType); if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) { hadError = true; - op.emitError("must have one return value: a memref of f32 or i64"); + op.emitError("must have one return value: a memref of f32, i64 or f64"); return; } @@ -126,26 +134,27 @@ class MungeCallingConventions void runOnOperation() override { auto module = getOperation(); OpBuilder b(module.getBodyRegion()); - auto consumeFuncReturnInt64Func = b.create( - module.getLoc(), "refbackend_consume_int64_func_return", - FunctionType::get( - module.getContext(), - UnrankedMemRefType::get(b.getI64Type(), /*memorySpace=*/0), {}), - b.getStringAttr("private")); - auto consumeFuncReturnFloat32Func = b.create( - module.getLoc(), "refbackend_consume_float32_func_return", - FunctionType::get( - module.getContext(), - UnrankedMemRefType::get(b.getF32Type(), /*memorySpace=*/0), {}), - b.getStringAttr("private")); - addEmitCInterfaceAttr(consumeFuncReturnInt64Func); - addEmitCInterfaceAttr(consumeFuncReturnFloat32Func); DenseMap consumeFuncReturnFuncs; - consumeFuncReturnFuncs[b.getF32Type()] = consumeFuncReturnFloat32Func; - consumeFuncReturnFuncs[b.getI64Type()] = consumeFuncReturnInt64Func; + DenseSet consumeFuncReturnFuncsSet; + auto createConsumeFuncReturnFunc = [&](Type elemTy, std::string funcName) { + auto consumeFuncReturnFunc = b.create( + module.getLoc(), funcName, + FunctionType::get(module.getContext(), + UnrankedMemRefType::get(elemTy, /*memorySpace=*/0), + {}), + b.getStringAttr("private")); + addEmitCInterfaceAttr(consumeFuncReturnFunc); + consumeFuncReturnFuncs[elemTy] = consumeFuncReturnFunc; + consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc); + }; + createConsumeFuncReturnFunc(b.getI64Type(), + "refbackend_consume_int64_func_return"); + createConsumeFuncReturnFunc(b.getF32Type(), + "refbackend_consume_float32_func_return"); + createConsumeFuncReturnFunc(b.getF64Type(), + "refbackend_consume_float64_func_return"); for (auto func : module.getOps()) { - if (func == consumeFuncReturnInt64Func || - func == consumeFuncReturnFloat32Func) + if (consumeFuncReturnFuncsSet.contains(func)) continue; if (failed(mungeFunction(func, consumeFuncReturnFuncs))) return signalPassFailure(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index c8daa0c9b..06692db52 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -482,6 +482,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit( "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" ) + emit( + "aten::softmax.int : (Tensor, int, int?) -> (Tensor)" + ) emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") @@ -525,7 +528,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::repeat : (Tensor, int[]) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)") - emit("aten::size.int : (Tensor, int) -> (int)") + emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::sum : (Tensor, int?) -> (Tensor)") emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 037568246..899ae2c70 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -24,13 +24,8 @@ __all__ = [ def checkArgTypeIsSupported(ty): - if ty == np.float32: - return - elif ty == np.int64: - return - assert False, "Only tensor argument of float32 and int64 are supported but got " + str( - ty) - + SUPPORTED = [np.float32, np.float64, np.int64] + assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported" class RefBackendInvoker: def __init__(self, module): @@ -45,12 +40,19 @@ class RefBackendInvoker: def consume_f32_return(a): self.result = unranked_memref_to_numpy(a, np.float32) + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def consume_f64_return(a): + self.result = unranked_memref_to_numpy(a, np.float64) + self.ee.register_runtime("refbackend_consume_int64_func_return", consume_i64_return) self.ee.register_runtime("refbackend_consume_float32_func_return", consume_f32_return) + self.ee.register_runtime("refbackend_consume_float64_func_return", + consume_f64_return) + def __getattr__(self, function_name: str): def invoke(*args): ffi_args = [] diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 0ae655e9a..63c87418e 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -489,3 +489,52 @@ func @torch.prim.dtype$int64(%t : !torch.tensor<*,si64>) -> !torch.int { %ret = torch.prim.dtype %t: !torch.tensor<*,si64> -> !torch.int return %ret : !torch.int } + +// CHECK-LABEL: func @torch.aten.size.int$neg_dim( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int { +// CHECK: %[[RET:.*]] = torch.constant.int 2 +// CHECK: return %[[RET]] : !torch.int +func @torch.aten.size.int$neg_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int { + %int-2 = torch.constant.int -2 + %ret = torch.aten.size.int %t, %int-2 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int + return %ret : !torch.int +} + +// CHECK-LABEL: func @torch.aten.size.int$pos_dim( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int { +// CHECK: %[[RET:.*]] = torch.constant.int 3 +// CHECK: return %[[RET]] : !torch.int +func @torch.aten.size.int$pos_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int { + %int1 = torch.constant.int 1 + %ret = torch.aten.size.int %t, %int1 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int + return %ret : !torch.int +} + +// CHECK-LABEL: func @torch.aten.size.int$invalid_dim( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.int { +// CHECK: %[[CST3:.*]] = torch.constant.int 3 +// CHECK: %[[RET:.*]] = torch.aten.size.int %[[T]], %[[CST3]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int +// CHECK: return %[[RET]] : !torch.int +func @torch.aten.size.int$invalid_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.int { + %int3 = torch.constant.int 3 + %ret = torch.aten.size.int %t, %int3 : !torch.tensor<[2,3],f32>, !torch.int -> !torch.int + return %ret : !torch.int +} + +// CHECK-LABEL: func @torch.tensor_static_info_cast$downcast_first( +// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: return %[[T]] : !torch.tensor +func @torch.tensor_static_info_cast$downcast_first(%t: !torch.tensor) -> !torch.tensor { + %downcast = torch.tensor_static_info_cast %t : !torch.tensor to !torch.tensor<[?,?],f64> + %upcast = torch.tensor_static_info_cast %downcast : !torch.tensor<[?,?],f64> to !torch.tensor + return %upcast: !torch.tensor +} + +// CHECK-LABEL: func @torch.tensor_static_info_cast$upcast_first( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f64>) -> !torch.tensor<[?,?],f64> { +// CHECK: return %[[T]] : !torch.tensor<[?,?],f64> +func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) -> !torch.tensor<[?,?],f64> { + %upcast = torch.tensor_static_info_cast %t : !torch.tensor<[?,?],f64> to !torch.tensor + %downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64> + return %downcast: !torch.tensor<[?,?],f64> +} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index ad43a5498..2173d196f 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -925,3 +925,33 @@ builtin.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indic %ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor return %ret: !torch.tensor } + +// ---- +// CHECK-LABEL: func @torch.aten.softmax.int( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, +// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { +// CHECK: %[[DTYPE:.*]] = torch.constant.none +// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32> +// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor +// CHECK: return %[[RET]] : !torch.tensor +func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { + %none = torch.constant.none + %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor + return %ret : !torch.tensor +} + + +// ---- +// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype( +// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, +// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { +// CHECK: %[[DTYPE:.*]] = torch.constant.int 4 +// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],si64> +// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],si64> to !torch.tensor +// CHECK: return %[[RET]] : !torch.tensor +// CHECK: } +func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { + %int4 = torch.constant.int 4 + %ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor + return %ret : !torch.tensor +}