//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Helper function to check whether the `dtype` is None or Float type. static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { if (dtype.getType().isa()) return true; int64_t dtypeInt; if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) return false; FailureOr resDtype = getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); if (failed(resDtype)) return false; return resDtype->isa(); } // Helper function to compute the return type of the reduction function. // `dim` specifies the dimension to reduce and `keepDim` preserves the rank of // the input tensor. static Type computeReductionType(PatternRewriter &rewriter, Operation *op, BaseTensorType tensorType, Value dim, bool keepDim) { SmallVector sizes; int64_t dimInt; if (tensorType.hasSizes()) { ArrayRef inputShape = tensorType.getSizes(); int64_t inputRank = inputShape.size(); if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) { (void)rewriter.notifyMatchFailure(op, "dim is not a valid dim"); return nullptr; } sizes.append(inputShape.begin(), inputShape.end()); // The dimension to be reduced is set to 1 when `keepDim` is true else it // is removed. if (keepDim) sizes[dimInt] = 1; else sizes.erase(sizes.begin() + dimInt); } else { unsigned reducedRank = keepDim ? inputRank : inputRank - 1; sizes.resize(reducedRank, kUnknownSize); } } Type resultType = tensorType.getWithSizesAndDtype( sizes.size() == 0 ? std::optional>() : llvm::ArrayRef(sizes), tensorType.getOptionalDtype()); return resultType; } // Reduction function to calculate sum along given `dim`. static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { Value dimList = rewriter.create( loc, Torch::ListType::get(dim.getType()), dim); Value keepDimCst = rewriter.create(loc, keepDim); Value dtype = rewriter.create(loc); Type resultType = computeReductionType( rewriter, op, input.getType().cast(), dim, keepDim); if (!resultType) return nullptr; return rewriter.create(loc, resultType, input, dimList, keepDimCst, dtype); } // Redunction function to calculate max along given `dim`. static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { Value keepDimCst = rewriter.create(loc, keepDim); BaseTensorType valueType = computeReductionType(rewriter, op, input.getType().cast(), dim, keepDim) .cast(); if (!valueType) return nullptr; BaseTensorType indexType = valueType .getWithSizesAndDtype( !valueType.hasSizes() ? std::optional>() : llvm::ArrayRef(valueType.getSizes()), IntegerType::get(op->getContext(), 64, IntegerType::Signed)) .cast(); return rewriter .create(loc, valueType, indexType, input, dim, keepDimCst) .getValues(); } // Helper for creating `aten::sub_tensor_op`. static Value createTensorSub(PatternRewriter &rewriter, Location loc, Type tensorType, Value lhs, Value rhs) { Value alpha = rewriter.create(loc, rewriter.getF64FloatAttr(1)); Value sub = rewriter.create(loc, tensorType, lhs, rhs, alpha); return sub; } // Helper to create a tensor filled with the given scalar. Scalar would be // converted the to the element type of the given tensor type. static Value createInitTensor(PatternRewriter &rewriter, Location loc, BaseTensorType resultType, Value scalar, Value sizeList) { assert(resultType.hasDtype() && "result must have dtype"); Value noneVal = rewriter.create(loc); Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); return rewriter.create(loc, resultType, sizeList, scalar, dtype, /*layout=*/noneVal, /*device=*/noneVal, /*memory_format=*/noneVal); } // Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` // would be converted to the element type of the given `inputType`. static Value createRank0Tensor(PatternRewriter &rewriter, Location loc, BaseTensorType inputType, Value scalar) { assert(inputType.hasDtype() && "input must have dtype"); SmallVector sizes; BaseTensorType rank0TensorTy = inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) .cast(); Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), ValueRange{}); return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); } // Share code between `softmax_backward` and `log_softmax_backward` ops. // Returns x - y * sum(z, dim). static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, Location loc, Operation *op, Type tensorType, Value x, Value y, Value z, Value dim) { Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true); if (!sum) return nullptr; auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(op->getContext())); Value broadcastSize = rewriter.create(loc, broadcastSizeType, z); Value sumBroadcast = rewriter.create(loc, tensorType, sum, broadcastSize); Value temp = rewriter.create(loc, tensorType, y, sumBroadcast); Value sub = createTensorSub(rewriter, loc, tensorType, x, temp); return sub; } static SmallVector computeDimsOrderForMoveDim(int64_t srcDimInt, int64_t dstDimInt, unsigned inputRank) { llvm::iota_range dimsOrderIR(0, inputRank, /*inclusive=*/false); SmallVector dimsOrder(dimsOrderIR.begin(), dimsOrderIR.end()); dimsOrder.erase(dimsOrder.begin() + srcDimInt); dimsOrder.insert(dimsOrder.begin() + dstDimInt, srcDimInt); return dimsOrder; } namespace { /// We decompose aten.amax into a set of aten.max.dim op(s) depending on the /// number of dimensions across which the max needs to be computed. /// Eg: /// INPUT: /// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) /// /// OUTPUT: /// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 /// input_2 = aten.max.dim(input_1, 1, keepdim) #2 /// final_output = aten.max.dim(input_2, 0, keepdim) #3 /// /// NOTE: We iterate over, in reverse order, every dimension included in `dim` /// of the `aten.amax` op and create an `aten.amax.dim` op. /// Input tensor to the next `aten.amax.dim` op is thus the output of the /// previous `aten.amax.dim` op. class DecomposeAtenAmaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenAmaxOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); SmallVector dims; if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); bool keepDim; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) return rewriter.notifyMatchFailure( op, "Expected a constant boolean value for keepDim"); Value input = op.getSelf(); auto inputTy = input.getType().dyn_cast(); if (!inputTy || !inputTy.hasSizes()) { return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); } // For every dimension included in `dim` of the op, iterated over in // reverse order, we create a call to aten.max.dim. std::sort(dims.begin(), dims.end()); std::reverse(dims.begin(), dims.end()); for (int64_t dimInt : dims) { int64_t inputRank = inputTy.getSizes().size(); dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(dimInt)); // The input to the next invocation of aten.max.dim is the output of the // previous aten.max.dim op. input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); } rewriter.replaceOp(op, input); return success(); } }; } // end namespace namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSizeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; SmallVector sizes; for (unsigned i = 0; i < rank; i++) { Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); sizes.push_back(rewriter.create(loc, self, dim)); } Value sizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), sizes); rewriter.replaceOp(op, sizeList); return success(); } }; } // namespace namespace { class DecomposeAtenSelectIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSelectIntOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value start = op.getIndex(); Value dim = op.getDim(); Value self = op.getSelf(); // convert `start` to non-negative: start += int(start < 0) * dimSize Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value isNegative = rewriter.create(loc, start, zero); isNegative = rewriter.create(loc, isNegative); Value dimSize = rewriter.create(loc, self, dim); Value indexOffset = rewriter.create(loc, isNegative, dimSize); start = rewriter.create(loc, start, indexOffset); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusOne = rewriter.create(loc, one.getType(), start, one); Value slice = rewriter.create( loc, computeReductionType(rewriter, op, self.getType().cast(), dim, /*keepDim=*/true), op.getSelf(), dim, start, startPlusOne, /*step=*/one); // `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after // slicing, while `aten.select.int` does. rewriter.replaceOpWithNewOp(op, op.getResult().getType(), slice, op.getDim()); return success(); } }; } // namespace namespace { class DecomposeAtenNarrowOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNarrowOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value start = op.getStart(); Value dim = op.getDim(); Value length = op.getLength(); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusLength = rewriter.create(loc, one.getType(), start, length); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, /*start=*/start, /*end=*/startPlusLength, /*step=*/one); return success(); } }; } // namespace namespace { // Decompose `aten.narrow.Tensor` to `aten.narrow` op class DecomposeAtenNarrowTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNarrowTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto *context = op.getContext(); // PyTorch makes sure that `start` param is an 0-dim integral tensor. // REF: https://pytorch.org/docs/stable/generated/torch.narrow.html. auto start = rewriter.create( loc, Torch::IntType::get(context), op.getStart()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getDim(), start, op.getLength()); return success(); } }; } // namespace namespace { class DecomposeAtenZeroOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenZeroOp op, PatternRewriter &rewriter) const override { Value zero = rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), zero); return success(); } }; } // namespace namespace { class DecomposeAtenIsnanOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIsnanOp op, PatternRewriter &rewriter) const override { Value input = op.getSelf(); // Create a new aten.ne operation with the same type and input value. rewriter.replaceOpWithNewOp(op, op.getType(), input, input); return success(); } }; } // namespace namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenReshapeOp op, PatternRewriter &rewriter) const override { Value input = op.getSelf(); // TODO: Handle non value tensor type operands. if (!input.getType().isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: only value tensor type operands are supported"); } rewriter.replaceOpWithNewOp(op, op.getType(), input, op.getShape()); return success(); } }; } // namespace // Calculates the softmax function on the given `input` tensor. Softmax(x) = // exp(x)/sum(exp(x)). // To avoid overflow we use the following decomposition rule: // x_max = max(input, dim, keepdim = True) // unnorm = aten.exp(input - x_max) // softmax = unnorm / sum(unnorm, dim, keepdim = True) template static Value getSoftmaxResult(OpTy op, Value self, Type resultType, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value dim = op.getDim(); Value xMax = createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); if (!xMax) return nullptr; Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax); Value unNormalizedExp = rewriter.create(loc, resultType, unNormalized); Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim, /*keepDim=*/true); if (!sum) return nullptr; return rewriter.create(loc, resultType, unNormalizedExp, sum); } // Decompose softmax into: exp(x) / sum(exp(x)) namespace { class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType resultTensorType = op.getType().cast(); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Type resultTensorDtype = resultTensorType.getDtype(); if (!resultTensorDtype.isa()) return rewriter.notifyMatchFailure(op, "Only support floating-point type"); // If `dtype` arg is non-none then convert the input to `dtype`. if (!op.getDtype().getType().isa()) { Location loc = op.getLoc(); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); self = rewriter.create( loc, resultTensorType, self, getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); if (!result) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), result); return success(); } }; } // namespace namespace { class DecomposeAten_SoftmaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_SoftmaxOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType tensorType = self.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); bool halfToFloat; if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat))) return rewriter.notifyMatchFailure( op, "Expected a boolean value for half_to_float"); BaseTensorType resultTensorType = op.getType().cast(); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Type resultTensorDtype = resultTensorType.getDtype(); // `torch.ops.aten._softmax`'s softmax with half to float conversion is not // supported on CPU, but we go ahead with the decomposing. // TODO: Add an e2e test once upstream support is added. // If `half_to_float` is set, we convert the input's elemental type to match // that of output's. if (halfToFloat) { Location loc = op.getLoc(); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); self = rewriter.create( loc, resultTensorType, self, getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); if (!result) return op.emitError("failed to get softmax result"); rewriter.replaceOpWithNewOp(op, resultTensorType, result); return success(); } }; } // namespace // Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) => // newGrad = gradOutput * output // result = newGrad - output * sum(newGrad, dim)) // // Refer to // https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31 namespace { class DecomposeAten_SoftmaxBackwardDataOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value gradOutput = op.getGradOutput(); Value output = op.getOutput(); Value dim = op.getDim(); BaseTensorType tensorType = gradOutput.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value newGrad = rewriter.create(loc, tensorType, gradOutput, output); Value result = createSoftmaxBackwardCommonKernel( rewriter, loc, op, tensorType, newGrad, output, newGrad, dim); if (!result) return rewriter.notifyMatchFailure( op, "nullptr returned by createSoftmaxBackwardCommonKernel function."); rewriter.replaceOp(op, result); return success(); } }; } // namespace // AtenTanhBackwardOp(gradOutput, output) => // result = gradOutput * (1 - output^2) // To get away from broadcasts the above formula is expanded i.e., // result = gradOutput - (gradOutput * output^2) namespace { class DecomposeAtenTanhBackwardOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTanhBackwardOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value gradOutput = op.getGradOutput(); // `output` is the value flowing out from tanh. Hence, tanh(x) = output. // Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2). Value output = op.getOutput(); BaseTensorType tensorType = gradOutput.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value tanhSquare = rewriter.create(loc, tensorType, output, output); Value gradMulTanhSquare = rewriter.create( loc, tensorType, tanhSquare, gradOutput); Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput, gradMulTanhSquare); rewriter.replaceOp(op, newGrad); return success(); } }; } // namespace // Aten_LogSoftmaxBackwardDataOp(gradOutput, output, dim) => // result = gradOutput - (exp(output) * sum(gradOutput, dim)) namespace { class DecomposeAten_LogSoftmaxBackwardDataOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value gradOutput = op.getGradOutput(); Value output = op.getOutput(); Value dim = op.getDim(); BaseTensorType tensorType = gradOutput.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value expOut = rewriter.create(loc, tensorType, output); Value result = createSoftmaxBackwardCommonKernel( rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim); if (!result) return rewriter.notifyMatchFailure( op, "nullptr returned by createSoftmaxBackwardCommonKernel function."); rewriter.replaceOp(op, result); return success(); } }; } // namespace // Decompose `AtenArgMaxOp` into `AtenMaxDimOp`. namespace { class DecomposeAtenArgMaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenArgmaxOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value dim = op.getDim(); Value keepDim = op.getKeepdim(); Value result = op.getResult(); BaseTensorType inputType = input.getType().cast(); BaseTensorType indicesTensorType = result.getType().cast(); std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a rank"); } unsigned inputRank = *maybeInputRank; if (!indicesTensorType.hasSizes()) return failure(); BaseTensorType valueTensorType = inputType .getWithSizesAndDtype(indicesTensorType.getOptionalSizes(), inputType.getOptionalDtype()) .cast(); // If the dim type is `NoneType` i.e. reduce along all the dimensions. // `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input // tensor is flattened to 1d tensor and then the reduction happens on the // 0th dimension. if (dim.getType().isa()) { BaseTensorType flattenType = inputType .getWithSizesAndDtype({kUnknownSize}, inputType.getOptionalDtype()) .cast(); dim = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(inputRank - 1)); input = rewriter.create(loc, flattenType, input, dim, end); } Value maxResult = rewriter .create(loc, valueTensorType, indicesTensorType, input, dim, keepDim) .getIndices(); rewriter.replaceOp(op, maxResult); return success(); } }; } // namespace // Decompose `aten.bucketize` into the following op sequence: // // def aten_bucketize(input, boundaries, out_int32, right): // unsqz_input = input.unsqueeze(-1) // if not right: // comparison = unsqz_input <= boundaries // else: // comparison = unsqz_input < boundaries // indices = torch.argmax(comparison.float(), dim=-1) // within_bound = comparison[..., -1] // result = torch.where(within_bound, indices, boundaries.shape[0]) // if out_int32: // result = result.int() // return result // namespace { class DecomposeAtenBucketizeTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBucketizeTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); auto inputType = input.getType().cast(); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "unimplemented: input must have known sizes"); } ArrayRef inputShape = inputType.getSizes(); Value boundaries = op.getBoundaries(); auto boundariesType = boundaries.getType().cast(); if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) { return rewriter.notifyMatchFailure(op, "unimplemented: boundaries must have " "known sizes and must be a 1D array"); } int64_t boundariesSize = boundariesType.getSizes()[0]; bool outInt32; if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) { return rewriter.notifyMatchFailure( op, "unimplemented: out_int32 must be a constant bool"); } bool right; if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) { return rewriter.notifyMatchFailure( op, "unimplemented: right must be a constant bool"); } // unsqueeze input at the last dim to make it broadcastable with boundaries Value constMinusOne = rewriter.create( loc, rewriter.getI64IntegerAttr(-1)); auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne); if (failed(unsqzTensorInfo)) { return rewriter.notifyMatchFailure(op, "cannot generate unsqueeze tensor"); } Value unsqzInput = *unsqzTensorInfo; // compare unsqueezed input with boundaries SmallVector compareShape(inputShape); compareShape.push_back(boundariesSize); Type compareType = inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type()); Value compare; if (!right) { compare = rewriter.create(loc, compareType, unsqzInput, boundaries); } else { compare = rewriter.create(loc, compareType, unsqzInput, boundaries); } // convert the comparison results to float32 as the argmax op input, // which does not support integer dtype in LINALG backend Value compareF32 = convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type()); // get the first boundary index where the input element is less than (or // equal to) the boundary value Type indicesType = inputType.getWithSizesAndDtype( inputShape, rewriter.getIntegerType(64, IntegerType::Signed)); Value constFalse = rewriter.create(loc, false); Value indices = rewriter.create(loc, indicesType, compareF32, /*dim=*/constMinusOne, /*keepdim=*/constFalse); // get the comparison results between each input element and the rightmost // boundary value Type withinUpperBoundType = inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type()); Value withinUpperBound = rewriter.create( loc, withinUpperBoundType, compare, /*dim=*/constMinusOne, /*index=*/constMinusOne); // If the input element is less than (or equal to) the rightmost boundary, // take the max index as result. Otherwise, the element is beyond the // rightmost boundary, so take the boundary size. Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value upperBound = rewriter.create(loc, boundaries, /*dim=*/constZero); Value result = rewriter.create( loc, indicesType, withinUpperBound, indices, upperBound); if (outInt32) { result = convertTensorToDtype( rewriter, loc, result, rewriter.getIntegerType(32, IntegerType::Signed)); } rewriter.replaceOp(op, result); return success(); } }; } // namespace // To avoid overflow we use the following decomposition rule: // x_max = aten.max(x, dim, keepdim=True)[0] // shifted = x - x_max // shifted_logsumexp = aten.log(aten.sum(aten.exp(shifted), dim, keepdim=True)) // log_softmax = shifted - shifted_logsumexp template static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value dim = op.getDim(); Value self = op.getSelf(); BaseTensorType tensorType = self.getType().cast(); Value xMax = createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); if (!xMax) return nullptr; Value shifted = createTensorSub(rewriter, loc, tensorType, self, xMax); Value shiftedExp = rewriter.create(loc, tensorType, shifted); Value shiftedSumExp = createSumAlongDimension(rewriter, loc, op, shiftedExp, dim, /*keepDim=*/true); if (!shiftedSumExp) return nullptr; Value shiftedLogSumExp = rewriter.create(loc, shiftedSumExp.getType(), shiftedSumExp); Value result = createTensorSub(rewriter, loc, op.getType(), shifted, shiftedLogSumExp); return result; } namespace { class DecomposeAtenLogSoftmaxIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); if (!op.getDtype().getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for log_softmax"); BaseTensorType tensorType = self.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value logSoftmax = getLogSoftmaxResult(op, rewriter); if (!logSoftmax) return rewriter.notifyMatchFailure( op, "getLogSoftmaxResult function returned nullptr"); rewriter.replaceOp(op, logSoftmax); return success(); } }; } // namespace namespace { class DecomposeAten_LogSoftmaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_LogSoftmaxOp op, PatternRewriter &rewriter) const override { bool halfToFloat; if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat))) return rewriter.notifyMatchFailure( op, "Expected a boolean value for half_to_float"); // Currently, setting `halfToFloat` is not supported as the E2E testing for // the same is not present on CPU. if (halfToFloat) return rewriter.notifyMatchFailure( op, "halfToFloat is currently not supported."); Value _logSoftmax = getLogSoftmaxResult(op, rewriter); if (!_logSoftmax) return rewriter.notifyMatchFailure( op, "getLogSoftmaxResult function returned nullptr"); rewriter.replaceOp(op, _logSoftmax); return success(); } }; } // namespace // Decompose aten.matmul into: aten.mm and aten.bmm according to ranks. namespace { class DecomposeAtenMatmulOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMatmulOp op, PatternRewriter &rewriter) const override { Value lhs = op.getSelf(); Value rhs = op.getOther(); std::optional maybeLhsRank = getTensorRank(lhs); std::optional maybeRhsRank = getTensorRank(rhs); if (!maybeLhsRank || !maybeRhsRank) { return rewriter.notifyMatchFailure( op, "expected input tensors to have a rank"); } unsigned lhsRank = *maybeLhsRank; unsigned rhsRank = *maybeRhsRank; if (lhsRank == 2 && rhsRank == 2) { // If both lhs and rhs ranks are 2 then map it to `aten.mm` op. rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); } else if (lhsRank == 3 && rhsRank == 3) { // If both lhs and rhs ranks are 3 then map it to `aten.bmm` op. rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); } else { return failure(); } return success(); } }; } // namespace // Decompose aten.mv into: aten.matmul. namespace { class DecomposeAtenMvOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMvOp op, PatternRewriter &rewriter) const override { Value lhs = op.getSelf(); Value rhs = op.getVec(); rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); return success(); } }; } // namespace // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) static Value getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { BaseTensorType inputType = input.getType().cast(); Value relu = rewriter.create(loc, inputType, input); Value cst6 = rewriter.create(loc, rewriter.getI64IntegerAttr(6)); Value sixTensor = createRank0Tensor(rewriter, loc, inputType, cst6); Value relu6Out = rewriter.create(loc, inputType, relu, sixTensor); return relu6Out; } namespace { class DecomposeAtenRelu6Op : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRelu6Op op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value relu6 = getRelu6Results(rewriter, loc, op.getSelf()); rewriter.replaceOp(op, relu6); return success(); } }; } // namespace // Hardswish(x) = x * Relu6(x+3)/6 namespace { class DecomposeAtenHardswishOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenHardswishOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Type inputType = input.getType(); Value constantOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value constantThree = rewriter.create( loc, rewriter.getI64IntegerAttr(3)); Value constantSix = rewriter.create( loc, rewriter.getI64IntegerAttr(6)); Value inputPlusThree = rewriter.create( loc, inputType, input, constantThree, /*alpha=*/constantOne); Value relu6 = getRelu6Results(rewriter, loc, inputPlusThree); Value divTensor = rewriter.create(loc, inputType, relu6, constantSix); Value mulTensor = rewriter.create(loc, inputType, divTensor, input); rewriter.replaceOp(op, mulTensor); return success(); } }; } // namespace // LeakyRelu = max(0,x) + negative_slope * min(0,x) namespace { class DecomposeAtenLeakyReluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLeakyReluOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value negativeSlope = op.getNegativeSlope(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value constantZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value positiveOutput = rewriter.create(loc, resType, zeroTensor, input); Value negativeOutput = rewriter.create(loc, resType, zeroTensor, input); Value scaledNegativeOutput = rewriter.create( loc, resType, negativeOutput, negativeSlope); Value leakyReluOutput = rewriter.create( loc, resType, positiveOutput, scaledNegativeOutput, constantOne); rewriter.replaceOp(op, leakyReluOutput); return success(); } }; } // namespace // LeakyReluBackward = max(0,grad) + negative_slope * min(0,x) namespace { class DecomposeAtenLeakyReluBackwardOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLeakyReluBackwardOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value gradOutput = op.getGradOutput(); Value input = op.getSelf(); Value negativeSlope = op.getNegativeSlope(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } bool selfIsResult = false; if (!matchPattern(op.getSelfIsResult(), m_TorchConstantBool(&selfIsResult)) || selfIsResult) return rewriter.notifyMatchFailure( op, "unimplemented: self_is_result should be false"); Value constantZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value positiveOutput = rewriter.create(loc, resType, zeroTensor, gradOutput); Value negativeOutput = rewriter.create(loc, resType, zeroTensor, input); Value scaledNegativeOutput = rewriter.create( loc, resType, negativeOutput, negativeSlope); Value leakyReluBackwardOutput = rewriter.create( loc, resType, positiveOutput, scaledNegativeOutput, constantOne); rewriter.replaceOp(op, leakyReluBackwardOutput); return success(); } }; } // namespace // Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1) namespace { class DecomposeAtenEluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEluOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value alpha = op.getAlpha(); Value scale = op.getScale(); Value inputScale = op.getInputScale(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value constantZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value maxZeroX = rewriter.create(loc, resType, zeroTensor, input); Value positiveOutput = rewriter.create(loc, resType, maxZeroX, scale); Value minZeroX = rewriter.create(loc, resType, zeroTensor, input); Value scaledMinZeroX = rewriter.create(loc, resType, minZeroX, inputScale); Value expX = rewriter.create(loc, resType, scaledMinZeroX); Value expXM1 = rewriter.create(loc, resType, expX, constantOne, constantOne); Value scaledExpXM1 = rewriter.create(loc, resType, expXM1, scale); Value negativeOutput = rewriter.create(loc, resType, scaledExpXM1, alpha); Value eluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOne); rewriter.replaceOp(op, eluOutput); return success(); } }; } // namespace namespace { class DecomposeAtenTOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTOp op, PatternRewriter &rewriter) const override { Value lhs = op.getSelf(); std::optional lhsRank = getTensorRank(lhs); auto loc = op.getLoc(); if (!lhsRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } else if (*lhsRank > 2) { std::string errorMessage = "t() expects a tensor with <=2 dimensions, but self is " + std::to_string(*lhsRank) + "D"; return rewriter.notifyMatchFailure(op, errorMessage.c_str()); } else if (*lhsRank < 2) rewriter.replaceOp(op, lhs); else { Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp(op, op.getType(), lhs, zero, one); } return success(); } }; } // namespace // Decompose `aten.stack` into `aten.unsqueeze` and `aten.cat`. namespace { class DecomposeAtenStackOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenStackOp op, PatternRewriter &rewriter) const override { SmallVector tensors; if (!getListConstructElements(op.getTensors(), tensors)) { return rewriter.notifyMatchFailure( op, "unimplemented: the tensor list is not from list construct"); } // Ensure all tensors have known sizes for (Value tensor : tensors) { BaseTensorType tensorType = tensor.getType().cast(); if (!tensorType.hasSizes()) { return rewriter.notifyMatchFailure( op, "unimplemented: one tensor does not have known sizes"); } } SmallVector unsqueezedTensors; for (Value tensor : tensors) { auto unsqueezedInfo = unsqueezeTensor(rewriter, op, tensor, op.getDim()); if (failed(unsqueezedInfo)) { return rewriter.notifyMatchFailure( op, "cannot generate unsqueeze tensor op"); } unsqueezedTensors.push_back(*unsqueezedInfo); } Type listElemType = op.getType().cast().getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value unsqueezedTensorList = rewriter.create( op.getLoc(), listType, unsqueezedTensors); rewriter.replaceOpWithNewOp(op, op.getType(), unsqueezedTensorList, op.getDim()); return success(); } }; } // namespace // Decompose aten.roll into aten.slice and aten.cat ops. // https://pytorch.org/docs/stable/generated/torch.roll.html namespace { class DecomposeAtenRollOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRollOp op, PatternRewriter &rewriter) const override { SmallVector shifts; if (!getListConstructElements(op.getShifts(), shifts)) return rewriter.notifyMatchFailure( op, "unimplemented: shifts not list of Scalar"); SmallVector dims; if (!getListConstructElements(op.getDims(), dims)) return rewriter.notifyMatchFailure( op, "unimplemented: dims not list of Scalar"); if (shifts.size() != dims.size()) return op.emitError("list sizes of shifts and dims are not the same"); auto loc = op.getLoc(); Value constNone = rewriter.create(loc); Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value constOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); auto self = op.getSelf(); auto selfTy = self.getType().cast(); // roll(input, shift, dim) = cat({ // slice(input, dim, -shift, none), // slice(input, dim, 0, -shift)}, dim) auto imitateRoll = [&](Value input, Value shift, Value dim, int64_t cstDim) { Value negShift = rewriter.create(loc, shift); ArrayRef inputShape = selfTy.getSizes(); SmallVector sizes; sizes.append(inputShape.begin(), inputShape.end()); sizes[cstDim] = kUnknownSize; Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes), selfTy.getOptionalDtype()); Value slice0 = rewriter.create( loc, sliceTy, input, dim, negShift, constNone, constOne); Value slice1 = rewriter.create( loc, sliceTy, input, dim, constZero, negShift, constOne); Type listType = Torch::ListType::get(sliceTy); Value slices = rewriter.create( loc, listType, llvm::ArrayRef{slice0, slice1}); return rewriter.create(loc, self.getType(), slices, dim); }; std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; Value output = self; auto nShifts = shifts.size(); for (size_t k = 0; k < nShifts; ++k) { auto dim = dims[k]; int64_t cstDim = -1; if (!matchPattern(dim, m_TorchConstantInt(&cstDim))) return rewriter.notifyMatchFailure( op, "unimplemented: dim must be constant"); cstDim = toPositiveDim(cstDim, rank); output = imitateRoll(output, shifts[k], dim, cstDim); } rewriter.replaceOp(op, output); return success(); } }; } // namespace // Decompose aten.repeat into aten.expand and aten.view ops. // // Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html // // For shape [S1, S2, S3] and repeats [M0, M1, M2, M3] // MS0 = M0; MS1 = M1 * S1; MS2 = M2 * S2; MS3 = M3 * S3 // // def aten_repeat(self, repeats): // sizes = self.size() // unsqueezed_sizes = [] // expanded_sizes = [] // reshape_sizes = [] // leading_rank = repeats.size() - sizes.size() // for r in range(leading_rank): // unsqueezed_sizes.append(1) // expanded_sizes.append(repeats[r]) // reshaped_sizes.append(repeats[r]) // // for s, m in zip(sizes, repeats[leading_rank:]): // unsqueezed_sizes += [1, s] // expanded_sizes += [m, s] // reshaped_sizes += [m * s] // return // self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) // namespace { class DecomposeAtenRepeatOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRepeatOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; SmallVector repeats; if (!getListConstructElements(op.getRepeats(), repeats)) return rewriter.notifyMatchFailure( op, "Unimplemented: repeats not list of Scalar"); if (rank > repeats.size()) { return rewriter.notifyMatchFailure( op, "repeats are not matched with self's rank"); } auto insertDimSizes = [](SmallVector &dimSizes, SmallVector &shape, const ArrayRef &vals) { dimSizes.insert(dimSizes.end(), vals.begin(), vals.end()); std::transform(vals.begin(), vals.end(), std::back_inserter(shape), [&](Value val) -> int64_t { int64_t cst_val; if (matchPattern(val, m_TorchConstantInt(&cst_val))) { return cst_val; } else { return kUnknownSize; } }); }; Value one = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); SmallVector unsqueezedSizes, expandedSizes, reshapedSizes; SmallVector unsqueezedIntSizes, expandedIntSizes; assert(repeats.size() >= rank && "leadingRank should greater than 0"); auto leadingRank = repeats.size() - rank; for (size_t i = 0; i < leadingRank; ++i) { insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one}); insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef{repeats[i]}); reshapedSizes.push_back(repeats[i]); } auto selfType = self.getType().dyn_cast(); auto selfShape = selfType.getSizes(); for (unsigned i = 0; i < rank; i++) { auto scale = repeats[i + leadingRank]; Value dimSize; if (selfShape[i] == kUnknownSize) { Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); dimSize = rewriter.create(loc, self, dim); } else { dimSize = rewriter.create( loc, rewriter.getI64IntegerAttr(selfShape[i])); } insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one, dimSize}); insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef{scale, dimSize}); Value scaledSize = rewriter.create(loc, dimSize, scale); reshapedSizes.push_back(scaledSize); } Type dtype = self.getType().cast().getOptionalDtype(); Type unsqueezedType = ValueTensorType::get( context, llvm::ArrayRef(unsqueezedIntSizes), dtype); Type expandedType = ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype); auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value unsqueezedDims = rewriter.create(loc, listType, unsqueezedSizes); Value expandedDims = rewriter.create(loc, listType, expandedSizes); Value reshapedDims = rewriter.create(loc, listType, reshapedSizes); auto reshaped = rewriter.create(loc, unsqueezedType, op.getSelf(), unsqueezedDims); auto expanded = rewriter.create(loc, expandedType, reshaped, expandedDims); rewriter.replaceOpWithNewOp(op, op.getType(), expanded, reshapedDims); return success(); } }; } // namespace // Decompose aten.flatten.using_ints into aten.view op. namespace { class DecomposeAtenFlattenUsingIntsOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); unsigned rank = *maybeRank; int64_t start, end; if (!matchPattern(op.getStartDim(), m_TorchConstantInt(&start)) || !matchPattern(op.getEndDim(), m_TorchConstantInt(&end))) { return rewriter.notifyMatchFailure( op, "unimplemented: requires start and end dims to be constants"); } SmallVector newSizes; if (rank == 0) { Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); newSizes.push_back(one); } else { start = toPositiveDim(start, rank); end = toPositiveDim(end, rank); if (start > end) { return rewriter.notifyMatchFailure( op, "expected end dim larger than start dim"); } newSizes.reserve(rank - end + start); for (int64_t k = 0; k < start; ++k) { Value dim = rewriter.create(loc, rewriter.getI64IntegerAttr(k)); newSizes.push_back( rewriter.create(loc, self, /*dim=*/dim)); } Value flattenDimSize = rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); newSizes.push_back(flattenDimSize); for (int64_t k = end + 1; k < rank; ++k) { Value dim = rewriter.create(loc, rewriter.getI64IntegerAttr(k)); newSizes.push_back( rewriter.create(loc, self, /*dim=*/dim)); } } Value newSizeList = rewriter.create( loc, ListType::get(IntType::get(context)), newSizes); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), newSizeList); return success(); } }; } // namespace // Decompose aten.expand into aten.broadcast_to op. namespace { class DecomposeAtenExpandOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenExpandOp op, PatternRewriter &rewriter) const override { bool implicit = false; if (!matchPattern(op.getImplicit(), m_TorchConstantBool(&implicit)) || implicit) { return rewriter.notifyMatchFailure( op, "unimplemented: requires implicit to be false"); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getSize()); return success(); } }; } // namespace // Decompose aten.where.Scalar into aten.where.self op. namespace { class DecomposeAtenWhereScalarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereScalarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf()); Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), selfTensor, otherTensor); return success(); } }; } // namespace // Decompose aten.where.ScalarOther into aten.where.self op. namespace { class DecomposeAtenWhereScalarOtherOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), op.getSelf(), otherTensor); return success(); } }; } // namespace // Decompose aten.where.ScalarSelf into aten.where.self op. namespace { class DecomposeAtenWhereScalarSelfOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), selfTensor, op.getOther()); return success(); } }; } // namespace // Decompose aten.masked_fill.Scalar into aten.where.self op. namespace { class DecomposeAtenMaskedFillScalarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value mask = op.getMask(); Value value = createRank0Tensor(rewriter, loc, resType, op.getValue()); rewriter.replaceOpWithNewOp(op, resType, mask, value, op.getSelf()); return success(); } }; } // namespace // Decompose aten._convolution-like to aten.convolution namespace { template class DecomposeAten_ConvolutionLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConvolutionLikeOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), op.getOutputPadding(), op.getGroups()); return success(); } }; } // namespace // Decompose aten.conv2d to aten.convolution namespace { class DecomposeAtenConv2dOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConv2dOp op, PatternRewriter &rewriter) const override { Value emptyList = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, op.getGroups()); return success(); } }; } // namespace // Decompose aten.conv_transpose2d to aten.convolution namespace { class DecomposeAtenConvTranspose2dOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op, PatternRewriter &rewriter) const override { Value cstTrue = rewriter.create(op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); return success(); } }; } // namespace static LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, int64_t dimB, Type &transposedType) { if (!inType.hasSizes()) return failure(); SmallVector shape(inType.getSizes()); int64_t tmp = shape[0]; shape[0] = shape[1]; shape[1] = tmp; transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape), inType.getOptionalDtype()); return success(); } // The convolution backward op is decomposed as follows: // inputH, inputW = input.shape[2:] // output_padding_ = [ // inputH // - 1 // + 2 * padding_[0] // - dilation_[0] * (weight.shape[2] - 1) // - (grad_output.shape[2] - 1) * stride_[0], // inputW // - 1 // + 2 * padding_[1] // - dilation_[1] * (weight.shape[3] - 1) // - (grad_output.shape[3] - 1) * stride_[1], // ] // // decomp_grad_input = torch.nn.functional.conv_transpose2d( // grad_output, // weight, // None, // stride_, // padding_, // output_padding_, // groups_, // dilation_, // ) // // input_transposed = torch.ops.aten.transpose(input, 0, 1) // grad_output_transposed = grad_output.view( // grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] // ) // decomp_grad_weight = torch.ops.aten.convolution( // input_transposed, // grad_output_transposed, // bias=None, // stride=dilation_, // padding=padding_, // dilation=stride_, // transposed=False, // output_padding=[0, 0], // groups=input.shape[0], // ) // decomp_grad_weight = torch.narrow(decomp_grad_weight, 2, 0, weight.shape[2]) // decomp_grad_weight = torch.narrow(decomp_grad_weight, 3, 0, weight.shape[3]) // decomp_grad_weight = decomp_grad_weight.view( // input_transposed.shape[0], // input_transposed.shape[1], // grad_output.shape[1], // *decomp_grad_weight.shape[2:] // ) // decomp_grad_weight = decomp_grad_weight.movedim(0, 2) // decomp_grad_weight = decomp_grad_weight.sum(dim=0) // // decomp_grad_bias = torch.sum(grad_output, dim=[0, 2, 3]) namespace { class DecomposeAtenConvolutionBackwardOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value input = op.getInput(); Value weight = op.getWeight(); Value gradOutput = op.getGradOutput(); std::optional maybeGradRank = getTensorRank(gradOutput); if (!maybeGradRank) { return rewriter.notifyMatchFailure(op, "expected grad output to have a rank"); } unsigned gradRank = *maybeGradRank; if (gradRank != 4) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D convolutions supported."); Value cstZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value cstTwo = rewriter.create( loc, rewriter.getI64IntegerAttr(2)); Value cstNone = rewriter.create(loc); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); SmallVector padding, dilation, stride; SmallVector paddingInt, dilationInt, strideInt, outputPaddingInt; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInt))) return rewriter.notifyMatchFailure( op, "padding must be a list of constant ints"); if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInt))) return rewriter.notifyMatchFailure( op, "stride must be a list of constant ints"); if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInt))) return rewriter.notifyMatchFailure( op, "dilation must be a list of constant ints"); if (!llvm::all_of(dilationInt, [](int64_t dilationVal) { return dilationVal == 1; })) return rewriter.notifyMatchFailure( op, "unimplemented: only dilations of 1 supported."); if (!matchPattern(op.getOutputPadding(), m_TorchListOfConstantInts(outputPaddingInt))) return rewriter.notifyMatchFailure( op, "output padding must be a list of constant ints"); if (!llvm::all_of(outputPaddingInt, [](int64_t outPad) { return outPad == 0; })) return rewriter.notifyMatchFailure( op, "unimplemented: only output padding of 0 supported."); SmallVector outMask; if (!matchPattern(op.getOutputMask(), m_TorchListOfConstantBools(outMask))) return rewriter.notifyMatchFailure( op, "only constant bool output_mask is supported."); for (unsigned i = 0; i < outMask.size(); i++) { if (outMask[i] == false) { Value result = op->getResults()[i]; if (!result.getUsers().empty()) return rewriter.notifyMatchFailure( op, "unimplemented: false value supported for output_mask only " "when the result tensor corresponding to that has no users."); } } bool transposed; if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) return rewriter.notifyMatchFailure( op, "transposed arg should be a constant bool."); if (transposed) return rewriter.notifyMatchFailure( op, "unimplemented: transposed convolutions are not supported."); getListConstructElements(op.getPadding(), padding); getListConstructElements(op.getStride(), stride); getListConstructElements(op.getDilation(), dilation); // Computing Grad Input. // Calculate output padding for first convolution. // output_padding_ = [ // inputH - 1 + (2 * padding_[0]) - (dilation_[0] * (weight.size()[2] // - 1)) - ((grad_out.size()[2] - 1) * stride_[0]), inputW - 1 + (2 * // padding_[1]) - (dilation_[1] * (weight.size()[3] - 1)) - // ((grad_out.size()[3] - 1) * stride_[1]), // ] SmallVector outputPaddingValues; for (unsigned i = 2; i < gradRank; i++) { Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); Value inputVecDim = rewriter.create(loc, input, dim); Value gradOutDim = rewriter.create(loc, gradOutput, dim); Value weightDim = rewriter.create(loc, weight, dim); Value inputVecDimMinusOne = rewriter.create(loc, inputVecDim, cstOne); Value gradOutDimMinusOne = rewriter.create(loc, gradOutDim, cstOne); Value weightDimMinusOne = rewriter.create(loc, weightDim, cstOne); Value twoTimesPadding = rewriter.create(loc, padding[i - 2], cstTwo); Value tmpA = rewriter.create(loc, weightDimMinusOne, dilation[i - 2]); Value tmpB = rewriter.create(loc, gradOutDimMinusOne, stride[i - 2]); Value outputPaddingVal = rewriter.create( loc, inputVecDimMinusOne, twoTimesPadding); outputPaddingVal = rewriter.create(loc, outputPaddingVal, tmpA); outputPaddingVal = rewriter.create(loc, outputPaddingVal, tmpB); outputPaddingValues.push_back(outputPaddingVal); } Value outputPaddingForGradInput = rewriter.create( loc, ListType::get(IntType::get(context)), outputPaddingValues); Value gradInput = rewriter.create( loc, op.getResultTypes()[0], gradOutput, weight, cstNone, op.getStride(), op.getPadding(), outputPaddingForGradInput, op.getGroups(), op.getDilation()); Type transposedType; if (failed(getTransposedType(input.getType().cast(), 0, 1, transposedType))) return failure(); Value inputTransposed = rewriter.create( loc, transposedType, input, cstZero, cstOne); // For the cases where the stride is non-unit, we compute the `GradWeight` // through this implementation. Value gradWeight; if (!llvm::all_of(strideInt, [](int64_t stride) { return stride == 1; })) { // Computing Grad Weight. SmallVector gradOutputSize; for (unsigned i = 0; i < gradRank; i++) { gradOutputSize.push_back(rewriter.create( loc, gradOutput, rewriter.create( loc, rewriter.getI64IntegerAttr(i)))); } Value gradOutputViewDimZero = rewriter.create( loc, gradOutputSize[0], gradOutputSize[1]); Value gradOutputViewShapeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], gradOutputSize[3]}); BaseTensorType gradOutputTy = gradOutput.getType().cast(); if (!gradOutputTy.hasSizes()) return failure(); SmallVector gradOutputSizesInt(gradOutputTy.getSizes()); SmallVector gradOutputViewSizesInt(gradOutputSizesInt); if (gradOutputViewSizesInt[0] != kUnknownSize && gradOutputViewSizesInt[1] != kUnknownSize) gradOutputViewSizesInt[0] *= gradOutputViewSizesInt[1]; else gradOutputViewSizesInt[0] = kUnknownSize; gradOutputViewSizesInt[1] = 1; BaseTensorType gradOutputTypeForView = gradOutputTy .getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt), gradOutputTy.getOptionalDtype()) .cast(); Value gradOutputView = rewriter.create( loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); BaseTensorType inputTransposedTy = inputTransposed.getType().cast(); if (!inputTransposedTy.hasSizes()) return failure(); SmallVector inputTransposedSizesInt( inputTransposedTy.getSizes()); SmallVector gradWeightSizesInt{inputTransposedSizesInt[0], gradOutputViewSizesInt[0]}; for (unsigned i = 2; i < gradRank; i++) { if (inputTransposedSizesInt[i] != kUnknownSize && gradOutputViewSizesInt[i] != kUnknownSize) { int64_t kernelSizeInt = strideInt[i - 2] * (gradOutputViewSizesInt[i] - 1) + 1; gradWeightSizesInt.push_back( ((inputTransposedSizesInt[i] + (paddingInt[i - 2] * 2) - kernelSizeInt) / dilationInt[i - 2]) + 1); } else { gradWeightSizesInt.push_back(kUnknownSize); } } BaseTensorType gradWeightTy = inputTransposedTy .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), inputTransposedTy.getOptionalDtype()) .cast(); Value numGroup = rewriter.create(loc, input, cstZero); gradWeight = rewriter.create( loc, gradWeightTy, inputTransposed, gradOutputView, cstNone, /*stride=*/op.getDilation(), op.getPadding(), /*dilation=*/op.getStride(), op.getTransposed(), op.getOutputPadding(), numGroup); BaseTensorType weightTy = weight.getType().cast(); if (!weightTy.hasSizes()) return failure(); SmallVector weightSizes(weightTy.getSizes()); for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { gradWeightSizesInt[i + 2] = weightSizes[i + 2]; BaseTensorType gradWeightNarrowTy = gradWeightTy .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), gradWeightTy.getOptionalDtype()) .cast(); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i + 2)); Value length = rewriter.create(loc, weight, dim); gradWeight = rewriter.create( loc, gradWeightNarrowTy, gradWeight, dim, /*start=*/cstZero, length); } SmallVector gradWeightViewShapeInt{ inputTransposedSizesInt[0], inputTransposedSizesInt[1]}; gradWeightViewShapeInt.push_back(gradOutputSizesInt[1]); gradWeightViewShapeInt.insert( gradWeightViewShapeInt.end(), {gradWeightSizesInt[2], gradWeightSizesInt[3]}); SmallVector gradWeightViewShapeValue; for (unsigned i = 0; i < gradWeightViewShapeInt.size(); i++) { gradWeightViewShapeValue.push_back( rewriter.create( loc, rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i]))); } Value gradWeightViewShapeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), gradWeightViewShapeValue); BaseTensorType gradWeightTypeForView = gradWeightTy .getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt), gradWeightTy.getOptionalDtype()) .cast(); gradWeight = rewriter.create( loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); gradWeightTy = gradWeight.getType().cast(); SmallVector gradWeightDimsOrder = computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size()); SmallVector gradWeightMoveDimShape; for (unsigned i = 0; i < gradWeightDimsOrder.size(); i++) { gradWeightMoveDimShape.push_back( gradWeightViewShapeInt[gradWeightDimsOrder[i]]); } BaseTensorType gradWeightTypeForMoveDim = gradWeightTy .getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape), gradWeightTy.getOptionalDtype()) .cast(); gradWeight = rewriter.create( loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero, /*destination=*/cstTwo); Value gradIntList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), llvm::ArrayRef{cstZero}); gradWeight = rewriter.create( loc, op.getResultTypes()[1], /*self=*/gradWeight, /*dim=*/gradIntList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); } else { if (failed(getTransposedType(gradOutput.getType().cast(), 0, 1, transposedType))) return failure(); Value gradOutputTransposed = rewriter.create( loc, transposedType, gradOutput, cstZero, cstOne); // Convolve input with grad_output. if (failed( getTransposedType(op.getResultTypes()[1].cast(), 0, 1, transposedType))) return failure(); gradWeight = rewriter.create( loc, transposedType, inputTransposed, gradOutputTransposed, cstNone, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), op.getOutputPadding(), op.getGroups()); gradWeight = rewriter.create( loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne); } // Computing Grad Bias. SmallVector dimIntList{cstZero}; for (unsigned i = 2; i < gradRank; i++) dimIntList.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); Value gradIntList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dimIntList); // Sum grad_output along dim 1. Value gradBias = rewriter.create( loc, op.getResultTypes()[2], gradOutput, gradIntList, cstFalse, cstNone); rewriter.replaceOp(op, {gradInput, gradWeight, gradBias}); return success(); } }; } // namespace // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenAddmmOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value mat1 = op.getMat1(); Value mat2 = op.getMat2(); std::optional mat1Rank = getTensorRank(mat1); std::optional mat2Rank = getTensorRank(mat2); // The operands `mat1`, `mat2` to aten.addmm must be of rank 2. if (!mat1Rank || !mat2Rank || *mat1Rank != 2 || *mat2Rank != 2) { return rewriter.notifyMatchFailure( op, "expected mat1, mat2 operands to aten.addmm to be rank 2"); } // TODO: Handle integer type operands. auto inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: non-floating point dtype"); } // matrix multiplication: matmul = mat1 @ mat2 Value matmul = rewriter.create(loc, op.getType(), mat1, mat2); // scaledInput = self * beta Value scaledInput = rewriter.create(loc, input.getType(), input, op.getBeta()); // result = scaledInput + alpha * matmul rewriter.replaceOpWithNewOp(op, op.getType(), scaledInput, matmul, op.getAlpha()); return success(); } }; } // namespace // Decompose aten.mean into: sum(x)/div(numTensorElements). namespace { class DecomposeAtenMeanOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMeanOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value output = op.getResult(); BaseTensorType outputTensorType = output.getType().cast(); Value sum = rewriter.create(loc, outputTensorType, input, op.getDtype()); Value numTensorElements = rewriter.create(loc, input); rewriter.replaceOpWithNewOp(op, outputTensorType, sum, numTensorElements); return success(); } }; } // namespace // productDimSize = product(size(dim) for dim in dims) // aten.mean(x, dims) = aten.sum(x, dims) / productDimSize. namespace { class DecomposeAtenMeanDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMeanDimOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; Value dimList = op.getDim(); Value keepDim = op.getKeepdim(); Value dtype = op.getDtype(); Type outputType = op.getType(); MLIRContext *context = op.getContext(); BaseTensorType inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa() || !isNoneOrFloatDtype(context, dtype)) { return rewriter.notifyMatchFailure( op, "only floating-point type is supported"); } SmallVector dimListElements; if (!getListConstructElements(dimList, dimListElements) && !dimList.getType().isa()) { return rewriter.notifyMatchFailure( op, "expected `dim` to be `None` or constructed from list construct"); } // Compute sum along dimensions specified in `dimList`. Value sumAlongDims = rewriter.create( loc, outputType, input, dimList, keepDim, dtype); // `productDimSize` is product of sizes of dimensions to be reduced. Value productDimSize; // Case: Reduce along all dims. if (dimListElements.empty() && inputRank != 0) { productDimSize = rewriter.create(loc, input); } else { productDimSize = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); for (Value dim : dimListElements) { Value dimSize = rewriter.create(loc, input, dim); productDimSize = rewriter.create(loc, productDimSize, dimSize); } } rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, productDimSize); return success(); } }; } // namespace namespace { class DecomposeAtenSquareOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSquareOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); rewriter.replaceOpWithNewOp(op, op.getType(), self, self); return success(); } }; } // namespace // Silu(x) = sigmoid(x) * x namespace { class DecomposeAtenSiluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSiluOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); Value sigmoid = rewriter.create(op.getLoc(), op.getType(), self); rewriter.replaceOpWithNewOp(op, op.getType(), sigmoid, self); return success(); } }; } // namespace // pDash = 1.0 - p // boolMask = aten.rand_like(input) < pDash // dropout(input, p, train=True) = (boolMask * input) / pDash // dropout(input, p, train=False) = input namespace { class DecomposeAtenDropoutOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenDropoutOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getInput(); Value prob = op.getP(); bool train = false; if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) return rewriter.notifyMatchFailure(op, "train must be a boolean constant"); if (!train) { rewriter.replaceOp(op, input); return success(); } BaseTensorType inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa()) return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); Value noneVal = rewriter.create(loc); Value floatOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value oneMinusP = rewriter.create(loc, floatOne, prob); Value boolMask = rewriter.create( loc, inputType, input, oneMinusP, /*generator=*/noneVal); Value maskedInput = rewriter.create(loc, inputType, boolMask, input); rewriter.replaceOpWithNewOp(op, op.getType(), maskedInput, oneMinusP); return success(); } }; class DeomposeAtenNativeDropoutOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNativeDropoutOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op->getContext(); Value input = op.getInput(); Value prob = op.getP(); bool train = false; if (!op.getTrain().getType().isa()) { if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) { return rewriter.notifyMatchFailure( op, "train must be a boolean constant or none"); } } Value noneVal = rewriter.create(loc); if (!train) { Value i1Type = getDtypeIntValueForType(rewriter, loc, IntegerType::get(context, 1)); Value inputSize = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), input); Value trueValue = rewriter.create(loc, 1); Value trueMask = rewriter.create( loc, op->getResultTypes()[1], inputSize, trueValue, i1Type, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); rewriter.replaceOp(op, ArrayRef{input, trueMask}); return success(); } BaseTensorType inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); } Value floatOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value oneMinusP = rewriter.create(loc, floatOne, prob); Value boolMask = rewriter.create( loc, inputType, input, oneMinusP, /*generator=*/noneVal); Value maskedInput = rewriter.create(loc, inputType, boolMask, input); Value output = rewriter.create( loc, op->getResultTypes()[0], maskedInput, oneMinusP); rewriter.replaceOp( op, ArrayRef{ output, convertTensorToDtype(rewriter, loc, boolMask, IntegerType::get(context, 1))}); return success(); } }; } // namespace // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); std::optional maybeInputRank = getTensorRank(self); if (!maybeInputRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; BaseTensorType rank0FloatTensorTy = op.getType().cast(); if (!rank0FloatTensorTy.hasSizes() || rank0FloatTensorTy.getSizes().size() != 0) { return rewriter.notifyMatchFailure( op, "expected aten.var to have a rank 0 tensor type"); } SmallVector dims; for (unsigned i = 0; i < inputRank; i++) dims.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dims); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp(op, rank0FloatTensorTy, self, dimList, op.getUnbiased(), /*keepdim=*/cstFalse); return success(); } }; } // namespace // Decompose aten.std to sqrt(var(x)) namespace { class DecomposeAtenStdOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenStdOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType inputTensorTy = self.getType().cast(); if (!inputTensorTy.hasDtype() || !inputTensorTy.getDtype().isa()) { return rewriter.notifyMatchFailure(op, "Only aten.std support floating type"); } Value var = rewriter.create(op->getLoc(), op.getType(), op.getSelf(), op.getUnbiased()); rewriter.replaceOpWithNewOp(op, op.getType(), var); return success(); } }; } // namespace // Softplus(x, beta, threshold) = // x * beta > threshold ? x : log(1 + exp(x * beta)) / beta namespace { class DecomposeAtenSoftplusOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSoftplusOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); BaseTensorType inputType = input.getType().cast(); Value inputTimesBeta = rewriter.create(loc, inputType, input, op.getBeta()); // out = log1p(exp(input * beta)) / beta Value exp = rewriter.create(loc, inputType, inputTimesBeta); Value log1p = rewriter.create(loc, inputType, exp); Value out = rewriter.create(loc, inputType, log1p, op.getBeta()); // Select where x * beta > threshold auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), rewriter.getI1Type()); Value condition = rewriter.create( loc, boolResType, inputTimesBeta, op.getThreshold()); rewriter.replaceOpWithNewOp(op, op.getType(), condition, input, out); return success(); } }; } // namespace // Decompose aten.std.dim to sqrt(var.dim(x)) namespace { class DecomposeAtenStdDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenStdDimOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType inputTensorType = self.getType().cast(); if (!inputTensorType.hasDtype() || !inputTensorType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "aten.std.dim expects input tensor of floating-point type"); } Value varDim = rewriter.create(op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(), op.getKeepdim()); rewriter.replaceOpWithNewOp(op, op.getType(), varDim); return success(); } }; } // namespace // Decompose aten.std.correction to sqrt(var.correction(x)) namespace { class DecomposeAtenStdCorrectionOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenStdCorrectionOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType inputTensorType = self.getType().cast(); if (!inputTensorType.hasDtype() || !inputTensorType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "aten.std.correction expects input tensor of floating-point type"); } Value varCorrection = rewriter.create( op->getLoc(), op.getType(), self, op.getDim(), op.getCorrection(), op.getKeepdim()); rewriter.replaceOpWithNewOp(op, op.getType(), varCorrection); return success(); } }; } // namespace // Hardsigmoid(x) = max(0, min(1, (x+3)/6)) namespace { class DecomposeAtenHardsigmoidOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenHardsigmoidOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); BaseTensorType inputType = input.getType().cast(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } // outputTensor = (input + 3) / 6. Value constantOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value constantThree = rewriter.create( loc, rewriter.getI64IntegerAttr(3)); Value constantSix = rewriter.create( loc, rewriter.getI64IntegerAttr(6)); Value inputPlusThree = rewriter.create( loc, inputType, input, constantThree, /*alpha=*/constantOne); Value outputTensor = rewriter.create( loc, inputType, inputPlusThree, constantSix); // result = max(0, min(1, (input+3)/6)) Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value oneTensor = createRank0Tensor(rewriter, loc, inputType, constantOne); Value minResult = rewriter.create(loc, inputType, oneTensor, outputTensor); Value zeroTensor = createRank0Tensor(rewriter, loc, inputType, constantZero); rewriter.replaceOpWithNewOp(op, op.getType(), zeroTensor, minResult); return success(); } }; } // namespace namespace { class DecomposeAtenHardtanhOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenHardtanhOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); BaseTensorType inputType = input.getType().cast(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } // result = min(maxVal, max(minVal, x)) Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal()); Value maxResult = rewriter.create(loc, inputType, input, minVal); Value maxVal = createRank0Tensor(rewriter, loc, inputType, op.getMaxVal()); rewriter.replaceOpWithNewOp(op, op.getType(), maxVal, maxResult); return success(); } }; } // namespace namespace { class DecomposeAtenRandLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandLikeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Type resultType = op.getType(); auto inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure(op, "only support floating-point type"); } // Create a uniform random op with low and high set to 0.0 and 1.0, // respectively. Value none = rewriter.create(loc); Value zero = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value emptyTensor = rewriter.create( loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, resultType, emptyTensor, /*from=*/zero, /*to=*/one, /*generator=*/none); return success(); } }; } // namespace namespace { // Bernoulli(x, p) = (randLike(float(x)) < p).cast(type(x)). Here, // 1. p must be a float tensor. // 2. The shape of p should be broadcastable to the shape of x. // 3. Bernoulli(x, p) returns a tensor of the same type as that of x. static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op, Location loc, Value input, Value prob, Value &output) { auto inputType = input.getType().cast(); auto probType = prob.getType().cast(); // Both the `input` and `prob` must be ranked tensors. if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() || !probType.hasDtype()) { return rewriter.notifyMatchFailure( op, "can't decompose bernoulli like ops without sizes or dtype"); } // The `prob` is expected to be a float type tensor. if (!probType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "probabilities must be a float type tensor"); } // Since the `aten.randLike` op expects float-type operand, create a // float-type tensor with the same shape as that of the `input`. Value floatTensor = convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type()); Value none = rewriter.create(loc); Value randomVal = rewriter.create( loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); // Bernoulli(x, p) = randLike(float(x)) < p. auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), rewriter.getI1Type()); Value lessThanP = rewriter.create(loc, boolResType, randomVal, prob); // As the `output` is expected to be of the `input` type, convert the boolean // tensor `lessThanP` to a `input` type tensor. output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype()); return success(); } // aten.bernoulli(x) = randLike(x) < x. Here, the input x is a tensor // containing probabilities to be used for drawing the binary random number. class DecomposeAtenBernoulliOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBernoulliOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to ben None because only global default " "generator is supported"); Value output; if (failed( decomposeBernoulliLikeOp(rewriter, op, loc, input, input, output))) return rewriter.notifyMatchFailure( op, "decomposeBernoulliLikeOp failed to decompose the op"); rewriter.replaceOp(op, output); return success(); } }; // aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)). // Since the input x can be an integer tensor, it's important to cast it to // float type before passing it to the `aten.randLike` op. template class DecomposeAtenBernoulliLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(BernoulliLikeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value p = op.getP(); if (!op.getGenerator().getType().template isa()) return rewriter.notifyMatchFailure( op, "The generator has to ben None because only global default " "generator is supported"); auto inputType = input.getType().cast(); SmallVector empty; Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty), rewriter.getF64Type()); Value prob = rewriter.create(loc, tensorType, p); Value output; if (failed( decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output))) return rewriter.notifyMatchFailure( op, "decomposeBernoulliLikeOp failed to decompose the op"); rewriter.replaceOp(op, output); return success(); } }; // aten.bernoulli.Tensor(x, p) = (randLike(float(x)) < p).cast(type(x)). // Since the input x can be an integer tensor, it's important to cast it to // float type before passing it to the `aten.randLike` op. class DecomposeAtenBernoulliTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBernoulliTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value prob = op.getP(); if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to ben None because only global default " "generator is supported"); Value output; if (failed( decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output))) return rewriter.notifyMatchFailure( op, "decomposeBernoulliLikeOp failed to decompose the op"); rewriter.replaceOp(op, output); return success(); } }; } // namespace namespace { template class DecomposeAtenAddCLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value tensor1 = op.getTensor1(); Value tensor2 = op.getTensor2(); Value value = op.getValue(); Value product = rewriter.create(loc, op.getType(), tensor1, tensor2); rewriter.replaceOpWithNewOp(op, op.getType(), input, product, value); return success(); } }; class DecomposeAtenLayerNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLayerNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto input = op.getInput().getType().cast(); if (!input.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); int64_t inputRank = input.getSizes().size(); Value normalizedShape = op.getNormalizedShape(); SmallVector normalizedShapeSizesTorchInt; getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); std::vector meanVarSizes(inputRank, 1); for (int i = 0; i < axis; i++) meanVarSizes[i] = input.getSizes()[i]; auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes), input.getOptionalDtype()); auto nativeLayerNorm = rewriter.create( loc, op.getType(), meanVarType, meanVarType, op.getInput(), op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps()); rewriter.replaceOp(op, nativeLayerNorm.getResult(0)); return success(); } }; } // namespace namespace { class DecomposeAtenNativeLayerNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNativeLayerNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto context = op.getContext(); auto inputTy = op.getInput().getType().cast(); if (!inputTy.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); int64_t inputRank = inputTy.getSizes().size(); Value normalizedShape = op.getNormalizedShape(); SmallVector normalizedShapeSizesTorchInt; getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); auto reduceDimInts = llvm::to_vector<4>(llvm::seq(axis, inputRank)); auto reducedTy = op.getResult(1).getType(); auto sizeListType = ListType::get(IntType::get(context)); // build reduce dims SmallVector reduceDimVals; reduceDimVals.reserve(reduceDimInts.size()); std::transform(reduceDimInts.begin(), reduceDimInts.end(), std::back_inserter(reduceDimVals), [&](int64_t d) { return rewriter.create( loc, rewriter.getI64IntegerAttr(d)); }); Value reduceDimList = rewriter.create(loc, sizeListType, reduceDimVals); Value one = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value cstTrue = rewriter.create(loc, true); Value none = rewriter.create(loc); // mean(x) Value inputMean = rewriter.create( loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); // x - mean(x) Value inputMeanExpanded = rewriter.create(loc, inputTy, inputMean, op.getInput()); Value inputZeroMean = rewriter.create( loc, inputTy, op.getInput(), inputMeanExpanded, one); // var(x) = mean((x - mean(x))^2) Value inputZeroMeanSquare = rewriter.create( loc, inputTy, inputZeroMean, inputZeroMean); Value inputVar = rewriter.create( loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none); // rsqrt(var(x) + eps) Value inputVarPlusEps = rewriter.create( loc, reducedTy, inputVar, op.getEps(), one); Value inputRsqrtVar = rewriter.create(loc, reducedTy, inputVarPlusEps); // (x - mean(x)) * rsqrt(var(x) + eps) Value inputRsqrtVarExpanded = rewriter.create( loc, inputTy, inputRsqrtVar, op.getInput()); Value inputNormalized = rewriter.create( loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); Value out = rewriter.create( loc, op.getResult(0).getType(), inputNormalized); Value weight = op.getWeight(); Value bias = op.getBias(); if (!weight.getType().isa()) { out = rewriter.create(loc, out.getType(), out, weight); } if (!bias.getType().isa()) { out = rewriter.create(loc, out.getType(), out, bias, one); } rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar}); return success(); } }; } // namespace namespace { // Decompose `aten.emptyLike` op into `aten.size` and `aten.empty` ops. class DecomposeAtenEmptyLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEmptyLikeOp op, PatternRewriter &rewriter) const override { auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getSelf()); rewriter.replaceOpWithNewOp( op, op.getType(), sizeList, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); return success(); } }; } // namespace namespace { // The `aten.arange` op is converted to `aten.arange.startStep` op. class DecomposeAtenArangeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenArangeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); // The AtenArangeOp doesn't have a start and step value. Therefore we set // them as default values 0 and 1, respectively. Value start, step; start = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // The `aten.arange.start` op is converted to `aten.arange.startStep` op. class DecomposeAtenArangeStartOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenArangeStartOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); // The AtenArangeStartOp doesn't have a step value. Therefore we set it as // default value 1. Value step; step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose constant tensor full like ops. template class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value constVal = rewriter.create( loc, rewriter.getI64IntegerAttr(fillVal)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), constVal, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); return success(); } }; } // namespace namespace { class DecomposeAtenNativeBatchNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNativeBatchNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value input = op.getInput(); Value weight = op.getWeight(); Value bias = op.getBias(); Value runningMean = op.getRunningMean(); Value runningVar = op.getRunningVar(); Value eps = op.getEps(); // TODO: Add support for `training` mode. bool training = false; if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training)) || training) return rewriter.notifyMatchFailure( op, "unimplemented: training mode is not supported"); // Rank of the input tensor must be greater than or equal to 2. The shape of // the `input` is supposed to be (N, C, D?, H?, W?). std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank || *maybeInputRank < 2) return rewriter.notifyMatchFailure( op, "input must have rank greater than or equal to 2"); unsigned inputRank = *maybeInputRank; // In the inference mode, the `runningMean` and `runningVar` must not be // None. if (runningMean.getType().isa() || runningVar.getType().isa()) return rewriter.notifyMatchFailure( op, "running stats must not be None in inference mode"); // Rank of `runningMean` and `runningVar` must be exactly 1. std::optional runningMeanRank = getTensorRank(runningMean); std::optional runningVarRank = getTensorRank(runningVar); if (!runningMeanRank || !runningVarRank || *runningMeanRank != 1 || *runningVarRank != 1) return rewriter.notifyMatchFailure( op, "expected runningMean and runningVar to be rank 1"); Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value numFeatures = rewriter.create(loc, input, /*dim=*/one); // TODO: Add Runtime Asserts to check the shape of weight, bias, // runningMean and runningVar to be (numFeatures). // The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?) // to make it broadcast-compatible with (N, C, D?, H?, W?). // 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?) // 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?) SmallVector runningStatsShape(inputRank, one); runningStatsShape[1] = numFeatures; Value runningStatsSizeList = rewriter.create( loc, ListType::get(IntType::get(context)), runningStatsShape); SmallVector runningStatsShapeInt(inputRank, 1); runningStatsShapeInt[1] = runningMean.getType().cast().getSizes()[0]; Type dtype = input.getType().cast().getOptionalDtype(); Type reshapeType = ValueTensorType::get( context, llvm::ArrayRef(runningStatsShapeInt), dtype); runningMean = rewriter.create(loc, reshapeType, runningMean, runningStatsSizeList); runningVar = rewriter.create(loc, reshapeType, runningVar, runningStatsSizeList); // normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)). Value inputSubMean = rewriter.create( loc, input.getType(), input, runningMean, /*alpha=*/one); Value varEps = rewriter.create( loc, runningVar.getType(), runningVar, eps, /*alpha=*/one); Value invStd = rewriter.create(loc, varEps.getType(), varEps); Value normalizedInput = rewriter.create( loc, inputSubMean.getType(), inputSubMean, invStd); // The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it // broadcast-compatible with (N, C, D?, H?, W?). // 1. weight = weight.view(1, C, 1?, 1?, 1?) // 2. bias = bias.view(1, C, 1?, 1?, 1?) // 3. output = normalizedInput * weight + bias Value batchNormOutput = normalizedInput; if (!weight.getType().isa()) { // Rank of `weight` must be exactly 1. std::optional weightRank = getTensorRank(weight); if (!weightRank || *weightRank != 1) return rewriter.notifyMatchFailure(op, "expected weight to be rank 1"); weight = rewriter.create(loc, reshapeType, weight, runningStatsSizeList); batchNormOutput = rewriter.create( loc, batchNormOutput.getType(), batchNormOutput, weight); } if (!bias.getType().isa()) { // Rank of `bias` must be exactly 1. std::optional biasRank = getTensorRank(bias); if (!biasRank || *biasRank != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); bias = rewriter.create(loc, reshapeType, bias, runningStatsSizeList); batchNormOutput = rewriter.create( loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one); } // The `mean` and `invstd` outputs are empty tensors in inference mode. Value zeroList = rewriter.create( loc, Torch::ListType::get(zero.getType()), zero); Value none = rewriter.create(loc); Value emptyMeanTensor = rewriter.create( loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); Value emptyInvStdTensor = rewriter.create( loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); rewriter.replaceOp(op, {batchNormOutput, emptyMeanTensor, emptyInvStdTensor}); return success(); } }; } // namespace // Decompse `Aten_UnsafeViewOp` into `AtenViewOp`. UnsafeView() differs from // view() in that the returned tensor isn't treated as a view for the purposes // of automatic differentiation. It's only safe to use if the `self` tensor is // temporary. For example, the viewed tensor here (a + b) is discarded // immediately after viewing: // // res = UnsafeView(a + b, size); // // This is a hack because in-place operations on tensors treated like views // can be much more expensive than the same operations on non-view tensors. // Refer to // https://github.com/pytorch/pytorch/blob/364055b2771ecf9b54f1d67a8bf44bb5496476d4/aten/src/ATen/native/TensorShape.cpp#L2072 namespace { class DecomposeAten_UnsafeViewOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_UnsafeViewOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getSize()); return success(); } }; } // namespace // In PyTorch, ReshapeAlias just uses an already computed stride. // See // https://github.com/pytorch/pytorch/blob/d8c31a819d4a65e732b5901e3b994e1869851f1a/aten/src/ATen/native/TensorShape.cpp#L1153 // Note that this is the same decomposition as in AOTAutograd // https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/Src/aotAutograd.py#L104 namespace { class DecomposeAten_ReshapeAliasOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getSize()); return success(); } }; } // namespace namespace { // Decompose constant tensor like ops. template class DecomposeConstantTensorNewLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); if (dtype.getType().isa()) { BaseTensorType tensorType = op.getSelf().getType().template cast(); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a dtype"); } dtype = getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose `aten.full` op into `aten.broadcastTo` class DecomposeAtenFullOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFullOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); BaseTensorType outTy = op.getType().template cast(); if (!outTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Value fillVal = rewriter.create(loc, tensorType, op.getFillValue()); fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype()); rewriter.replaceOpWithNewOp(op, op.getType(), fillVal, op.getSize()); return success(); } }; } // namespace namespace { // Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops. class DecomposeAtenLinearOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLinearOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getInput(); Value weight = op.getWeight(); Value bias = op.getBias(); BaseTensorType inputType = input.getType().cast(); if (!inputType.hasSizes() || inputType.getSizes().size() < 2) return rewriter.notifyMatchFailure( op, "expected input to be rank 2 or greater"); BaseTensorType weightType = weight.getType().cast(); // `weight` must be a rank 2 matrix. if (!weightType.hasSizes() || weightType.getSizes().size() != 2) return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); SmallVector transposeShape = llvm::to_vector(llvm::reverse(weightType.getSizes())); Type transposeType = weightType.getWithSizesAndDtype( llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); Value transposeWeight = rewriter.create(loc, transposeType, weight); Value matmul = rewriter.create(loc, op.getType(), input, transposeWeight); if (bias.getType().isa()) { rewriter.replaceOp(op, matmul); return success(); } BaseTensorType biasType = bias.getType().cast(); if (!biasType.hasSizes() || biasType.getSizes().size() != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); Value alpha = rewriter.create(loc, rewriter.getF64FloatAttr(1)); rewriter.replaceOpWithNewOp(op, op.getType(), matmul, op.getBias(), alpha); return success(); } }; } // namespace namespace { // Decompose `aten.mish` op into `aten.tanh` and `aten.softplus` ops. // Mish(x) = x * Tanh(Softplus(x)) class DecomposeAtenMishOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMishOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Type type = op.getType(); auto inputType = input.getType().cast(); if (!inputType.hasDtype()) return rewriter.notifyMatchFailure(op, "Dtype not present"); Type dType = inputType.getDtype(); // Form default Value tensors for `beta` and `threshold` operands // of `aten.softplus` op. Value beta = getConstantWithGivenDtypeAndValue(rewriter, loc, 1.0, dType); Value threshold = getConstantWithGivenDtypeAndValue(rewriter, loc, 20.0, dType); Value softplusOp = rewriter.create(loc, type, input, beta, threshold); Value tanhOp = rewriter.create(loc, type, softplusOp); rewriter.replaceOpWithNewOp(op, type, input, tanhOp); return success(); } }; } // namespace namespace { // Decompose `aten.fullLike` op into `aten.emptyLike` and `aten.fill` ops. class DecomposeAtenFullLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFullLikeOp op, PatternRewriter &rewriter) const override { BaseTensorType outTy = op.getType().template cast(); if (!outTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Value fillVal = rewriter.create( op.getLoc(), tensorType, op.getFillValue()); fillVal = convertTensorToDtype(rewriter, op.getLoc(), fillVal, outTy.getDtype()); rewriter.replaceOpWithNewOp(op, op.getType(), fillVal, op.getSelf()); return success(); } }; } // namespace namespace { // Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op. class DecomposeAtenIndexPutOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIndexPutOp op, PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } }; } // namespace namespace { class DecomposeAtenExpandAsOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenExpandAsOp op, PatternRewriter &rewriter) const override { auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getOther()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), sizeList); return success(); } }; } // namespace namespace { // Decompose `aten.ToCopy` op into `valsem.aten.copy` op. class DecomposeAten_ToCopyOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ToCopyOp op, PatternRewriter &rewriter) const override { auto resultType = op.getType().cast(); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Type resultDtype = resultType.getDtype(); Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, resultDtype); Value emptyTensor = rewriter.create( op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, op.getType(), emptyTensor, op.getSelf(), op.getNonBlocking()); return success(); } }; } // namespace namespace { // Decompose `aten.copy` op into `aten.to.dtype` and `aten.expand_as`. class DecomposeAtenCopyOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCopyOp op, PatternRewriter &rewriter) const override { auto resultType = op.getType().cast(); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } auto srcTy = op.getSrc().getType().cast(); if (!srcTy.hasSizes() || !srcTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected src type to have a known rank and dtype"); } Type resultDtype = resultType.getDtype(); Value srcToDtype = convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype); rewriter.replaceOpWithNewOp(op, op.getType(), srcToDtype, op.getSelf()); return success(); } }; } // namespace namespace { // Decompose `aten.newEmpty` op into `aten.empty.memoryFormat` op. class DecomposeAtenNewEmptyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNewEmptyOp op, PatternRewriter &rewriter) const override { Value noneVal = rewriter.create(op.getLoc()); Value dtype = op.getDtype(); if (dtype.getType().isa()) { BaseTensorType tensorType = op.getSelf().getType().cast(); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a dtype"); } dtype = getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal); return success(); } }; } // namespace namespace { // Decompose `aten.indexPut.hackedTwin` op into `valsem.aten.indexPutImpl` // op. class DecomposeAtenIndexPutHackedTwinOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op, PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } }; } // namespace namespace { // Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl` // op. class DecomposeAten_UnsafeIndexPutHackedTwinOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op, PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } }; } // namespace namespace { // Decompose `aten.pad` op into `aten.constantPadNd` op. class DecomposeAtenPadOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenPadOp op, PatternRewriter &rewriter) const override { Value value = op.getValue(); if (value.getType().isa()) return rewriter.notifyMatchFailure(op, "optional type not supported"); if (value.getType().isa()) value = rewriter.create( op.getLoc(), rewriter.getF64FloatAttr(0)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getPad(), value); return success(); } }; } // namespace namespace { // Decompose `aten.to.dtypeLayout` op into `aten.to.dtype` op. class DecomposeAtenToDtypeLayoutOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op, PatternRewriter &rewriter) const override { // TODO: Add support for pinMemory arg equal to `True`. if (!op.getPinMemory().getType().isa()) { bool pinMemory; if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) return rewriter.notifyMatchFailure( op, "unimplemented: pinMemory must be a constant"); else if (pinMemory) return rewriter.notifyMatchFailure( op, "unimplemented: pinMemory is expected to be false"); } // TODO: Add support for non-None device arg. if (!op.getDevice().getType().isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: device arg must be None"); } // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. if (!op.getLayout().getType().isa()) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( op, "unimplemented: layout must be a constant"); else if (tensorLayout != torch_upstream::Layout::Strided) return rewriter.notifyMatchFailure( op, "unimplemented: layout is expected to be strided"); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); return success(); } }; } // namespace namespace { // Decompose `aten.to.device` op into `aten.to.dtype` op. class DecomposeAtenToDeviceOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenToDeviceOp op, PatternRewriter &rewriter) const override { // Device information isn't relevant to torch-mlir, so we can drop that info // here. rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); return success(); } }; } // namespace namespace { // Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. // The logic of this decomposition is totally same with // the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two // cases are supported: // 1. inputSize = outputSize // 2. outputSize = 1 class DecomposeAtenAdaptiveAvgPool1dOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op.getContext(); Value input = op.getSelf(); std::optional maybeRank = getTensorRank(input); if (!maybeRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned rank = *maybeRank; Value sizeDim = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 1)); Value inputSize = rewriter.create(loc, input, sizeDim); Value outputShape = op.getOutputSize(); SmallVector outputShapeSizesTorchInt; getListConstructElements(outputShape, outputShapeSizesTorchInt); Value outputSize = outputShapeSizesTorchInt[0]; Value constantOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value constantFalse = rewriter.create(loc, false); Value constantTrue = rewriter.create(loc, true); int64_t outputSizeInt; if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { return rewriter.notifyMatchFailure( op, "the output size of adaptive_pool_1d must be a constant int"); } SmallVector kernelSize; if (outputSizeInt == 1) { BaseTensorType inputTensorType = input.getType().cast(); ArrayRef inputShape = inputTensorType.getSizes(); kernelSize.push_back( inputShape[rank - 1] == kUnknownSize ? inputSize : rewriter.create( loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); } else { Value cond = rewriter.create(loc, inputSize, outputSize); rewriter.create( loc, cond, "unimplemented: only support cases where input and output size are " "equal for non-unit output size"); kernelSize.push_back(constantOne); } Value kernelSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); Value strideList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantOne}); Value paddingSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero}); rewriter.replaceOpWithNewOp( op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); return success(); } }; } // namespace namespace { // Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op. // // For AdaptiveAvgPool2d op, when the input size is an integer multiple of // output size the kernelSize, stride and padding is calculated as follows: // strideH = inH // outH // strideW = inH // outH // kernelH = inH - [(outH - 1) * strideH] // kernelW = inW - [(outW - 1) * strideW] // paddingH = 0, paddingW = 0 // // For the special case, when the output size is one for all dimensions, // the kernel size is same as the input size. class DecomposeAtenAdaptiveAvgPool2dOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value input = op.getSelf(); std::optional maybeRank = getTensorRank(input); if (!maybeRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned rank = *maybeRank; SmallVector inputHW; Value dimH = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 2)); inputHW.push_back( /*inH=*/rewriter.create(loc, input, dimH)); Value dimW = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 1)); inputHW.push_back( /*inW=*/rewriter.create(loc, input, dimW)); Value outputShape = op.getOutputSize(); SmallVector outputShapeSizesTorchInt; getListConstructElements(outputShape, outputShapeSizesTorchInt); // TODO: Add support for cases other than: // 1.) inH == outH and inW == outW. // 2.) outH == outW == 1 bool unitOutputSize = true; for (Value outShape : outputShapeSizesTorchInt) { int64_t outShapeInt; if (!matchPattern(outShape, m_TorchConstantInt(&outShapeInt))) { return rewriter.notifyMatchFailure( op, "output size is expected to be a constant"); } if (outShapeInt != 1) { unitOutputSize = false; break; } } Value constantOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value constantFalse = rewriter.create(loc, false); Value constantTrue = rewriter.create(loc, true); Value constantNone = rewriter.create(loc); SmallVector kernelSize; for (unsigned i = 0; i < inputHW.size(); i++) { if (unitOutputSize) { BaseTensorType inputTensorType = input.getType().cast(); ArrayRef inputShape = inputTensorType.getSizes(); kernelSize.push_back(inputShape[rank - 2 + i] == kUnknownSize ? inputHW[i] : rewriter.create( loc, rewriter.getI64IntegerAttr( inputShape[rank - 2 + i]))); } else { Value cond = rewriter.create(loc, inputHW[i], outputShapeSizesTorchInt[i]); rewriter.create( loc, cond, "unimplemented: only support cases where input and output size are " "equal for non-unit output size"); Value outMinusOne = rewriter.create( loc, outputShapeSizesTorchInt[i], constantOne); kernelSize.push_back( rewriter.create(loc, inputHW[i], outMinusOne)); } } Value kernelSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); // Currently we only support cases where input size is equal to the output // size or unit output size. For the former case, stride is always equal to // one and for the latter the stride value doesn't matter, since the kernel // size is same as the input size. Therfore, keeping the stride as one for // the latter case as well for the ease of implementation. Value strideList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantOne, constantOne}); Value paddingSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero, constantZero}); rewriter.replaceOpWithNewOp( op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, /*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue, /*divisorOverride=*/constantNone); return success(); } }; } // namespace namespace { // Decompose `aten.clampMin` op into `aten.clamp` op. class DecomposeAtenClampMinOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenClampMinOp op, PatternRewriter &rewriter) const override { Value constantNone = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getMin(), /*max=*/constantNone); return success(); } }; } // namespace namespace { // Decompose `aten.clampMax` op into `aten.clamp` op. class DecomposeAtenClampMaxOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenClampMaxOp op, PatternRewriter &rewriter) const override { Value constantNone = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), /*min=*/constantNone, op.getMax()); return success(); } }; } // namespace namespace { // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // `aten.add.Tensor` op. class DecomposeAtenBaddbmmOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBaddbmmOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value bmm = rewriter.create(loc, op.getType(), op.getBatch1(), op.getBatch2()); Value alphaTimesBmm = rewriter.create(loc, op.getType(), bmm, op.getAlpha()); Value input = op.getSelf(); BaseTensorType inputType = input.getType().cast(); BaseTensorType resultType = op->getResult(0).getType().cast(); if (inputType.hasDtype() && resultType.hasDtype() && inputType.getDtype() != resultType.getDtype()) { input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype()); } rewriter.replaceOpWithNewOp( op, op.getType(), alphaTimesBmm, op.getSelf(), op.getBeta()); return success(); } }; } // namespace namespace { // Decompose `aten.floorDivide` op into `aten.div.TensorMode` op. class DecomposeAtenFloorDivideOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFloorDivideOp op, PatternRewriter &rewriter) const override { // https://pytorch.org/docs/stable/generated/torch.floorDivide.html // PyTorch aten.floorDivide is a misnomer because it actually rounds // the quotient towards zero instead of taking its floor. Value cstStrFloor = rewriter.create(op.getLoc(), "trunc"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getOther(), /*roundingMode=*/cstStrFloor); return success(); } }; } // namespace namespace { // Decompose `aten.numpyT` op into `aten.permute` op. class DecomposeAtenNumpyTOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNumpyTOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); std::optional maybeInputRank = getTensorRank(self); if (!maybeInputRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; SmallVector dimListElements; SmallVector dimListInts(llvm::reverse( llvm::iota_range(0, inputRank, /*inclusive=*/false))); for (int dimListInt : dimListInts) { dimListElements.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(dimListInt))); } Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), dimListElements); rewriter.replaceOpWithNewOp(op, op.getType(), self, dimList); return success(); } }; } // namespace template static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, bool unbiased, double correction) { Location loc = op.getLoc(); Value self = op.getSelf(); Value dimList = op.getDim(); Value keepDim = op.getKeepdim(); BaseTensorType inputTensorTy = self.getType().cast(); Type outputType = op.getType(); BaseTensorType outputTensorType = outputType.cast(); if (!outputTensorType.hasDtype()) { return rewriter.notifyMatchFailure(op, "expected result type to have a dtype"); } Type newOutputType = outputTensorType.getWithSizesAndDtype( outputTensorType.getSizes(), rewriter.getF64Type()); if (!inputTensorTy.hasDtype() || !inputTensorTy.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "support floating-point type input only"); } // Upcasting the input tensor to `F64` dtype for higher precision during the // computation of the result. if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) { self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); inputTensorTy = self.getType().cast(); } std::optional maybeInputRank = getTensorRank(self); if (!maybeInputRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; SmallVector dimListElements; bool isNoneOrEmpty = true; if (!dimList.getType().template isa()) { if (!getListConstructElements(dimList, dimListElements)) return rewriter.notifyMatchFailure( op, "expect dimList to be constructed from list construct"); if (!dimListElements.empty() || inputRank == 0) isNoneOrEmpty = false; } if (isNoneOrEmpty) { for (unsigned i = 0; i < inputRank; i++) dimListElements.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dimListElements); } Type meanDimResultType = inputTensorTy; for (unsigned i = 0; i < dimListElements.size(); i++) meanDimResultType = computeReductionType( rewriter, op, meanDimResultType.cast(), dimListElements[i], /*keepDim=*/true); Value constantNone = rewriter.create(loc); Value constantTrue = rewriter.create(loc, true); Value meanAlongDims = rewriter.create( loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, /*dtype=*/constantNone); Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); Value square = rewriter.create(loc, inputTensorTy, subMean); if (!unbiased) { Value result = rewriter.create( loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); rewriter.replaceOp(op, result); return success(); } // Divide the square sum by productDimSize - correction. Value squareSum = rewriter.create( loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); // `productDimSize` is product of sizes of dimensions to be reduced. Value constantOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value productDimSize = constantOne; for (Value dim : dimListElements) { Value dimSize = rewriter.create(loc, self, dim); productDimSize = rewriter.create(loc, productDimSize, dimSize); } productDimSize = rewriter.create(loc, productDimSize); constantOne = rewriter.create( loc, rewriter.getF64FloatAttr(1.0)); Value cstCorrection = rewriter.create( loc, rewriter.getF64FloatAttr(correction)); // The `correction` value should be less than or equal to `productDimSize + // 1`. Value productDimSizePlusOne = rewriter.create( loc, productDimSize.getType(), productDimSize, constantOne); Value cond = rewriter.create(loc, productDimSizePlusOne, cstCorrection); rewriter.create( loc, cond, "correction value should be less than or equal to productDimSize + 1"); Value productDimSizeSubCorrection = rewriter.create(loc, productDimSize, cstCorrection); Value result = rewriter.create(loc, newOutputType, squareSum, productDimSizeSubCorrection); result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); rewriter.replaceOp(op, result); return success(); } // Decompose aten.var(x, dims) into: // sub = aten.sub(x, aten.mean(x, dims)) // square = aten.square(sub) // For Unbiased case: // out = aten.sum(square, dims) / (productDimSize-1) // For Biased case: // out = aten.mean(square, dims) namespace { class DecomposeAtenVarDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarDimOp op, PatternRewriter &rewriter) const override { bool unbiased; if (!matchPattern(op.getUnbiased(), m_TorchConstantBool(&unbiased))) { return rewriter.notifyMatchFailure( op, "Only support constant unbiased for aten.var"); } double correction = unbiased ? 1.0 : 0.0; if (failed(calculateVariance(op, rewriter, unbiased, correction))) return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return success(); } }; } // namespace // Decompose aten.var(x, dims) into: // sub = aten.sub(x, aten.mean(x, dims)) // square = aten.square(sub) // out = aten.sum(square, dims) / (productDimSize - correction) namespace { class DecomposeAtenVarCorrectionOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarCorrectionOp op, PatternRewriter &rewriter) const override { int64_t correctionValInt; double correctionValFloat = 1.0; if (!op.getCorrection().getType().isa()) { if (op.getCorrection().getType().isa()) { if (!matchPattern(op.getCorrection(), m_TorchConstantFloat(&correctionValFloat))) return rewriter.notifyMatchFailure( op, "Only support constant int or float correction value for " "aten.var"); } else if (op.getCorrection().getType().isa()) { if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correctionValInt))) return rewriter.notifyMatchFailure( op, "Only support constant int or float correction value for " "aten.var"); correctionValFloat = (double)correctionValInt; } else { return rewriter.notifyMatchFailure( op, "unimplemented: correction value should be only constant int " "or float for aten.var"); } } bool unbiased = correctionValFloat == 0.0 ? false : true; if (failed(calculateVariance(op, rewriter, unbiased, correctionValFloat))) return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return success(); } }; } // namespace namespace { // Decompose the `aten.selectScatter` operation into `aten.sliceScatter` op. class DecomposeAtenSelectScatterOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSelectScatterOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value start = op.getIndex(); Value dim = op.getDim(); Value self = op.getSelf(); Value src = op.getSrc(); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusOne = rewriter.create(loc, one.getType(), start, one); auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim); if (failed(unsqueezedInfo)) { return rewriter.notifyMatchFailure(op, "cannot generate unsqueeze tensor op"); } src = *unsqueezedInfo; rewriter.replaceOpWithNewOp( op, op.getSelf().getType(), self, src, dim, start, startPlusOne, /*step=*/one); return success(); } }; } // namespace namespace { class DecomposeAten_EmbeddingBagOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op, PatternRewriter &rewriter) const override { Value weight = op.getWeight(); Value indices = op.getIndices(); Value offsets = op.getOffsets(); Value scaleGradByFreq = op.getScaleGradByFreq(); Value mode = op.getMode(); Value sparse = op.getSparse(); Value perSampleWeights = op.getPerSampleWeights(); Value includeLastOffset = op.getIncludeLastOffset(); Value paddingIdx = op.getPaddingIdx(); auto resultType0 = op->getResult(0).getType(); auto resultType1 = op->getResult(1).getType(); auto resultType2 = op->getResult(2).getType(); auto resultType3 = op->getResult(3).getType(); mlir::TypeRange returnTypes{resultType0, resultType1, resultType2, resultType3}; rewriter.replaceOpWithNewOp( op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode, sparse, perSampleWeights, includeLastOffset, paddingIdx); return success(); } }; } // namespace namespace { // Decompose `aten.liftFreshCopy` op into `aten.clone` op. class DecomposeAtenLiftFreshCopyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLiftFreshCopyOp op, PatternRewriter &rewriter) const override { Value constantNone = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), /*memoryFormat=*/constantNone); return success(); } }; } // namespace namespace { class DecomposeAtenMseLossOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMseLossOp op, PatternRewriter &rewriter) const override { // The `reduction` arg would have only three valid values. // 0 means no reduction. // 1 means mean reduction. // 2 means sum reduction. int64_t reductionType; if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reductionType))) return rewriter.notifyMatchFailure( op, "Expected a constant integer value for reduction"); Location loc = op.getLoc(); BaseTensorType resultType = op.getType().cast(); BaseTensorType inputType = op.getSelf().getType().cast(); if (!inputType.hasSizes()) return rewriter.notifyMatchFailure( op, "Expected the input tensor to have sizes"); BaseTensorType subType = inputType .getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), resultType.getOptionalDtype()) .cast(); Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); Value result = rewriter.create(loc, subType, sub); if (reductionType == torch_upstream::Reduction::None) { rewriter.replaceOp(op, result); return success(); } Value cstFalse = rewriter.create(loc, false); Value cstNone = rewriter.create(loc); if (reductionType == torch_upstream::Reduction::Mean) result = rewriter.create(loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse, /*dtype=*/cstNone); else result = rewriter.create( loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse, /*dtype=*/cstNone); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { // Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op class DecomposeAtenNormScalarOptDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNormScalarOptDimOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value none = rewriter.create(loc); Value ord = op.getP(); if (ord.getType().isa()) { ord = rewriter.create( loc, rewriter.getF64FloatAttr(2.0)); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(), /*dtype=*/none); return success(); } }; } // namespace namespace { class DecomposeAtenRandintLowOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandintLowOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type resultType = op.getType(); BaseTensorType resultTensorType = resultType.cast(); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } int64_t cstLow, cstHigh; if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow))) return rewriter.notifyMatchFailure( op, "unimplemented: low must be a constant integer"); if (!matchPattern(op.getHigh(), m_TorchConstantInt(&cstHigh))) return rewriter.notifyMatchFailure( op, "unimplemented: high must be a constant integer"); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)cstLow)); Value high = rewriter.create( loc, rewriter.getF64FloatAttr((double)cstHigh)); BaseTensorType floatResultType = resultTensorType .getWithSizesAndDtype(resultTensorType.getSizes(), rewriter.getF32Type()) .cast(); Value emptyTensor = rewriter.create( loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(), /*memoryFormat=*/none); Value result = rewriter.create(loc, floatResultType, emptyTensor, /*from=*/low, /*to=*/high, /*generator=*/none); rewriter.replaceOpWithNewOp( op, resultType, result, getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()), /*nonBlocking=*/cstFalse, /*copy=*/cstFalse, /*memoryFormat=*/none); return success(); } }; } // namespace namespace { class DecomposeAtenRandintOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandintOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type resultType = op.getType(); Value low = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp( op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose `aten.varMean.correction` op into `aten.var.correction` and // `aten.mean.dim` op. class DecomposeAtenVarMeanCorrectionOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarMeanCorrectionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value noneVal = rewriter.create(loc); Value var = rewriter.create( loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), op.getKeepdim()); Value mean = rewriter.create(loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } }; } // namespace namespace { // Decompose `prims.convertElementType` op into `aten.to.dtype` op. class DecomposePrimsConvertElementTypeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsConvertElementTypeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value cstFalse = rewriter.create(loc, false); Value cstNone = rewriter.create(loc); rewriter.replaceOpWithNewOp( op, op.getType(), op.getA(), op.getDtype(), /*nonBlocking=*/cstFalse, /*copy=*/cstFalse, /*memoryFormat=*/cstNone); return success(); } }; } // namespace namespace { // Decompose `prims.var` op into `aten.var.correction` op. class DecomposePrimsVarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsVarOp op, PatternRewriter &rewriter) const override { if (!op.getOutputDtype().getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for prims::var op"); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getInp(), op.getDims(), op.getCorrection(), /*keepdim=*/cstFalse); return success(); } }; } // namespace namespace { // Decompose `prims.sqrt` op into `aten.sqrt` op. class DecomposePrimsSqrtOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsSqrtOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); return success(); } }; } // namespace namespace { // The op is decomposed using the Box-Muller transform. // Refer: https://en.wikipedia.org/wiki/Box-Muller_transform class DecomposeAtenRandnGeneratorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandnGeneratorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resultType = op.getType().cast(); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); Value none = rewriter.create(loc); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)0.0)); Value high = rewriter.create( loc, rewriter.getF64FloatAttr((double)1.0)); Value cstMinusTwo = rewriter.create( loc, rewriter.getF64FloatAttr((double)-2.0)); Value cstTwoPie = rewriter.create( loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159))); Value emptyTensorA = rewriter.create( loc, resultType, op.getSize(), /*dtype=*/dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); Value emptyTensorB = rewriter.create( loc, resultType, op.getSize(), /*dtype=*/dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); Value uOne = rewriter.create(loc, resultType, emptyTensorA, /*from=*/low, /*to=*/high, /*generator=*/op.getGenerator()); Value uTwo = rewriter.create(loc, resultType, emptyTensorB, /*from=*/low, /*to=*/high, /*generator=*/op.getGenerator()); Value logUOne = rewriter.create(loc, resultType, uOne); Value minusTwoLogUOne = rewriter.create(loc, resultType, logUOne, cstMinusTwo); Value r = rewriter.create(loc, resultType, minusTwoLogUOne); Value theta = rewriter.create(loc, resultType, uTwo, cstTwoPie); Value cosTheta = rewriter.create(loc, resultType, theta); rewriter.replaceOpWithNewOp(op, op.getType(), r, cosTheta); return success(); } }; } // namespace namespace { // Decompose `aten.randn` op into `aten.randn.generator` op. class DecomposeAtenRandnOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandnOp op, PatternRewriter &rewriter) const override { Value none = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSize(), /*generator=*/none, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose `aten.randn_like` op into `aten.randn.generator` op. class DecomposeAtenRandnLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandnLikeOp op, PatternRewriter &rewriter) const override { // Only `none`, `contiguous` and `preserve` memory_format is supported. if (!op.getMemoryFormat().getType().isa()) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( op, "unimplemented: the memory format should be specified in " "an integer constant"); if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::Preserve) return rewriter.notifyMatchFailure( op, "unimplemented: only none, contiguous and preserve " "memory_format is supported"); } Value none = rewriter.create(op.getLoc()); auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getSelf()); rewriter.replaceOpWithNewOp( op, op.getType(), sizeList, /*generator=*/none, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { class DecomposeAtenRandOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resultType = op.getType().cast(); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Value noneVal = rewriter.create(loc); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)0.0)); Value high = rewriter.create( loc, rewriter.getF64FloatAttr((double)1.0)); Value emptyTensor = rewriter.create( loc, resultType, op.getSize(), /*dtype=*/op.getDtype(), /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/noneVal); rewriter.replaceOpWithNewOp(op, resultType, emptyTensor, /*from=*/low, /*to=*/high, /*generator=*/noneVal); return success(); } }; } // namespace namespace { class DecomposeAtenVarMeanOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarMeanOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value falseVal = rewriter.create(loc, false); Value noneVal = rewriter.create(loc); Value var = rewriter.create(loc, op.getType(0), op.getSelf(), /*dim=*/noneVal, op.getUnbiased(), /*keepdim=*/falseVal); Value mean = rewriter.create(loc, op.getType(0), op.getSelf(), /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } }; } // namespace namespace { class DecomposeAtenNewEmptyStridedOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op, PatternRewriter &rewriter) const override { SmallVector sizeListInts, strideListInts; if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) return rewriter.notifyMatchFailure( op, "all size list elements must be constant ints"); if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideListInts))) return rewriter.notifyMatchFailure( op, "all stride list elements must be constant ints"); // We only support the cases with default stride values. // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and // stride[2] == 1. bool isDefaultStride = true; for (unsigned i = 0; i < strideListInts.size(); i++) { int64_t defaultStride = 1; for (unsigned j = i + 1; j < sizeListInts.size(); j++) defaultStride *= sizeListInts[j]; if (defaultStride != strideListInts[i]) { isDefaultStride = false; break; } } if (!isDefaultStride) return rewriter.notifyMatchFailure( op, "only default strides supported for new_empty_strided op"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { class DecomposePrimsSqueezeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsSqueezeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getA(); SmallVector dimensions; if (!matchPattern(op.getDimensions(), m_TorchListOfConstantInts(dimensions))) return rewriter.notifyMatchFailure( op, "all dimensions must be constant ints"); std::sort(dimensions.begin(), dimensions.end()); std::reverse(dimensions.begin(), dimensions.end()); if (dimensions.size() == 0) { rewriter.replaceOp(op, input); return success(); } Value result = input; for (unsigned i = 0; i < dimensions.size(); i++) { auto squeezeTensorInfo = squeezeTensor(rewriter, op, loc, dimensions[i], result); if (failed(squeezeTensorInfo)) { return rewriter.notifyMatchFailure(op, "cannot generate unsqueeze tensor"); } result = *squeezeTensorInfo; } rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class DecomposeAtenMovedimIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMovedimIntOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a rank"); } unsigned inputRank = *maybeInputRank; if (inputRank <= 1) { rewriter.replaceOp(op, input); return success(); } int64_t srcDimInt, dstDimInt; if (matchPattern(op.getSource(), m_TorchConstantInt(&srcDimInt))) { srcDimInt = toPositiveDim(srcDimInt, inputRank); if (!isValidDim(srcDimInt, inputRank)) return rewriter.notifyMatchFailure(op, "source is not a valid dim"); } else { return rewriter.notifyMatchFailure(op, "source is not a constant int"); } if (matchPattern(op.getDestination(), m_TorchConstantInt(&dstDimInt))) { dstDimInt = toPositiveDim(dstDimInt, inputRank); if (!isValidDim(dstDimInt, inputRank)) return rewriter.notifyMatchFailure(op, "destination is not a valid dim"); } else { return rewriter.notifyMatchFailure(op, "destination is not a constant int"); } SmallVector dimsOrder = computeDimsOrderForMoveDim(srcDimInt, dstDimInt, inputRank); SmallVector cstDimsOrder; for (int64_t dim : dimsOrder) cstDimsOrder.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(dim))); Value permuteDimsOrder = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), cstDimsOrder); rewriter.replaceOpWithNewOp(op, op.getType(), input, permuteDimsOrder); return success(); } }; } // namespace namespace { class DecomposeAtenCrossEntropyLossOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCrossEntropyLossOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); Value target = op.getTarget(); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure( op, "Unimplemented: unranked input tensor"); unsigned selfRank = maybeRank.value(); maybeRank = getTensorRank(target); if (!maybeRank) return rewriter.notifyMatchFailure( op, "Unimplemented: unranked target tensor"); unsigned targetRank = maybeRank.value(); // When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d // of the form [minibatch] the cross entropy loss decomposes to the // combination of softmax and nll loss as follows: // cross_entropy_loss = NLLLoss(LogSoftmax(input, dim=1), target) // Currently, we only support the above-mentioned case. if (selfRank != 2 || targetRank != 1) { return rewriter.notifyMatchFailure( op, "unimplemented: only support cases with 2-d input and 1-d target"); } // TODO: Add support for label_smoothing value other than 0.0 (default // value). double labelSmoothing; if (!matchPattern(op.getLabelSmoothing(), m_TorchConstantFloat(&labelSmoothing))) { return rewriter.notifyMatchFailure( op, "Only support constant float label_smoothing value"); } else if (labelSmoothing != 0.0) { return rewriter.notifyMatchFailure(op, "unimplemented: only support default " "value of 0.0 for label_smoothing"); } Value noneVal = rewriter.create(loc); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value logSoftmax = rewriter.create( loc, self.getType(), self, dim, /*dtype=*/noneVal); Value nllLoss = rewriter .create( loc, op.getType(), target.getType(), logSoftmax, target, op.getWeight(), op.getReduction(), op.getIgnoreIndex()) ->getResult(0); rewriter.replaceOp(op, nllLoss); return success(); } }; } // namespace namespace { class DecomposeAtenOneHotOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenOneHotOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto context = op.getContext(); Value input = op.getSelf(); auto inputType = input.getType().cast(); if (!inputType.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); int64_t inputRank = inputType.getSizes().size(); int64_t numClasses; if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses))) return rewriter.notifyMatchFailure( op, "unimplemented: num_classes must be constant"); Value none = rewriter.create(loc); // arange tensor auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); auto arangeType = ValueTensorType::get(context, llvm::ArrayRef(numClasses), si64Type); Value arangeTensor = rewriter.create( loc, arangeType, op.getNumClasses(), /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); // unsqueeze input llvm::SmallVector unsqueezeShape(inputType.getSizes()); unsqueezeShape.push_back(1); auto unsqueezeType = ValueTensorType::get(context, unsqueezeShape, si64Type); Value unsqueezeTensor = rewriter.create( loc, unsqueezeType, input, rewriter.create(loc, rewriter.getI64IntegerAttr(inputRank))); // compare auto eqType = ValueTensorType::get( context, op.getType().cast().getSizes(), IntegerType::get(context, 1)); Value eqTensor = rewriter.create( loc, eqType, unsqueezeTensor, arangeTensor); // convert to si64 Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { // Decompose `aten.var_mean.dim` op into `aten.var.dim` and // `aten.mean.dim` op. class DecomposeAtenVarMeanDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarMeanDimOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value noneVal = rewriter.create(loc); Value var = rewriter.create(loc, op.getType(0), op.getSelf(), op.getDim(), op.getUnbiased(), op.getKeepdim()); Value mean = rewriter.create( loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } }; } // namespace namespace { // decompose aten.scalar_tensor to prim.NumToTensor.Scalar and // aten.to.dtype_layout class DecomposeAtenScalarTensor : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenScalarTensorOp op, PatternRewriter &rewriter) const override { auto resultTy = op.getResult().getType().cast(); auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType()); Value numToTensor = rewriter.create( op.getLoc(), resultTy.getWithSizesAndDtype(resultTy.getOptionalSizes(), scalarTy), op.getS()); Value cstNone = rewriter.create(op.getLoc()); Value cstFalse = rewriter.create(op.getLoc(), false); Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), resultTy.getDtype()); Value toDTypeLayout = rewriter.create( op.getLoc(), op.getType(), numToTensor, dtype, op.getLayout(), op.getDevice(), op.getPinMemory(), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/cstNone); rewriter.replaceOp(op, toDTypeLayout); return success(); } }; } // namespace namespace { // Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op. class DecomposeAtenTopkOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTopkOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto context = op.getContext(); bool sorted; if (!matchPattern(op.getSorted(), m_TorchConstantBool(&sorted))) return rewriter.notifyMatchFailure( op, "Expected a constant boolean value for sorted"); if (!sorted) return rewriter.notifyMatchFailure( op, "unimplemented: sorted value arg must be set to True"); Value self = op.getSelf(); Value dim = op.getDim(); auto selfType = self.getType().cast(); auto sortIndicesType = selfType.getWithSizesAndDtype( selfType.getOptionalSizes(), IntegerType::get(context, 64, IntegerType::Signed)); auto sortOpResult = rewriter.create( loc, self.getType(), sortIndicesType, self, dim, /*descending=*/op.getLargest()); Value start = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value step = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value resultValue = rewriter.create( loc, op->getResultTypes()[0], sortOpResult->getResult(0), dim, start, /*end=*/op.getK(), step); Value resultIndices = rewriter.create( loc, op->getResultTypes()[1], sortOpResult->getResult(1), dim, start, /*end=*/op.getK(), step); rewriter.replaceOp(op, {resultValue, resultIndices}); return success(); } }; } // namespace namespace { // Decompose `aten.scatter.value` op into `aten.scatter.src` op. class DecomposeAtenScatterValueOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenScatterValueOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value self = op.getSelf(); Value index = op.getIndex(); std::optional maybeIndexRank = getTensorRank(index); if (!maybeIndexRank) { return rewriter.notifyMatchFailure( op, "expected index tensor to have a rank"); } unsigned indexRank = *maybeIndexRank; SmallVector sizes; for (int64_t i = 0; i < indexRank; ++i) { Value dim = rewriter.create(loc, rewriter.getI64IntegerAttr(i)); sizes.push_back(rewriter.create(loc, index, /*dim=*/dim)); } Value sizeList = rewriter.create( loc, ListType::get(IntType::get(context)), sizes); auto selfType = self.getType().cast(); auto indexType = index.getType().cast(); BaseTensorType srcType = selfType .getWithSizesAndDtype(indexType.getOptionalSizes(), selfType.getOptionalDtype()) .cast(); Value src = createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList); rewriter.replaceOpWithNewOp(op, op.getType(), self, op.getDim(), index, src); return success(); } }; } // namespace namespace { // Decompose `aten.sign` op into comparisons and aten.where. class DecomposeAtenSignOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSignOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto outType = op.getType().dyn_cast(); if (!outType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); auto zero = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); auto one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); auto minusOne = rewriter.create(loc, rewriter.getF64FloatAttr(-1.0)); auto compTy = outType.getWithSizesAndDtype(outType.getOptionalSizes(), rewriter.getI1Type()); auto greater = rewriter.create(loc, compTy, op.getSelf(), zero); auto greaterEqual = rewriter.create(loc, compTy, op.getSelf(), zero); // Pseudo code: // if (in >= 0) // if (in > 0) // return 1 // else // return 0 // else // return -1 auto selectGreater = rewriter.create(loc, outType, greater, one, zero); rewriter.replaceOpWithNewOp(op, outType, greaterEqual, selectGreater, minusOne); return success(); } }; } // namespace namespace { // Unconditionally decompose `torch.type_as` into `prim.dtype` + // `torch.to.dtype`. class DecomposeAtenTypeAsOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTypeAsOp op, PatternRewriter &rewriter) const override { auto input = op.getSelf(); auto other = op.getOther(); Location loc = op.getLoc(); Value targetDtype = rewriter.create(loc, other); Value nonBlocking = rewriter.create(loc, false); Value copy = rewriter.create(loc, false); Value memoryFormat = rewriter.create(loc); rewriter.replaceOpWithNewOp( op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat); return success(); } }; } // namespace // AtenIndexTensorOp namespace { // The goal of this pattern is to eliminate none index in aten.Index.Tensor's // `indices` param for the ease of various backend. The detailed steps are: // 1. reorder input tensor so that the non-none index appears at adjacent // positions. // 2. manually generate index tensor with some ops like iota, to replace the // none index in `indices` // 3. replace the old aten.Index.Tensor with a new // aten.Index.Tensor_hacked_twin. class DecomposeAtenIndexTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; // TODO: It might be better to use aten.view op instead of mulitple // aten.unsqueeze. But currently, torch-to-linalg pass has limited support for // view on dynamic shapes, such as [?] -> [?,1,1,1]. Using aten.view op will // cause relevant e2e tests fail. static FailureOr unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter, Value input, int count) { Location loc = op->getLoc(); Value constMinusOne = rewriter.create( loc, rewriter.getI64IntegerAttr(-1)); Value result = input; while (count--) { auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne); if (failed(unsqzTensorInfo)) { return failure(); } result = *unsqzTensorInfo; } return result; } static Value createIndexToReplaceNone(Operation *op, PatternRewriter &rewriter, Value input, int dimInt, int64_t dimSize) { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); Value none = rewriter.create(loc); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto resultType = ValueTensorType::get( context, {dimSize}, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto dim = rewriter.create( loc, rewriter.getI64IntegerAttr(dimInt)); auto end = rewriter.create(loc, input, dim); auto v = rewriter.create( loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); return v; } LogicalResult matchAndRewrite(AtenIndexTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); SmallVector indices; if (!getListConstructElements(op.getIndices(), indices)) return rewriter.notifyMatchFailure(op, "failed to get elements of `indices`"); auto input = op.getSelf(); auto inputType = input.getType().cast(); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only input with shape information is supported"); } auto inputSizes = inputType.getSizes(); int64_t inputRank = inputSizes.size(); auto outputType = op.getType().cast(); if (!outputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only output with shape information is supported"); } auto outputRank = outputType.getSizes().size(); auto isTensor = [](Value v) { return v.getType().isa(); }; // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin if (llvm::all_of(indices, isTensor)) { if (indices.size() == 0) { return rewriter.notifyMatchFailure( op, "the indices is empty, it should be folded as a nop"); } // By default, we regard the first index type as the list element type. auto indexElemType = indices[0] .getType() .template cast() .getWithSizesAndDtype(std::nullopt, nullptr); auto newIndex = rewriter.create( loc, Torch::ListType::get(indexElemType), indices); rewriter.replaceOpWithNewOp(op, op.getType(), input, newIndex); return success(); } SmallVector indexUsed = llvm::to_vector(llvm::map_range(indices, isTensor)); for (size_t i = indices.size(); i < inputRank; ++i) indexUsed.emplace_back(false); bool indexIsConsecutive = true; int64_t firstUsedIndex = -1; for (size_t i = 0; i < indices.size(); ++i) { if (indexUsed[i] && firstUsedIndex == -1) { firstUsedIndex = i; } else if (indexUsed[i] && !indexUsed[i - 1]) { indexIsConsecutive = false; break; } } // use aten.permute to reorder the input Value newInput; // `dims` stores the mapping from new index to the old index of input // tensor. SmallVector dims; if (!indexIsConsecutive) { SmallVector dimValues; SmallVector permutedSizes; for (int i = 0; i < inputRank; i++) { if (indexUsed[i]) { dims.emplace_back(i); dimValues.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); } } for (int i = 0; i < inputRank; i++) { if (!indexUsed[i]) { dims.emplace_back(i); dimValues.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); } } auto dimValueList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), dimValues); newInput = rewriter.create( loc, inputType.getWithSizesAndDtype(permutedSizes, inputType.getOptionalDtype()), input, dimValueList); } else { newInput = input; for (int i = 0; i < inputRank; i++) { dims.emplace_back(i); } } // manually generate new indices. SmallVector listElements(inputRank); int64_t trailingDimCnt = 0; int64_t i; // handle trailing none index. for (i = inputRank - 1; i >= 0; --i) { int64_t oldI = dims[i]; if (indexUsed[oldI]) break; Value v = createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, trailingDimCnt); if (failed(vInfo)) { return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); } listElements[i] = *vInfo; trailingDimCnt++; } // handle non-none index in between. for (; i >= 0; --i) { int64_t oldI = dims[i]; if (!indexUsed[oldI]) break; auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, indices[oldI], trailingDimCnt); if (failed(vInfo)) { return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); } listElements[i] = *vInfo; } // handle possible leading none dimensions. for (; i >= 0; --i) { int64_t oldI = dims[i]; if (indexUsed[oldI]) { return rewriter.notifyMatchFailure( op, "the indices are still unconsecutive after reordering input " "tensor"); } Value v = createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, outputRank - 1 - i); if (failed(vInfo)) { return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); } listElements[i] = *vInfo; } auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr); auto newIndexList = rewriter.create( loc, Torch::ListType::get(listElemType), listElements); rewriter.replaceOpWithNewOp( op, op.getType(), newInput, newIndexList); return success(); } }; } // namespace namespace { // Unconditionally decompose `aten.tile` into `aten.repeat`. class DecomposeAtenTileOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTileOp op, PatternRewriter &rewriter) const override { auto input = op.getSelf(); auto repeats = op.getDims(); SmallVector dimsElements; if (!getListConstructElements(repeats, dimsElements)) { return rewriter.notifyMatchFailure( op, "failed to get elements of `dims` param"); } auto dimsSize = dimsElements.size(); auto inputType = input.getType().cast(); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only support input tensor with shape information"); } auto inputRank = inputType.getSizes().size(); if (dimsSize < inputRank) { auto constantOne = rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(1)); for (auto i = dimsSize; i < inputRank; ++i) { dimsElements.insert(dimsElements.begin(), constantOne); } repeats = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), dimsElements); } rewriter.replaceOpWithNewOp(op, op.getType(), input, repeats); return success(); } }; } // namespace namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { private: llvm::StringSet<> legalOpsSet; template void addPatternIfTargetOpIsIllegal(RewritePatternSet &patterns) { MLIRContext *context = &getContext(); std::optional opName = DecomposePattern(context).getRootKind(); // Because the `DecomposeComplexOpsPass` uses a greedy algorithm // to apply patterns, only patterns that we for sure know we want to run // must be added. This restricts the set of patterns allowed in this file to // patterns that apply to a single op. In other words, patterns that match // on `Operation *` are not allowed, since there is no way of telling if // that pattern will match on an op in the `legalOpsSet` or not. assert(opName && "All decomposition patterns must target a single op"); if (!legalOpsSet.contains(opName->getStringRef().ltrim(kTorchOpPrefix))) patterns.add(context); } public: DecomposeComplexOpsPass() = default; DecomposeComplexOpsPass(ArrayRef legalOps) { this->legalOps = legalOps; } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); // The strings in the `legalOps` ArrayRef don't exist during the call to the // constructor `DecomposeComplexOpsPass`, so the creation of the // `legalOpsSet` must be delayed to when `runOnOperation` gets called. legalOpsSet.clear(); legalOpsSet.insert(legalOps.begin(), legalOps.end()); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>( patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenBernoulliLikeOp>( patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenBernoulliLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorNewLikeOp>( patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorNewLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createDecomposeComplexOpsPass( ArrayRef legalOps) { return std::make_unique(legalOps); }