//===------------------------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with // sinceVersion==1. We interpret this to include the following spec // diffs that are irrelevant to this level of lowering: // * Supported element types. // * Limited broadcasting to full broadcasting support. // // There are a lot of spec revisions that basically generalized elementwise // to be more normal and a direct translation vs a special case. This // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { patterns.onOp( "HardSigmoid", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensorOperand; float alpha, beta; if (binder.tensorOperand(tensorOperand) || binder.f32FloatAttr(alpha, "alpha", 0.2f) || binder.f32FloatAttr(beta, "beta", 0.5f) || binder.tensorResultType(resultType)) return failure(); // HardSigmoid computes the following expression: // max(0, min(1, alpha * x + beta)) Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); // Expression: alpha * x + beta Value alpha_x_plus_beta = rewriter.create( binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha); // Expression: min(1, alpha * x + beta) Value constantOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantOne); Value minExpression = rewriter.create( binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); // Expression: max(0, min(1, alpha * x + beta)) Value constantZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantZero); rewriter.replaceOpWithNewOp( binder.op, resultType, zeroTensor, minExpression); return success(); }); patterns.onOp( "Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value operand; Torch::ValueTensorType resultType; std::string approximate; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(approximate, "approximate", "none")) return failure(); Value vApproximate = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getStringAttr(approximate)); rewriter.replaceOpWithNewOp(binder.op, resultType, operand, vApproximate); return success(); }); patterns.onOp( "GridSample", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; Value grid; if (binder.tensorOperandAtIndex(input, 0) || binder.tensorOperandAtIndex(grid, 1) || binder.tensorResultType(resultType)) return rewriter.notifyMatchFailure( binder.op, "operand grid_sampler bind failure"); auto inputTensorType = cast(input.getType()); ArrayRef inputShape = inputTensorType.getSizes(); uint32_t inputRank = inputShape.size(); auto gridTensorType = cast(grid.getType()); ArrayRef gridShape = gridTensorType.getSizes(); uint32_t gridRank = gridShape.size(); if (inputRank != 4) return rewriter.notifyMatchFailure(binder.op, "only input rank 4 supported"); if (gridRank != 4) return rewriter.notifyMatchFailure(binder.op, "only grid rank 4 supported"); if (inputShape[0] != gridShape[0]) return rewriter.notifyMatchFailure( binder.op, "N must be same for input and grid"); if (gridShape[3] != 2) return rewriter.notifyMatchFailure(binder.op, "gridShape[3] expected to be 2"); std::string iModeString; int64_t iModeInt; if (binder.customOpNameStringAttr(iModeString, "mode", "linear")) return rewriter.notifyMatchFailure(binder.op, "mode bind failure"); if (iModeString == "linear" || iModeString == "bilinear") { iModeInt = 0; } else if (iModeString == "nearest") { iModeInt = 1; } else { return rewriter.notifyMatchFailure( binder.op, "currently only mode : linear and nearest supported"); } std::string padding; if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) return rewriter.notifyMatchFailure(binder.op, "padding_mode bind failure"); if (padding != "zeros") return rewriter.notifyMatchFailure( binder.op, "currently only padding_mode : zeros supported"); int64_t align; if (binder.s64IntegerAttr(align, "align_corners", 0)) return rewriter.notifyMatchFailure(binder.op, "align_corners bind failure"); Value interpolationMode = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt)); Value paddingMode = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); bool alignMode = align; Value alignCorners = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(alignMode)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, grid, interpolationMode, paddingMode, alignCorners); return success(); }); patterns.onOp( "If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value conditionTensor; if (binder.tensorOperand(conditionTensor)) { return rewriter.notifyMatchFailure(binder.op, "condition bind failure"); } auto conditionType = conditionTensor.getType().cast(); if (!conditionType || conditionType.getSizes().size() != 1) return rewriter.notifyMatchFailure( binder.op, "condition must have one single element per " "https://onnx.ai/onnx/operators/onnx__If.html"); auto conditionInt = rewriter.create( binder.getLoc(), rewriter.getType(), conditionTensor); auto conditionBool = rewriter.create( binder.getLoc(), rewriter.getType(), conditionInt); llvm::SmallVector resultTypes; if (binder.tensorResultTypes(resultTypes)) { return rewriter.notifyMatchFailure(binder.op, "result type bind failure"); } Region *thenRegion, *elseRegion; if (binder.getRegionAtIndex(elseRegion, 0) || binder.getRegionAtIndex(thenRegion, 1)) { return rewriter.notifyMatchFailure(binder.op, "region bind failure"); } auto primIfOp = rewriter.create( binder.getLoc(), TypeRange(resultTypes), conditionBool); auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) { rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin()); }; inlineIfCase(*thenRegion, primIfOp.getThenRegion()); inlineIfCase(*elseRegion, primIfOp.getElseRegion()); auto replaceTerminator = [&](Region ®ion) { PatternRewriter::InsertionGuard guard(rewriter); Operation *terminator = region.front().getTerminator(); rewriter.setInsertionPoint(terminator); rewriter.replaceOpWithNewOp( terminator, terminator->getOperands()); }; replaceTerminator(primIfOp.getThenRegion()); replaceTerminator(primIfOp.getElseRegion()); rewriter.replaceOp(binder.op, primIfOp.getResults()); return success(); }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("LessOrEqual", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Log", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander); patterns.onOp( "LogSoftmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value input; Torch::ValueTensorType resultType; if (binder.tensorOperand(input) || binder.tensorResultType(resultType)) return failure(); int64_t axis; if (binder.s64IntegerAttr(axis, "axis", -1)) return rewriter.notifyMatchFailure(binder.op, "axis bind failure"); Value axisConst = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(axis)); Value none = rewriter.create(binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, input, axisConst, none); return success(); }); patterns.onOp( "LogSoftmax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value input; Torch::ValueTensorType resultType; if (binder.tensorOperand(input) || binder.tensorResultType(resultType)) return failure(); int64_t axis; if (binder.s64IntegerAttr(axis, "axis", 1)) return rewriter.notifyMatchFailure(binder.op, "axis bind failure"); std::optional maybeRank = Torch::getTensorRank(input); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unsupported: unranked tensor"); int64_t rank = *maybeRank; // if negative axis is provided, then flip it to a positive axis if (axis < 0) { axis = rank + axis; } // need input type and sizes to flatten/unflatten later. auto inputTy = cast(input.getType()); if (!inputTy || !inputTy.hasSizes()) return rewriter.notifyMatchFailure( binder.op, "failed to get input type or sizes"); Value axisConst = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(axis)); Value none = rewriter.create(binder.getLoc()); Value cstEnd = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); // The old version of LogSoftmax flattens post-axis dims, performs // LogSoftmax on the flattened dim, then unflattens back to the original // shape. // this section gets some size information necessary for // flattening/unflattening if (!inputTy || !inputTy.hasSizes()) return failure(); llvm::ArrayRef allDims(inputTy.getSizes()); llvm::ArrayRef rightDims(allDims.begin() + axis, allDims.end()); llvm::SmallVector leftDims(allDims.begin(), allDims.begin() + axis); int64_t prodRightSizes = 1; llvm::SmallVector rightDimConsts; for (int64_t n : rightDims) { rightDimConsts.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(n))); if (n == Torch::kUnknownSize) { prodRightSizes = -1; break; } prodRightSizes *= n; } leftDims.push_back(prodRightSizes); // the following list will be used to unflatten the right side Value rightDimsPrimList = rewriter.create( binder.getLoc(), rewriter.getType( rewriter.getType()), rightDimConsts); auto flatRightTy = rewriter.getType( leftDims, inputTy.getOptionalDtype()); // flatten input Value inputFlatRight = rewriter.create( binder.getLoc(), flatRightTy, input, axisConst, cstEnd); // compute lsm over flattened index Value outputFlatRight = rewriter.create( binder.getLoc(), flatRightTy, inputFlatRight, axisConst, none); // unflatten rewriter.replaceOpWithNewOp( binder.op, resultType, outputFlatRight, axisConst, rightDimsPrimList); return success(); }); patterns.onOp("MatMul", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "MatMulInteger", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs, lhsZp, rhsZp; if (binder.tensorOperandAtIndex(lhs, 0) || binder.tensorOperandAtIndex(rhs, 1) || binder.tensorResultType(resultType)) return failure(); auto lhsTy = dyn_cast(lhs.getType()); auto rhsTy = dyn_cast(rhs.getType()); if (binder.tensorOperandAtIndex(lhsZp, 2)) { lhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); } if (binder.tensorOperandAtIndex(rhsZp, 3)) { rhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); } if (auto zpTy = dyn_cast(lhsZp.getType())) { for (auto dim : zpTy.getSizes()) if (dim != 1) return failure(); lhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), lhsZp); } if (auto zpTy = dyn_cast(rhsZp.getType())) { for (auto dim : zpTy.getSizes()) if (dim != 1) return failure(); rhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), rhsZp); } Value scale = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); auto q = [&](Type qty) -> Type { if (qty.isSignedInteger(8)) return rewriter.getType(); if (qty.isUnsignedInteger(8)) return rewriter.getType(); if (qty.isSignedInteger(32)) return rewriter.getType(); return {}; }; Type lhsQTy = rewriter.getType( lhsTy.getOptionalSizes(), q(lhsTy.getDtype())); Type rhsQTy = rewriter.getType( rhsTy.getOptionalSizes(), q(rhsTy.getDtype())); lhs = rewriter.create( binder.getLoc(), lhsQTy, lhs, scale, lhsZp); rhs = rewriter.create( binder.getLoc(), rhsQTy, rhs, scale, rhsZp); rewriter.replaceOpWithNewOp(binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Mul", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return rewriter.notifyMatchFailure(binder.op, "auto_pad bind failure"); if (autoPad != "NOTSET") return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); Torch::ValueTensorType resultTypeOut; Value operand; int64_t ceilMode, storageOrder; // TODO: Add support for indices output and storage_order if (binder.tensorOperand(operand) || binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) || binder.s64IntegerAttr(storageOrder, "storage_order", 0) || binder.tensorResultTypeAtIndex(resultTypeOut, 0)) return rewriter.notifyMatchFailure( binder.op, "operand/ceil_mode/storage_order/resultType bind failure"); if (storageOrder != 0) return rewriter.notifyMatchFailure( binder.op, "storage_order setting is not supported."); // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(operand); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); int64_t rank = *maybeRank; int64_t spatial = rank - 2; SmallVector kernel, padding, strides, dilations; if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) return rewriter.notifyMatchFailure(binder.op, "kernel_shape bind failure"); if (kernel.size() != static_cast(spatial)) return rewriter.notifyMatchFailure( binder.op, "kernel list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(padding, "pads", {})) return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); if (!padding.empty() && padding.size() != static_cast(2 * spatial)) return rewriter.notifyMatchFailure( binder.op, "padding list must contain (begin,end) pair for each " "spatial axis"); if (binder.s64IntegerArrayAttr(strides, "strides", {})) return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); if (!strides.empty() && strides.size() != static_cast(spatial)) return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); if (padding.empty()) padding.resize(spatial, 0); if (strides.empty()) strides.resize(spatial, 1); if (dilations.empty()) dilations.resize(spatial, 1); // If the padding is symmetric we can push the padding operation to the // torch operator. if (padding.size() == static_cast(2 * spatial)) { bool equal = true; for (int i = 0; i < spatial; ++i) { equal = equal && (padding[i] == padding[i + spatial]); } if (equal) padding.resize(spatial); } // Torch pool operators require equal padding on each size of each // dimension so we materialize the padding behavior explicitly and set // the padding to 0. if (padding.size() == static_cast(2 * spatial)) { auto operandTy = cast(operand.getType()); llvm::SmallVector shuffledPadding(spatial * 2); llvm::SmallVector paddedShape(operandTy.getSizes()); shuffledPadding.resize(2 * rank); for (int i = 0; i < spatial; ++i) { paddedShape[i + 2] += padding[i] + padding[i + spatial]; shuffledPadding[2 * i] = padding[i]; shuffledPadding[2 * i + 1] = padding[i + spatial]; } Value shuffledPaddingList = createConstantIntList(binder, rewriter, padding); Value zero; if (resultTypeOut.getDtype().isa()) { zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr( std::numeric_limits::lowest())); } else if (resultTypeOut.getDtype().isa()) { zero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr( std::numeric_limits::lowest())); } auto paddedInputTy = rewriter.getType( paddedShape, operandTy.getDtype()); operand = rewriter.create( binder.getLoc(), paddedInputTy, operand, shuffledPaddingList, zero); padding.clear(); padding.resize(spatial, 0); } Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesList = createConstantIntList(binder, rewriter, strides); Value dilationsList = createConstantIntList(binder, rewriter, dilations); Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); if (rank == 3) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: AtenMaxPool1dOp"); if (binder.op->getNumResults() == 2) { Torch::ValueTensorType resultTypeIndices; if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1)) return failure(); if (rank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultTypeOut, resultTypeIndices, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } if (rank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultTypeOut, resultTypeIndices, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } } else { if (rank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultTypeOut, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } if (rank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultTypeOut, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } } return rewriter.notifyMatchFailure(binder.op, "No rank is matched."); }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("GreaterOrEqual", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "InstanceNormalization", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; float eps; if (binder.tensorOperands(operands, 3) || binder.tensorResultType(resultType) || operands.size() != 3 || binder.f32FloatAttr(eps, "epsilon", 1e-05f)) { return failure(); } Value none = rewriter.create(binder.getLoc()); Value boolTrue = rewriter.create(binder.getLoc(), true); Value boolFalse = rewriter.create(binder.getLoc(), false); auto epsValue = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(eps)); auto momentum = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); rewriter.replaceOpWithNewOp( binder.op, resultType, /* input */ operands[0], /* weight */ operands[1], /* bias */ operands[2], /* running mean */ none, /* running var */ none, /* use input stats */ boolTrue, momentum, epsValue, /* cudnn enabled */ boolFalse); return success(); }); patterns.onOp( "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType) || operands.size() == 0) { return failure(); } Value result = operands[0]; for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp(binder.op, result); return success(); }); patterns.onOp( "Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType) || operands.size() == 0) { return failure(); } Value result = operands[0]; for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp(binder.op, result); return success(); }); patterns.onOp("Neg", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp( "Not", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } auto loc = binder.getLoc(); auto operandTy = cast(operand.getType()); auto eTy = operandTy.getDtype(); if (!eTy.isInteger(1)) { auto i1ty = rewriter.getI1Type(); auto ty = rewriter.getType( operandTy.getSizes(), i1ty); auto torchqTy = Torch::getScalarTypeForType(i1ty); Value tyConst = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); operand = rewriter.create( loc, ty, operand, tyConst, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Or", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "GatherND", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, indices; int64_t batchDimCount; if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(indices, 1) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(batchDimCount, "batch_dims", 0)) return failure(); Location loc = binder.getLoc(); auto dataTy = cast(data.getType()); auto indicesTy = cast(indices.getType()); if (!dataTy || !dataTy.hasSizes()) return failure(); if (!indicesTy || !indicesTy.hasSizes()) return failure(); // step 1. Get shapes and ranks of data and indices. The last dimension // of indices is expected to be static. ArrayRef dataShape = dataTy.getSizes(); int64_t dataRank = dataShape.size(); ArrayRef indicesShape = indicesTy.getSizes(); int64_t indicesRank = indicesShape.size(); int64_t indicesLastDim = indicesShape.back(); // Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and // batch_dims integer b, onnx.gather_nd gathers slices of data into an // output tensor of rank q + r - indices_shape[-1] - 1 - b. // indices_shape[-1] must be static to have deterministic output rank. if (dataRank < 1 || indicesRank < 1) return rewriter.notifyMatchFailure( binder.op, "expected data and indices rank to be >= 1"); if (batchDimCount >= std::min(dataRank, indicesRank)) return rewriter.notifyMatchFailure( binder.op, "batch_dims should be strictly less than " "min(rank(data), rank(indices))"); if (indicesLastDim == Torch::kUnknownSize) return rewriter.notifyMatchFailure( binder.op, "expected last dimension of indices to be static"); // step 2. Get dimension list of data. SmallVector batchShape; SmallVector batchDims; SmallVector dataDims; for (int64_t i = 0; i < dataRank; ++i) { Value k = rewriter.create(binder.getLoc(), i); Value dataDim = rewriter.create(loc, data, k); dataDims.push_back(dataDim); if (i < batchDimCount) { batchShape.push_back(dataShape[i]); batchDims.push_back(dataDim); } } // step 3. Get dimension list of indices. Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value constOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); SmallVector indicesDimsMinusOne; SmallVector unflattenIndicesDims; Value indicesFlattenDim = constOne; for (int64_t i = 0; i < indicesRank - 1; ++i) { Value k = rewriter.create(binder.getLoc(), i); Value indicesDim = rewriter.create(loc, indices, k); indicesDimsMinusOne.push_back(indicesDim); if (i >= batchDimCount) { unflattenIndicesDims.push_back(indicesDim); indicesFlattenDim = rewriter.create( loc, indicesFlattenDim, indicesDim); } } ArrayRef indicesShapeMinusOne = indicesShape.drop_back(); // Algorithm: We can not directly perform torch.gather as it requires // the ranks of data(`r`) and indices(`q`) to be same. So we will // perform collapse and reshape operations to match the ranks of data // and indices(making sure the semantics of the onnx.gather_nd are // preserved), perform torch.gather operation, later unflatten the // gather result to match onnx.gather_nd output. For example, assuming // indices is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and // batch_dims(`b`)=1. Firstly, modify indices to 1-D indexing as the // torch.gather op supports only single dimensional indexing. (this // algorithm would have been simpler if we can get a torch op that // supports indexing at multiple dimensions simultaneously). 1-D indexed // indices will be of shape (4, 5, 3, 1), now materialize it to // `r-b-indices_shape[-1]` dimension of data i.e. reshaping it to the // shape (4, 5, 3, 1, 1). Next step is to flatten+expand the indices and // flatten the data to (4, 15, 7, 4) and (4, 110, 7, 4) shapes // respectively and then perform the torch.gather operation. Post the // gather operation, unflatten the indices dimensions of result to (4, // 5, 3, 7, 4) which is our required result. // step 4. Convert indices_shape[-1] dimensional indexing to 1D // indexing. Value sliceDim = rewriter.create( loc, rewriter.getI64IntegerAttr(indicesRank - 1)); SmallVector indicesSliceShape(indicesShapeMinusOne); indicesSliceShape.push_back(1); auto indicesSliceTy = rewriter.getType( indicesSliceShape, indicesTy.getOptionalDtype()); Value start = constZero; Value updatedIndices; for (int64_t i = 0; i < indicesLastDim; ++i) { Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(i + 1)); Value indicesSlice = rewriter.create( loc, indicesSliceTy, indices, sliceDim, start, end, /*step=*/constOne); start = end; // Apply bounds checking on the indices slice. auto boolTy = rewriter.getType( indicesSliceShape, rewriter.getI1Type()); Value lt = rewriter.create( loc, boolTy, indicesSlice, constZero); Value add = rewriter.create( loc, indicesSliceTy, indicesSlice, dataDims[batchDimCount + i], /*alpha=*/constOne); indicesSlice = rewriter.create( loc, indicesSliceTy, lt, add, indicesSlice); if (i == 0) { updatedIndices = indicesSlice; continue; } updatedIndices = rewriter.create( loc, indicesSliceTy, indicesSlice, updatedIndices, dataDims[batchDimCount + i]); } // step 5. Compute all the required result types here. SmallVector reshapeIndicesShape(indicesShapeMinusOne); SmallVector reshapeIndicesDims(indicesDimsMinusOne); // Determine the collapsed dim size of indices(index_shape[-1] is not // part of collapsing as we already removed it by 1-D indexing). SmallVector flattenIndicesShape(batchShape); auto indicesCt = 1; for (int64_t i = batchDimCount; i < indicesRank - 1; ++i) { if (indicesShape[i] == Torch::kUnknownSize) { indicesCt = Torch::kUnknownSize; break; } indicesCt *= indicesShape[i]; } flattenIndicesShape.push_back(indicesCt); // Determine the collapsed dim size of data. SmallVector flattenDataShape(batchShape); auto dataCt = 1; for (int64_t i = 0; i < indicesLastDim; ++i) { int64_t sz = dataShape[i + batchDimCount]; if (sz == Torch::kUnknownSize) { dataCt = Torch::kUnknownSize; break; } dataCt *= sz; } flattenDataShape.push_back(dataCt); // Compute the shape of expand op. SmallVector expandIndicesDims(batchDims); expandIndicesDims.push_back(indicesFlattenDim); SmallVector expandIndicesShape(batchShape); expandIndicesShape.push_back(indicesCt); // Append `r-b-indices_shape[-1]` unit or data dims appropriately to all // result types. for (int64_t i = batchDimCount + indicesLastDim; i < dataRank; ++i) { reshapeIndicesShape.push_back(1); flattenIndicesShape.push_back(1); flattenDataShape.push_back(dataShape[i]); expandIndicesShape.push_back(dataShape[i]); reshapeIndicesDims.push_back(constOne); expandIndicesDims.push_back(dataDims[i]); } // step 6. Reshape 1-D indexed indices to match the rank of flattened // data by inserting unit dimensions. auto intListTy = rewriter.getType( rewriter.getType()); Value reshapeIndicesSizeList = rewriter.create(loc, intListTy, reshapeIndicesDims); auto reshapeIndicesTy = rewriter.getType( reshapeIndicesShape, indicesTy.getOptionalDtype()); Value reshapedIndices = rewriter.create( loc, reshapeIndicesTy, updatedIndices, reshapeIndicesSizeList); // step 7. Flatten `q-b-1` dimensions of the indices. auto flattenIndicesTy = rewriter.getType( flattenIndicesShape, indicesTy.getOptionalDtype()); Value batchDimCountVal = rewriter.create( loc, rewriter.getI64IntegerAttr(batchDimCount)); Value flattenedIndices = reshapedIndices; if (indicesRank == 1) { flattenedIndices = rewriter.create( loc, flattenIndicesTy, reshapedIndices, constZero); } else if (indicesRank > 1) { Value endDim = rewriter.create( loc, rewriter.getI64IntegerAttr(indicesRank - 2)); flattenedIndices = rewriter.create( loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim); } // step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices. auto expandIndicesTy = rewriter.getType( expandIndicesShape, indicesTy.getOptionalDtype()); Value expandIndicesSizeList = rewriter.create(loc, intListTy, expandIndicesDims); Value constFalse = rewriter.create( loc, rewriter.getType(), rewriter.getBoolAttr(false)); Value expandedIndices = rewriter.create( loc, expandIndicesTy, flattenedIndices, expandIndicesSizeList, /*implicit=*/constFalse); // step 9. Flatten indices_shape[-1] dimensions of data. auto flattenDataTy = rewriter.getType( flattenDataShape, dataTy.getOptionalDtype()); Value endDim = rewriter.create( loc, rewriter.getI64IntegerAttr(batchDimCount + indicesLastDim - 1)); Value flattenedData = rewriter.create( loc, flattenDataTy, data, batchDimCountVal, endDim); // step 10. Now we have flattenedData and expandedIndices of same rank // to perform gather operation. auto gatherTy = rewriter.getType( expandIndicesShape, dataTy.getOptionalDtype()); Value gather = rewriter.create( loc, gatherTy, flattenedData, batchDimCountVal, expandedIndices, /*sparseGrad=*/constFalse); // step 11. Unflatten the collapsed indices dims of gather result. if (indicesRank == 1) { rewriter.replaceOpWithNewOp( binder.op, resultType, gather, /*dim=*/constZero); return success(); } Value unflattenSizeList = rewriter.create( loc, intListTy, unflattenIndicesDims); rewriter.replaceOpWithNewOp( binder.op, resultType, gather, batchDimCountVal, unflattenSizeList); return success(); }); patterns.onOp( "Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, indices; int64_t axis; if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(indices, 1) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Location loc = binder.getLoc(); auto ctx = binder.op->getContext(); auto indicesTy = cast(indices.getType()); auto dataTy = cast(data.getType()); if (!dataTy || !dataTy.hasSizes() || !indicesTy.hasSizes()) return failure(); int64_t dataRank = dataTy.getSizes().size(); int64_t indicesRank = indicesTy.getSizes().size(); axis = axis < 0 ? axis + dataRank : axis; Value index = rewriter.create( loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis)); // Apply bounds checking on the input: auto intTy = rewriter.getType(); auto boolTy = rewriter.getType( indicesTy.getSizes(), rewriter.getI1Type()); Value zero = rewriter.create( loc, intTy, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create( loc, intTy, rewriter.getI64IntegerAttr(1)); Value lt = rewriter.create(loc, boolTy, indices, zero); Value dim = rewriter.create(loc, intTy, data, index); Value add = rewriter.create(loc, indicesTy, indices, dim, one); indices = rewriter.create(loc, indicesTy, lt, add, indices); auto intListTy = rewriter.getType( rewriter.getType()); llvm::SmallVector indicesDims; for (int i = 0, s = indicesTy.getSizes().size(); i < s; ++i) { Value k = rewriter.create(binder.getLoc(), i); indicesDims.push_back(rewriter.create( binder.getLoc(), indices, k)); } Value indicesSizeList = rewriter.create( binder.getLoc(), intListTy, indicesDims); // Determine the collapsed dim size: auto indicesCt = 1; for (auto sz : indicesTy.getSizes()) { if (sz == Torch::kUnknownSize) { indicesCt = Torch::kUnknownSize; break; } indicesCt *= sz; } auto flattenTy = rewriter.getType( SmallVector{indicesCt}, indicesTy.getOptionalDtype()); if (indicesRank == 0) { indices = rewriter.create( binder.getLoc(), flattenTy, indices, zero); } else if (indicesRank > 1) { Value rank = rewriter.create(loc, intTy, indices); Value end = rewriter.create(loc, rank, one); indices = rewriter.create( loc, flattenTy, indices, zero, end); } llvm::SmallVector gatherShape(dataTy.getSizes()); gatherShape[axis] = indicesCt; auto gatherTy = rewriter.getType( gatherShape, dataTy.getOptionalDtype()); Value gather = rewriter.create( loc, gatherTy, data, index, indices); if (indicesRank == 1) { rewriter.replaceOp(binder.op, gather); return success(); } if (indicesRank > 1) { gather = rewriter.replaceOpWithNewOp( binder.op, resultType, gather, index, indicesSizeList); return success(); } rewriter.replaceOpWithNewOp(binder.op, resultType, gather); return success(); }); patterns.onOp( "GatherElements", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, indices; int64_t axis; if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(indices, 1) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); Value sparseGrad = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, sparseGrad); return success(); }); patterns.onOp( "Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value a, b, c; float alpha, beta; int64_t transA, transB; if (binder.tensorOperandAtIndex(a, 0) || binder.tensorOperandAtIndex(b, 1) || binder.s64IntegerAttr(transA, "transA", 0) || binder.s64IntegerAttr(transB, "transB", 0) || binder.f32FloatAttr(alpha, "alpha", 1.0f) || binder.f32FloatAttr(beta, "beta", 1.0f) || binder.tensorResultType(resultType)) return failure(); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); Value one = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); auto transpose = [&](Value m) -> Value { auto tty = cast(m.getType()); auto shape = tty.getOptionalSizes(); if (shape.has_value()) { llvm::SmallVector newShape(shape.value()); std::reverse(newShape.begin(), newShape.end()); shape = std::move(newShape); } auto oty = Torch::ValueTensorType::get(tty.getContext(), shape, tty.getOptionalDtype()); return rewriter.create(binder.getLoc(), oty, m, zero, one); }; if (transA) { a = transpose(a); } if (transB) { b = transpose(b); } if (binder.getNumOperands() == 2) { rewriter.replaceOpWithNewOp(binder.op, resultType, a, b); return success(); } if (binder.tensorOperandAtIndex(c, 2)) return rewriter.notifyMatchFailure(binder.op, "Expected either 2 or 3 inputs"); Value mm = rewriter.create(binder.getLoc(), resultType, a, b); if (alpha == 1.0 && beta == 1.0) { rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, one); return success(); } if (alpha != 1.0 && beta != 1.0) { Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); mm = rewriter.create( binder.getLoc(), resultType, mm, constAlpha); alpha = 1.0; } if (alpha != 1.0) { std::swap(alpha, beta); std::swap(mm, c); } Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, constBeta); return success(); }); patterns.onOp( "GlobalAveragePool", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); auto inputTensorType = cast(operand.getType()); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); unsigned inputRank = inputShape.size(); if (!resultType || !resultType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected result type having sizes"); } ArrayRef resultShape = resultType.getSizes(); SmallVector cstKernel, cstPadding, cstStrides; Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (unsigned i = 2; i < inputRank; i++) { if (inputShape[i] == Torch::kUnknownSize) { Value dim = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i)); Value inputDimSize = rewriter.create( binder.getLoc(), operand, dim); cstKernel.push_back(inputDimSize); } else { int64_t kernelSize = inputShape[i] - resultShape[i] + 1; cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); } cstPadding.push_back(cstZero); cstStrides.push_back(cstOne); } Value kernelSizeList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); Value paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); Value stridesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value cstFalse = rewriter.create(binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; Value cstNone = rewriter.create(binder.getLoc()); if (inputRank == 3) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad); return success(); } else if (inputRank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } else if (inputRank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } return failure(); }); patterns.onOp( "GlobalMaxPool", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); auto inputTensorType = operand.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); unsigned inputRank = inputShape.size(); if (!resultType || !resultType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected result type having sizes"); } SmallVector cstKernel, cstPadding, cstStrides, cstDilations; Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (unsigned i = 2; i < inputRank; i++) { if (inputShape[i] == Torch::kUnknownSize) { Value dim = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i)); Value inputDimSize = rewriter.create( binder.getLoc(), operand, dim); cstKernel.push_back(inputDimSize); } else { cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[i]))); } cstPadding.push_back(cstZero); cstDilations.push_back(cstOne); cstStrides.push_back(cstOne); } Value kernelSizeList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); Value paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); Value dilationsList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); Value stridesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value cstCeilMode = rewriter.create(binder.getLoc(), false); if (inputRank == 3) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } else if (inputRank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } else if (inputRank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } return failure(); }); patterns.onOp( "LayerNormalization", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType yType, meanType, invStdDevType; Value x, scale, b; int64_t axis, stashType; float epsilon; if (binder.tensorOperandAtIndex(x, 0) || binder.tensorOperandAtIndex(scale, 1) || binder.tensorOperandAtIndex(b, 2) || binder.tensorResultTypeAtIndex(yType, 0) || binder.s64IntegerAttr(axis, "axis", -1) || binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) || binder.s64IntegerAttr(stashType, "stash_type", 1)) return failure(); Value constEpsilon = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(epsilon)); unsigned rank = 1; if (std::optional maybeRank = Torch::getTensorRank(x)) rank = *maybeRank; SmallVector normalized; axis = Torch::toPositiveDim(axis, rank); auto xType = cast(x.getType()); if (!xType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input (X) to have sizes"); } ArrayRef xShape = xType.getSizes(); for (int64_t n = axis; n < rank; n++) { normalized.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n]))); } Value normalized_shape = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), normalized); int64_t numResults = binder.op->getNumResults(); if (numResults == 1) { SmallVector reducedShape(rank, 1); for (int64_t i = 0; i < axis; i++) reducedShape[i] = xShape[i]; auto reducedType = xType.getWithSizesAndDtype( reducedShape, xType.getOptionalDtype()); Value y = rewriter .create( binder.getLoc(), yType, /*meanType=*/reducedType, /*invStdDevType=*/reducedType, x, normalized_shape, scale, b, constEpsilon) .getResult0(); rewriter.replaceOp(binder.op, y); return success(); } if (numResults == 3) { if (binder.tensorResultTypeAtIndex(meanType, 1) || binder.tensorResultTypeAtIndex(invStdDevType, 2)) return failure(); rewriter.replaceOpWithNewOp( binder.op, yType, meanType, invStdDevType, x, normalized_shape, scale, b, constEpsilon); return success(); } return rewriter.notifyMatchFailure( binder.op, "Unimplemented: expected either 1 or 3 results"); }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; float alpha; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.f32FloatAttr(alpha, "alpha", 0.01f)) return failure(); Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAlpha); return success(); }); patterns.onOp( "Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, pads, axes; std::string mode; // TODO: The `axes` parameter is not supported yet. if (!binder.tensorOperandAtIndex(axes, 3)) { return rewriter.notifyMatchFailure( binder.op, "The axes parameter is not supported yet"); } if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(pads, 1) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(mode, "mode", "constant")) return failure(); Location loc = binder.getLoc(); // Get pads shape and rank. The pads tensor is expected to be 1-D // tensor. auto padsTensorType = cast(pads.getType()); if (!padsTensorType || !padsTensorType.hasSizes()) { return rewriter.notifyMatchFailure(binder.op, "Expect non empty pad tensor"); } ArrayRef padsShape = padsTensorType.getSizes(); int64_t padsRank = padsShape.size(); if (padsRank != 1) return rewriter.notifyMatchFailure(binder.op, "expect 1-d pad tensor"); int64_t padsSize = padsShape[0]; if (padsSize == Torch::kUnknownSize) { // As per onnx.Pad documentation, padSize = 2*num_data_axes // (if axes param not passed). Need to be updated when adding // support for `axes` param. auto dataOpTy = cast(data.getType()); TensorType dataTensor = dataOpTy.toBuiltinTensor(); if (!dataTensor || !dataTensor.hasRank()) return rewriter.notifyMatchFailure( binder.op, "pad length unknown and data operand unranked"); int64_t dataRank = dataTensor.getRank(); padsSize = 2 * dataRank; } Value constantValue; if (binder.getNumOperands() >= 3) { if (!binder.tensorOperandAtIndex(constantValue, 2)) { auto constTy = dyn_cast(constantValue.getType()); if (!constTy || !constTy.hasDtype()) return rewriter.notifyMatchFailure( binder.op, "constant ty is unsupport type"); Type scalarTy = rewriter.getType(); if (isa(constTy.getDtype())) scalarTy = rewriter.getType(); constantValue = rewriter.create(loc, scalarTy, constantValue); } } if (!constantValue) { auto dataTensorType = cast(data.getType()); if (dataTensorType.getDtype().isa()) constantValue = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); if (dataTensorType.getDtype().isa()) constantValue = rewriter.create( loc, rewriter.getF64FloatAttr(0.0f)); if (!constantValue) return rewriter.notifyMatchFailure( binder.op, "expected integer or float data tensor"); } // Extract all the values of 1-D pad tensor and create a list of all // these values as torch.pad op expects pad list. Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); SmallVector padsTensorValue; SmallVector emptyShape; Type padsElemType = Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape, padsTensorType.getOptionalDtype()); for (uint32_t i = 0; i < padsSize; ++i) { Value index = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); auto select = rewriter.create( loc, padsElemType, pads, constZero, index); Value selectInt = rewriter.create( loc, rewriter.getType(), select); padsTensorValue.push_back(selectInt); } // The torch.pad op expects a different arrangement of padding pairs for // each dimension as compared to the onnx.pad op. So, rearranging pad // tensor to satisfy torch.pad op semantics. SmallVector padsRearrange; for (uint32_t i = 0; i < padsSize / 2; i++) { padsRearrange.emplace_back(padsTensorValue[i]); padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]); } Value padsSizeList = rewriter .create( loc, Torch::ListType::get(rewriter.getType()), padsRearrange) .getResult(); Value modeVal = rewriter.create( loc, rewriter.getStringAttr(mode)); rewriter.replaceOpWithNewOp( binder.op, resultType, data, padsSizeList, modeVal, constantValue); return success(); }); patterns.onOp("Pow", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; if (binder.tensorOperand(tensor) || binder.tensorResultType(resultType)) { return failure(); } Value noneVal = rewriter.create(binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, tensor, /*memory_format=*/noneVal); return success(); }); patterns.onOp( "Mean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { if (binder.op->getNumOperands() == 1) { Torch::ValueTensorType resultType; Value x; if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOp(binder.op, x); return success(); } Torch::ValueTensorType resultType; SmallVector valList; int64_t numOperands = binder.op->getNumOperands(); Value numOperandsConstant = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), numOperands)); if (binder.tensorOperands(valList, numOperands) || binder.tensorResultType(resultType)) return failure(); Value constOne = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); // Short circuit to binary add Value curr = rewriter.create( binder.getLoc(), resultType, valList[0], valList[1], constOne); if (numOperands == 2) { rewriter.replaceOpWithNewOp( binder.op, resultType, curr, numOperandsConstant); return success(); } // When binder.op->getNumOperands() > 2 auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( binder.op->getContext()); for (int i = 2; i < numOperands; i++) { if (i == numOperands - 1) { curr = rewriter.create( binder.getLoc(), resultType, curr, valList[i], constOne); } else { curr = rewriter.create( binder.getLoc(), baseType, curr, valList[i], constOne); } } rewriter.replaceOpWithNewOp( binder.op, resultType, curr, numOperandsConstant); return success(); }); patterns.onOp( "IsInf", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; int64_t neg; int64_t pos; if (binder.tensorOperand(tensor) || binder.s64IntegerAttr(neg, "detect_negative", 1) || binder.s64IntegerAttr(pos, "detect_positive", 1) || binder.tensorResultType(resultType)) { return failure(); } if (neg == 0) { // replace all negative infs with 0 tensor = rewriter.create( binder.getLoc(), dyn_cast(tensor.getType()), tensor); } if (pos == 0) { // first use neg op to flip positive inf to negative inf. Then relu to // replace all positive infs with 0. Value flip = rewriter.create( binder.getLoc(), dyn_cast(tensor.getType()), tensor); tensor = rewriter.create( binder.getLoc(), dyn_cast(flip.getType()), flip); } rewriter.replaceOpWithNewOp(binder.op, resultType, tensor); return success(); }); patterns.onOp("IsNaN", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; if (binder.tensorOperand(tensor) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, tensor); return success(); }); patterns.onOp("PRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; Value slope; if (binder.tensorOperands(tensor, slope) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, tensor, slope); return success(); }); patterns.onOp("Mod", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value self, other; int64_t fmod; if (binder.tensorOperands(self, other) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(fmod, "fmod", 0)) { return failure(); } if (fmod) { rewriter.replaceOpWithNewOp( binder.op, resultType, self, other); return success(); } rewriter.replaceOpWithNewOp( binder.op, resultType, self, other); return success(); }); patterns.onOp("Mish", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; if (binder.tensorOperand(input) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, input); return success(); }); patterns.onOp( "OneHot", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { llvm::SmallVector inputs; Torch::ValueTensorType resultType; if (binder.tensorOperandsList(inputs) || binder.tensorResultType(resultType)) return failure(); if (inputs.size() != 3) return rewriter.notifyMatchFailure(binder.op, "expected 3 operands"); int64_t axis; if (binder.s64IntegerAttr(axis, "axis", -1)) return rewriter.notifyMatchFailure(binder.op, "`axis` attr not found"); auto loc = binder.getLoc(); Value indices = inputs[0]; Value depth = inputs[1]; Value values = inputs[2]; auto indicesTy = cast(indices.getType()); auto valuesTy = cast(values.getType()); auto depthTy = cast(depth.getType()); axis = axis < 0 ? axis + indicesTy.getSizes().size() + 1 : axis; bool depthIsInt = isa(depthTy.getDtype()); Type intTy = rewriter.getType(); Type floatTy = rewriter.getType(); Type depthETy = depthIsInt ? intTy : floatTy; depth = rewriter.create(loc, depthETy, depth); if (!depthIsInt) depth = rewriter.create( loc, rewriter.getType(), depth); Type boolTy = rewriter.getType( indicesTy.getSizes(), rewriter.getI1Type()); Value zero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value lt = rewriter.create(loc, boolTy, indices, zero); Value add = rewriter.create( loc, indicesTy, indices, depth, one); indices = rewriter.create(loc, indicesTy, lt, add, indices); auto selectTy = rewriter.getType( llvm::SmallVector{1}, valuesTy.getDtype()); bool valuesAreInt = isa(valuesTy.getDtype()); Type valueEty = valuesAreInt ? intTy : floatTy; Value off = rewriter.create(loc, selectTy, values, zero, zero); off = rewriter.create(loc, valueEty, off); Value on = rewriter.create(loc, selectTy, values, zero, one); on = rewriter.create(loc, valueEty, on); auto i32Ty = rewriter.getIntegerType(32, true); llvm::SmallVector onehotShape(indicesTy.getSizes()); onehotShape.push_back(Torch::kUnknownSize); auto onehotTy = rewriter.getType(onehotShape, i32Ty); Value onehot = rewriter.create( binder.getLoc(), onehotTy, indices, depth); for (int i = valuesTy.getSizes().size(); i > axis; ++i) { std::swap(onehotShape[i - 1], onehotShape[i]); Value iv0 = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); Value iv1 = rewriter.create( loc, rewriter.getI64IntegerAttr(i - 1)); onehotTy = rewriter.getType(onehotShape, i32Ty); onehot = rewriter.create(loc, onehotTy, onehot, iv1, iv0); } // Change one hot to an array of booleans to select value: auto i1Ty = rewriter.getI1Type(); auto torchqTy = Torch::getScalarTypeForType(i1Ty); Value tyConst = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); onehotTy = rewriter.getType(onehotShape, i1Ty); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); onehot = rewriter.create( loc, onehotTy, onehot, tyConst, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); onehot = rewriter.create(loc, resultType, onehot, on, off); rewriter.replaceOp(binder.op, onehot); return success(); }); patterns.onOp("HardSwish", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; if (binder.tensorOperand(input) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, input); return success(); }); patterns.onOp( "Hardmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // onnx.Hardmax can be expanded into the following python code: // // import torch.nn.functional as F // def hardmax(tensor, dim=-1): // maximums = torch.argmax(tensor, dim=dim, keepdim=False) // return F.one_hot(maximums) // // Given an example input: // tensor([[1, 2, 3], // [4, 6, 5], // [9, 8, 7]]) // Above code yields the following: // tensor([[0, 0, 1], // [0, 1, 0], // [1, 0, 0]]) Torch::ValueTensorType resultType; int64_t axisValue; Value input, axis; if (binder.tensorOperand(input) || binder.s64IntegerAttr(axisValue, "axis") || binder.tensorResultType(resultType)) return failure(); auto loc = binder.getLoc(); std::optional axisIntTorch = onnxDtypeIntToTorchDtypeInt(axisValue); if (!axisIntTorch.has_value()) return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given axis conversion"); axis = rewriter.create( loc, rewriter.getI64IntegerAttr(axisIntTorch.value())); // torch.argmax Value constKeepDims = rewriter.create( loc, rewriter.getType(), rewriter.getBoolAttr(false)); Value argmax = rewriter.create( loc, resultType, input, axis, constKeepDims); // one_hot Value oneInt = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp(binder.op, resultType, argmax, oneInt); return success(); }); }