//===------------------------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with // sinceVersion==1. We interpret this to include the following spec // diffs that are irrelevant to this level of lowering: // * Supported element types. // * Limited broadcasting to full broadcasting support. // // There are a lot of spec revisions that basically generalized elementwise // to be more normal and a direct translation vs a special case. This // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { patterns.onOp("HardSigmoid", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensorOperand; float alpha, beta; if (binder.tensorOperand(tensorOperand) || binder.f32FloatAttr(alpha, "alpha", 0.2f) || binder.f32FloatAttr(beta, "beta", 0.5f) || binder.tensorResultType(resultType)) return failure(); // HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta)) Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); // Expression: alpha * x + beta Value alpha_x_plus_beta = rewriter.create( binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha); // Expression: min(1, alpha * x + beta) Value constantOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantOne); Value minExpression = rewriter.create( binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); // Expression: max(0, min(1, alpha * x + beta)) Value constantZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantZero); rewriter.replaceOpWithNewOp( binder.op, resultType, zeroTensor, minExpression); return success(); }); patterns.onOp( "Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value operand; Torch::ValueTensorType resultType; std::string approximate; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(approximate, "approximate", "none")) return failure(); Value vApproximate = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getStringAttr(approximate)); rewriter.replaceOpWithNewOp(binder.op, resultType, operand, vApproximate); return success(); }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("LessOrEqual", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Log", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("MatMul", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Mul", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return rewriter.notifyMatchFailure(binder.op, "auto_pad bind failure"); if (autoPad != "NOTSET") return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); Torch::ValueTensorType resultType; Value operand; bool ceilMode; int64_t storageOrder; // TODO: Add support for indices output and storage_order if (binder.tensorOperand(operand) || binder.s64BoolAttr(ceilMode, "ceil_mode", false) || binder.s64IntegerAttr(storageOrder, "storage_order", 0) || binder.tensorResultType(resultType)) return rewriter.notifyMatchFailure( binder.op, "operand/ceil_mode/storage_order/resultType bind failure"); if (storageOrder != 0) return rewriter.notifyMatchFailure( binder.op, "storage_order setting is not supported."); // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(operand); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; SmallVector kernel, padding, strides, dilations; if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) return rewriter.notifyMatchFailure(binder.op, "kernel_shape bind failure"); if (kernel.size() != rank - 2) return rewriter.notifyMatchFailure( binder.op, "kernel list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(padding, "pads", {0})) return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); if (padding.size() != 1 && padding.size() != 2 * (rank - 2)) return rewriter.notifyMatchFailure( binder.op, "padding list must contain (begin,end) pair for each " "spatial axis"); if (binder.s64IntegerArrayAttr(strides, "strides", {1})) return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); if (strides.size() != 1 && strides.size() != rank - 2) return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesList = createConstantIntList(binder, rewriter, strides); Value dilationsList = createConstantIntList(binder, rewriter, dilations); Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); if (rank == 3) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: AtenMaxPool1dOp"); if (rank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } if (rank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } return rewriter.notifyMatchFailure(binder.op, "No rank is matched."); }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("GreaterOrEqual", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType) || operands.size() == 0) { return failure(); } Value result = operands[0]; for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp(binder.op, result.getDefiningOp()); return success(); }); patterns.onOp("Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType) || operands.size() == 0) { return failure(); } Value result = operands[0]; for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp( binder.op, result.getDefiningOp()); return success(); }); patterns.onOp("Neg", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Not", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Or", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, indices; int64_t axis; if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(indices, 1) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Location loc = binder.getLoc(); // 1. Get data shape and rank. auto dataTensorType = data.getType().cast(); if (!dataTensorType || !dataTensorType.hasSizes()) { return rewriter.notifyMatchFailure(binder.op, "Expect non empty input data"); } ArrayRef dataShape = dataTensorType.getSizes(); unsigned dataRank = dataShape.size(); // 2. Get indices shape and rank. auto indexType = indices.getType().cast(); if (!indexType || !indexType.hasSizes()) { return rewriter.notifyMatchFailure(binder.op, "Expect non empty index tensor"); } ArrayRef indexShape = indexType.getSizes(); unsigned indexRank = indexShape.size(); // 3. Compute total elements in the indices tensor, as we will collapse // the indices tensor to a unary tensor. Also compute index shape and // data shape tensors as they will be used for creating output types. int64_t indexElemCount = 1; for (int64_t dim : indexShape) { if (dim == Torch::kUnknownSize) { indexElemCount = Torch::kUnknownSize; break; } indexElemCount *= dim; } Value constOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); SmallVector indexShapeTensor; Value indexElemCountVal = constOne; for (unsigned i = 0; i < indexRank; ++i) { Value indexDimVal = rewriter.create( loc, indices, rewriter.create( loc, rewriter.getI64IntegerAttr(i))); indexShapeTensor.emplace_back(indexDimVal); indexElemCountVal = rewriter.create( loc, indexElemCountVal, indexDimVal); } SmallVector dataShapeTensor; for (unsigned i = 0; i < dataRank; ++i) { dataShapeTensor.emplace_back(rewriter.create( loc, data, rewriter.create( loc, rewriter.getI64IntegerAttr(i)))); } // 4. We can not directly perform torch.gather as the onnx.gather op // collects the input data at different location of output compared to // torch.gather op. The output of torch.gather and onnx.gather ops are // indexed differently. // check https://onnx.ai/onnx/operators/onnx__Gather.html for more // details. So we will collapse indices tensor to a unary tensor and // materialize to non-axis dimension of data tensor. For example, // assuming indices is of shape (4, 5, 6), data is (8, 10, 11, 12) and // axis=1. we will collapse indices into a (120,) unary tensor, // materialize to non-axis dimension of data i.e. reshaping the unary // indices tensor to (1, 120, 1, 1) and then perform the torch.gather // operation. Now broadcast the output of gather operation to non-axis // dimensions of data tensor. This would make the result of shape (8, // 10, 120, 12). Post the broadcasting, expand the indices dimensions by // reshaping (8, 10, 120, 12) to (8, 10, 4, 5, 6, 12) tensor, which is // our expected final result. SmallVector collapsedIndexShape(dataRank, 1); collapsedIndexShape[axis] = indexElemCount; Type collapsedIndexType = Torch::ValueTensorType::get( indexType.getContext(), llvm::ArrayRef(collapsedIndexShape), indexType.getOptionalDtype()); SmallVector collapsedIndexSize(dataRank, constOne); collapsedIndexSize[axis] = indexElemCountVal; auto collapsedIndexSizeList = rewriter.create( loc, rewriter.getType( rewriter.getType()), collapsedIndexSize); auto collapsedIndices = rewriter.create( loc, collapsedIndexType, indices, collapsedIndexSizeList); // 5. Compute gather result type and perform gather operation. Type gatherResultType = Torch::ValueTensorType::get( dataTensorType.getContext(), llvm::ArrayRef(collapsedIndexShape), dataTensorType.getOptionalDtype()); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); Value constFalse = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(false)); auto gatherOp = rewriter.create( loc, gatherResultType, data, constAxis, collapsedIndices, /*sparseGrad=*/constFalse); // 6. Broadcast the gather output to non-axis dimensions of data tensor. SmallVector dataShapeVector(dataShape); dataShapeVector[axis] = indexElemCount; Type expandResultType = Torch::ValueTensorType::get( dataTensorType.getContext(), llvm::ArrayRef(dataShapeVector), dataTensorType.getOptionalDtype()); dataShapeTensor[axis] = indexElemCountVal; auto expandSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), dataShapeTensor); auto expandedGather = rewriter.create( loc, expandResultType, gatherOp, expandSizeList, /*implicit=*/constFalse); // 7. Compute the result type of reshape op which expands the collapsed // indices shapes back to the original indices shapes and reshape the // output produced at step 6. This will produce our expected result of // onnx.gather op. SmallVector resultShapeTensor; for (unsigned i = 0; i < dataRank; ++i) { if (i == axis) { resultShapeTensor.insert(resultShapeTensor.end(), indexShapeTensor.begin(), indexShapeTensor.end()); continue; } resultShapeTensor.emplace_back(dataShapeTensor[i]); } auto resultSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), resultShapeTensor); rewriter.replaceOpWithNewOp( binder.op, resultType, expandedGather, resultSizeList); return success(); }); patterns.onOp( "GatherElements", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, indices; int64_t axis; if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(indices, 1) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); Value sparseGrad = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, sparseGrad); return success(); }); patterns.onOp( "Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value a, b, c; float alpha, beta; int64_t transA, transB; if (binder.tensorOperandAtIndex(a, 0) || binder.tensorOperandAtIndex(b, 1) || binder.tensorOperandAtIndex(c, 2) || binder.s64IntegerAttr(transA, "transA", 0) || binder.s64IntegerAttr(transB, "transB", 0) || binder.f32FloatAttr(alpha, "alpha", 1.0f) || binder.f32FloatAttr(beta, "beta", 1.0f) || binder.tensorResultType(resultType)) return failure(); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); Value one = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); auto transpose = [&](Value m) -> Value { auto tty = m.getType().cast(); auto shape = tty.getOptionalSizes(); if (shape.has_value()) { llvm::SmallVector newShape(shape.value()); std::reverse(newShape.begin(), newShape.end()); shape = std::move(newShape); } auto oty = Torch::ValueTensorType::get(tty.getContext(), shape, tty.getOptionalDtype()); return rewriter.create(binder.getLoc(), oty, m, zero, one); }; if (transA) { a = transpose(a); } if (transB) { b = transpose(b); } Value mm = rewriter.create(binder.getLoc(), resultType, a, b); if (alpha == 1.0 && beta == 1.0) { rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, one); return success(); } if (alpha != 1.0 && beta != 1.0) { Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); mm = rewriter.create( binder.getLoc(), resultType, mm, constAlpha); alpha = 1.0; } if (alpha != 1.0) { std::swap(alpha, beta); std::swap(mm, c); } Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, constBeta); return success(); }); patterns.onOp( "GlobalAveragePool", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); auto inputTensorType = operand.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); unsigned inputRank = inputShape.size(); if (!resultType || !resultType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected result type having sizes"); } ArrayRef resultShape = resultType.getSizes(); SmallVector cstKernel, cstPadding, cstStrides; Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (unsigned i = 2; i < inputRank; i++) { int64_t kernelSize = inputShape[i] - resultShape[i] + 1; cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); cstPadding.push_back(cstZero); cstStrides.push_back(cstOne); } Value kernelSizeList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); Value paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); Value stridesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value cstFalse = rewriter.create(binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; Value cstNone = rewriter.create(binder.getLoc()); if (inputRank == 3) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad); return success(); } else if (inputRank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } else if (inputRank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } return failure(); }); patterns.onOp( "LayerNormalization", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType yType, meanType, invStdDevType; Value x, scale, b; int64_t axis, stashType; float epsilon; if (binder.tensorOperandAtIndex(x, 0) || binder.tensorOperandAtIndex(scale, 1) || binder.tensorOperandAtIndex(b, 2) || binder.tensorResultTypeAtIndex(yType, 0) || binder.s64IntegerAttr(axis, "axis", -1) || binder.f32FloatAttr(epsilon, "epsilon", 0.00001) || binder.s64IntegerAttr(stashType, "stash_type", 1)) return failure(); Value constEpsilon = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(epsilon)); unsigned rank = 1; if (std::optional maybeRank = Torch::getTensorRank(x)) rank = *maybeRank; SmallVector normalized; axis = Torch::toPositiveDim(axis, rank); auto xType = x.getType().cast(); if (!xType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input (X) to have sizes"); } ArrayRef xShape = xType.getSizes(); for (int64_t n = axis; n < rank; n++) { normalized.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n]))); } Value normalized_shape = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), normalized); int64_t numResults = binder.op->getNumResults(); if (numResults == 1) { SmallVector reducedShape(rank, 1); for (int64_t i = 0; i < axis; i++) reducedShape[i] = xShape[i]; auto reducedType = xType.getWithSizesAndDtype( reducedShape, xType.getOptionalDtype()); Value y = rewriter .create( binder.getLoc(), yType, /*meanType=*/reducedType, /*invStdDevType=*/reducedType, x, normalized_shape, scale, b, constEpsilon) .getResult0(); rewriter.replaceOp(binder.op, y); return success(); } if (numResults == 3) { if (binder.tensorResultTypeAtIndex(meanType, 1) || binder.tensorResultTypeAtIndex(invStdDevType, 2)) return failure(); rewriter.replaceOpWithNewOp( binder.op, yType, meanType, invStdDevType, x, normalized_shape, scale, b, constEpsilon); return success(); } return rewriter.notifyMatchFailure( binder.op, "Unimplemented: expected either 1 or 3 results"); }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; float alpha; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.f32FloatAttr(alpha, "alpha", 0.01f)) return failure(); Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAlpha); return success(); }); patterns.onOp( "Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, pads, constantValue, axes; std::string mode; // TODO: The `axes` parameter is not supported yet. if (!binder.tensorOperandAtIndex(axes, 3)) { return rewriter.notifyMatchFailure( binder.op, "The axes parameter is not supported yet"); } if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(pads, 1) || binder.tensorOperandAtIndex(constantValue, 2) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(mode, "mode", "constant")) return failure(); Location loc = binder.getLoc(); // Get pads shape and rank. The pads tensor is expected to be 1-D // tensor. auto padsTensorType = pads.getType().cast(); if (!padsTensorType || !padsTensorType.hasSizes()) { return rewriter.notifyMatchFailure(binder.op, "Expect non empty pad tensor"); } ArrayRef padsShape = padsTensorType.getSizes(); int64_t padsRank = padsShape.size(); if (padsRank != 1) { return rewriter.notifyMatchFailure(binder.op, "Expect 1-D pad tensor"); } // Extract all the values of 1-D pad tensor and create a list of all // these values as torch.pad op expects pad list. int64_t padsSize = padsShape[0]; Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); SmallVector padsTensorValue; SmallVector emptyShape; Type padsElemType = Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape, padsTensorType.getOptionalDtype()); for (uint32_t i = 0; i < padsSize; ++i) { Value index = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); padsTensorValue.emplace_back(rewriter.create( loc, padsElemType, pads, constZero, index)); } // The torch.pad op expects a different arrangement of padding pairs for // each dimension as compared to the onnx.pad op. So, rearranging pad // tensor to satisfy torch.pad op semantics. SmallVector padsRearrange; for (uint32_t i = 0; i < padsSize / 2; i++) { padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) - 1 - i]); padsRearrange.emplace_back(padsTensorValue[padsSize - 1 - i]); } Value padsSizeList = rewriter .create( loc, Torch::ListType::get(rewriter.getType()), padsRearrange) .getResult(0); Value modeVal = rewriter.create( loc, rewriter.getStringAttr(mode)); // The constant value is a 0-d tensor, which needs to be converted to a // float scalar as torch.pad op expects a float scalar. auto constValueType = constantValue.getType().cast(); if (!constValueType) { return rewriter.notifyMatchFailure(binder.op, "Expect non-none constant value"); } auto resultTensorType = Torch::ValueTensorType::get( constValueType.getContext(), emptyShape, rewriter.getF64Type()); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); Value constFloatValue = rewriter.create( loc, resultTensorType, constantValue, Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getOptionalDtype()), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); Value constScalar = rewriter.create( loc, rewriter.getType(), constFloatValue); rewriter.replaceOpWithNewOp( binder.op, resultType, data, padsSizeList, modeVal, constScalar); return success(); }); patterns.onOp("Pow", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; if (binder.tensorOperand(tensor) || binder.tensorResultType(resultType)) { return failure(); } Value noneVal = rewriter.create(binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, tensor, /*memory_format=*/noneVal); return success(); }); }