//===------------------------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with // sinceVersion==1. We interpret this to include the following spec // diffs that are irrelevant to this level of lowering: // * Supported element types. // * Limited broadcasting to full broadcasting support. // // There are a lot of spec revisions that basically generalized elementwise // to be more normal and a direct translation vs a special case. This // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { patterns.onOp( "HardSigmoid", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensorOperand; float alpha, beta; if (binder.tensorOperand(tensorOperand) || binder.f32FloatAttr(alpha, "alpha", 0.2f) || binder.f32FloatAttr(beta, "beta", 0.5f) || binder.tensorResultType(resultType)) return failure(); // HardSigmoid computes the following expression: // max(0, min(1, alpha * x + beta)) Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); // Expression: alpha * x + beta Value alpha_x_plus_beta = rewriter.create( binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha); // Expression: min(1, alpha * x + beta) Value constantOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantOne); Value minExpression = rewriter.create( binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); // Expression: max(0, min(1, alpha * x + beta)) Value constantZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantZero); rewriter.replaceOpWithNewOp( binder.op, resultType, zeroTensor, minExpression); return success(); }); patterns.onOp( "Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value operand; Torch::ValueTensorType resultType; std::string approximate; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(approximate, "approximate", "none")) return failure(); Value vApproximate = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getStringAttr(approximate)); rewriter.replaceOpWithNewOp(binder.op, resultType, operand, vApproximate); return success(); }); patterns.onOp( "GridSample", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; Value grid; if (binder.tensorOperandAtIndex(input, 0) || binder.tensorOperandAtIndex(grid, 1) || binder.tensorResultType(resultType)) return rewriter.notifyMatchFailure( binder.op, "operand grid_sampler bind failure"); auto inputTensorType = input.getType().cast(); ArrayRef inputShape = inputTensorType.getSizes(); uint32_t inputRank = inputShape.size(); auto gridTensorType = grid.getType().cast(); ArrayRef gridShape = gridTensorType.getSizes(); uint32_t gridRank = gridShape.size(); if (inputRank != 4) return rewriter.notifyMatchFailure(binder.op, "only input rank 4 supported"); if (gridRank != 4) return rewriter.notifyMatchFailure(binder.op, "only grid rank 4 supported"); if (inputShape[0] != gridShape[0]) return rewriter.notifyMatchFailure( binder.op, "N must be same for input and grid"); if (gridShape[3] != 2) return rewriter.notifyMatchFailure(binder.op, "gridShape[3] expected to be 2"); std::string mode; if (binder.customOpNameStringAttr(mode, "mode", "bilinear")) return rewriter.notifyMatchFailure(binder.op, "mode bind failure"); if (mode != "bilinear") return rewriter.notifyMatchFailure( binder.op, "currently only mode : bilinear supported"); std::string padding; if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) return rewriter.notifyMatchFailure(binder.op, "padding_mode bind failure"); if (padding != "zeros") return rewriter.notifyMatchFailure( binder.op, "currently only padding_mode : zeros supported"); int64_t align; if (binder.s64IntegerAttr(align, "align_corners", 0)) return rewriter.notifyMatchFailure(binder.op, "align_corners bind failure"); if (align != 0) return rewriter.notifyMatchFailure( binder.op, "currently only align_corners : 0 supported"); Value interpolationMode = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); Value paddingMode = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); Value alignCorners = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, grid, interpolationMode, paddingMode, alignCorners); return success(); }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("LessOrEqual", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Log", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("MatMul", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "MatMulInteger", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs, lhsZp, rhsZp; if (binder.tensorOperandAtIndex(lhs, 0) || binder.tensorOperandAtIndex(rhs, 1) || binder.tensorResultType(resultType)) return failure(); if (binder.tensorOperandAtIndex(lhsZp, 2)) { lhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); } if (binder.tensorOperandAtIndex(rhsZp, 3)) { rhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); } auto lhsTy = dyn_cast(lhs.getType()); auto rhsTy = dyn_cast(rhs.getType()); if (auto zpTy = dyn_cast(lhsZp.getType())) { for (auto dim : zpTy.getSizes()) if (dim != 1) return failure(); lhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), lhsZp); } if (auto zpTy = dyn_cast(rhsZp.getType())) { for (auto dim : zpTy.getSizes()) if (dim != 1) return failure(); rhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), rhsZp); } Value scale = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); auto q = [&](Type qty) -> Type { if (qty.isSignedInteger(8)) return rewriter.getType(); if (qty.isUnsignedInteger(8)) return rewriter.getType(); if (qty.isSignedInteger(32)) return rewriter.getType(); return {}; }; Type lhsQTy = rewriter.getType( lhsTy.getOptionalSizes(), q(lhsTy.getDtype())); Type rhsQTy = rewriter.getType( rhsTy.getOptionalSizes(), q(rhsTy.getDtype())); lhs = rewriter.create( binder.getLoc(), lhsQTy, lhs, scale, lhsZp); rhs = rewriter.create( binder.getLoc(), rhsQTy, rhs, scale, rhsZp); rewriter.replaceOpWithNewOp(binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Mul", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return rewriter.notifyMatchFailure(binder.op, "auto_pad bind failure"); if (autoPad != "NOTSET") return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); Torch::ValueTensorType resultType; Value operand; bool ceilMode; int64_t storageOrder; // TODO: Add support for indices output and storage_order if (binder.tensorOperand(operand) || binder.s64BoolAttr(ceilMode, "ceil_mode", false) || binder.s64IntegerAttr(storageOrder, "storage_order", 0) || binder.tensorResultType(resultType)) return rewriter.notifyMatchFailure( binder.op, "operand/ceil_mode/storage_order/resultType bind failure"); if (storageOrder != 0) return rewriter.notifyMatchFailure( binder.op, "storage_order setting is not supported."); // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(operand); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); int64_t rank = *maybeRank; int64_t spatial = rank - 2; SmallVector kernel, padding, strides, dilations; if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) return rewriter.notifyMatchFailure(binder.op, "kernel_shape bind failure"); if (kernel.size() != static_cast(spatial)) return rewriter.notifyMatchFailure( binder.op, "kernel list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(padding, "pads", {})) return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); if (!padding.empty() && padding.size() != static_cast(2 * spatial)) return rewriter.notifyMatchFailure( binder.op, "padding list must contain (begin,end) pair for each " "spatial axis"); if (binder.s64IntegerArrayAttr(strides, "strides", {})) return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); if (!strides.empty() && strides.size() != static_cast(spatial)) return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); if (padding.empty()) padding.resize(spatial, 0); if (strides.empty()) strides.resize(spatial, 1); if (dilations.empty()) dilations.resize(spatial, 1); // If the padding is symmetric we can push the padding operation to the // torch operator. if (padding.size() == static_cast(2 * spatial)) { bool equal = true; for (int i = 0; i < spatial; ++i) { equal = equal && (padding[i] == padding[i + spatial]); } if (equal) padding.resize(spatial); } // Torch pool operators require equal padding on each size of each // dimension so we materialize the padding behavior explicitly and set // the padding to 0. if (padding.size() == static_cast(2 * spatial)) { auto operandTy = cast(operand.getType()); llvm::SmallVector shuffledPadding(spatial * 2); llvm::SmallVector paddedShape(operandTy.getSizes()); shuffledPadding.resize(2 * rank); for (int i = 0; i < spatial; ++i) { paddedShape[i + 2] += padding[i] + padding[i + spatial]; shuffledPadding[2 * i] = padding[i]; shuffledPadding[2 * i + 1] = padding[i + spatial]; } Value shuffledPaddingList = createConstantIntList(binder, rewriter, padding); Value zero; if (resultType.getDtype().isa()) { zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr( std::numeric_limits::lowest())); } else if (resultType.getDtype().isa()) { zero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr( std::numeric_limits::lowest())); } auto paddedInputTy = rewriter.getType( paddedShape, operandTy.getDtype()); operand = rewriter.create( binder.getLoc(), paddedInputTy, operand, shuffledPaddingList, zero); padding.clear(); padding.resize(spatial, 0); } Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesList = createConstantIntList(binder, rewriter, strides); Value dilationsList = createConstantIntList(binder, rewriter, dilations); Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); if (rank == 3) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: AtenMaxPool1dOp"); if (rank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } if (rank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, dilationsList, cstCeilMode); return success(); } return rewriter.notifyMatchFailure(binder.op, "No rank is matched."); }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("GreaterOrEqual", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "InstanceNormalization", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; float eps; if (binder.tensorOperands(operands, 3) || binder.tensorResultType(resultType) || operands.size() != 3 || binder.f32FloatAttr(eps, "epsilon", 1e-05f)) { return failure(); } Value none = rewriter.create(binder.getLoc()); Value boolTrue = rewriter.create(binder.getLoc(), true); Value boolFalse = rewriter.create(binder.getLoc(), false); auto epsValue = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(eps)); auto momentum = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); rewriter.replaceOpWithNewOp( binder.op, resultType, /* input */ operands[0], /* weight */ operands[1], /* bias */ operands[2], /* running mean */ none, /* running var */ none, /* use input stats */ boolTrue, momentum, epsValue, /* cudnn enabled */ boolFalse); return success(); }); patterns.onOp( "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType) || operands.size() == 0) { return failure(); } Value result = operands[0]; for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp(binder.op, result.getDefiningOp()); return success(); }); patterns.onOp( "Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType) || operands.size() == 0) { return failure(); } Value result = operands[0]; for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp(binder.op, result.getDefiningOp()); return success(); }); patterns.onOp("Neg", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Not", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Or", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, indices; int64_t axis; if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(indices, 1) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Location loc = binder.getLoc(); auto ctx = binder.op->getContext(); auto indicesTy = cast(indices.getType()); auto dataTy = cast(data.getType()); if (!dataTy || !dataTy.hasSizes()) return failure(); if (axis < 0) axis += dataTy.getSizes().size(); Value index = rewriter.create( loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis)); // Apply bounds checking on the input: auto intTy = rewriter.getType(); auto boolTy = rewriter.getType( indicesTy.getSizes(), rewriter.getI1Type()); Value zero = rewriter.create( loc, intTy, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create( loc, intTy, rewriter.getI64IntegerAttr(1)); Value lt = rewriter.create(loc, boolTy, indices, zero); Value dim = rewriter.create(loc, intTy, data, index); Value add = rewriter.create(loc, indicesTy, indices, dim, one); indices = rewriter.create(loc, indicesTy, lt, add, indices); auto intListTy = rewriter.getType( rewriter.getType()); auto indicesSize = rewriter.create(loc, intListTy, indices); // Determine the collapsed dim size: auto indicesCt = 1; for (auto sz : indicesTy.getSizes()) { if (sz == Torch::kUnknownSize) { indicesCt = Torch::kUnknownSize; break; } indicesCt *= sz; } auto flattenTy = rewriter.getType( SmallVector{indicesCt}, indicesTy.getOptionalDtype()); Value rank = rewriter.create(loc, intTy, indices); Value end = rewriter.create(loc, rank, one); indices = rewriter.create( loc, flattenTy, indices, zero, end); llvm::SmallVector gatherShape(dataTy.getSizes()); gatherShape[axis] = indicesCt; auto gatherTy = rewriter.getType( gatherShape, dataTy.getOptionalDtype()); Value gather = rewriter.create( loc, gatherTy, data, index, indices); rewriter.replaceOpWithNewOp( binder.op, resultType, gather, index, indicesSize); return success(); }); patterns.onOp( "GatherElements", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, indices; int64_t axis; if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(indices, 1) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); Value sparseGrad = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, sparseGrad); return success(); }); patterns.onOp( "Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value a, b, c; float alpha, beta; int64_t transA, transB; if (binder.tensorOperandAtIndex(a, 0) || binder.tensorOperandAtIndex(b, 1) || binder.tensorOperandAtIndex(c, 2) || binder.s64IntegerAttr(transA, "transA", 0) || binder.s64IntegerAttr(transB, "transB", 0) || binder.f32FloatAttr(alpha, "alpha", 1.0f) || binder.f32FloatAttr(beta, "beta", 1.0f) || binder.tensorResultType(resultType)) return failure(); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); Value one = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); auto transpose = [&](Value m) -> Value { auto tty = m.getType().cast(); auto shape = tty.getOptionalSizes(); if (shape.has_value()) { llvm::SmallVector newShape(shape.value()); std::reverse(newShape.begin(), newShape.end()); shape = std::move(newShape); } auto oty = Torch::ValueTensorType::get(tty.getContext(), shape, tty.getOptionalDtype()); return rewriter.create(binder.getLoc(), oty, m, zero, one); }; if (transA) { a = transpose(a); } if (transB) { b = transpose(b); } Value mm = rewriter.create(binder.getLoc(), resultType, a, b); if (alpha == 1.0 && beta == 1.0) { rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, one); return success(); } if (alpha != 1.0 && beta != 1.0) { Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); mm = rewriter.create( binder.getLoc(), resultType, mm, constAlpha); alpha = 1.0; } if (alpha != 1.0) { std::swap(alpha, beta); std::swap(mm, c); } Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, constBeta); return success(); }); patterns.onOp( "GlobalAveragePool", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); auto inputTensorType = operand.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); unsigned inputRank = inputShape.size(); if (!resultType || !resultType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected result type having sizes"); } ArrayRef resultShape = resultType.getSizes(); SmallVector cstKernel, cstPadding, cstStrides; Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (unsigned i = 2; i < inputRank; i++) { int64_t kernelSize = inputShape[i] - resultShape[i] + 1; cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); cstPadding.push_back(cstZero); cstStrides.push_back(cstOne); } Value kernelSizeList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); Value paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); Value stridesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value cstFalse = rewriter.create(binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; Value cstNone = rewriter.create(binder.getLoc()); if (inputRank == 3) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad); return success(); } else if (inputRank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } else if (inputRank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } return failure(); }); patterns.onOp( "LayerNormalization", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType yType, meanType, invStdDevType; Value x, scale, b; int64_t axis, stashType; float epsilon; if (binder.tensorOperandAtIndex(x, 0) || binder.tensorOperandAtIndex(scale, 1) || binder.tensorOperandAtIndex(b, 2) || binder.tensorResultTypeAtIndex(yType, 0) || binder.s64IntegerAttr(axis, "axis", -1) || binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) || binder.s64IntegerAttr(stashType, "stash_type", 1)) return failure(); Value constEpsilon = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(epsilon)); unsigned rank = 1; if (std::optional maybeRank = Torch::getTensorRank(x)) rank = *maybeRank; SmallVector normalized; axis = Torch::toPositiveDim(axis, rank); auto xType = x.getType().cast(); if (!xType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input (X) to have sizes"); } ArrayRef xShape = xType.getSizes(); for (int64_t n = axis; n < rank; n++) { normalized.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n]))); } Value normalized_shape = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), normalized); int64_t numResults = binder.op->getNumResults(); if (numResults == 1) { SmallVector reducedShape(rank, 1); for (int64_t i = 0; i < axis; i++) reducedShape[i] = xShape[i]; auto reducedType = xType.getWithSizesAndDtype( reducedShape, xType.getOptionalDtype()); Value y = rewriter .create( binder.getLoc(), yType, /*meanType=*/reducedType, /*invStdDevType=*/reducedType, x, normalized_shape, scale, b, constEpsilon) .getResult0(); rewriter.replaceOp(binder.op, y); return success(); } if (numResults == 3) { if (binder.tensorResultTypeAtIndex(meanType, 1) || binder.tensorResultTypeAtIndex(invStdDevType, 2)) return failure(); rewriter.replaceOpWithNewOp( binder.op, yType, meanType, invStdDevType, x, normalized_shape, scale, b, constEpsilon); return success(); } return rewriter.notifyMatchFailure( binder.op, "Unimplemented: expected either 1 or 3 results"); }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; float alpha; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.f32FloatAttr(alpha, "alpha", 0.01f)) return failure(); Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAlpha); return success(); }); patterns.onOp( "Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, pads, axes; std::string mode; // TODO: The `axes` parameter is not supported yet. if (!binder.tensorOperandAtIndex(axes, 3)) { return rewriter.notifyMatchFailure( binder.op, "The axes parameter is not supported yet"); } if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(pads, 1) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(mode, "mode", "constant")) return failure(); Location loc = binder.getLoc(); Value constantValue; if (binder.getNumOperands() >= 3) { if (binder.tensorOperandAtIndex(constantValue, 2)) { llvm::errs() << "failed to bind to index 2\n"; return failure(); } } else { auto dataTensorType = data.getType().cast(); auto maybeZeroAttr = [&]() -> std::optional { if (dataTensorType.getDtype().isa()) { return rewriter.getI64IntegerAttr(0); } if (dataTensorType.getDtype().isa()) { return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f); } return std::nullopt; }(); if (!maybeZeroAttr) { return rewriter.notifyMatchFailure( binder.op, "expected integer or float data tensor"); } auto shapedType = dataTensorType.toBuiltinTensor(); auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr); constantValue = rewriter.create( loc, dataTensorType, splat); } // Get pads shape and rank. The pads tensor is expected to be 1-D // tensor. auto padsTensorType = pads.getType().cast(); if (!padsTensorType || !padsTensorType.hasSizes()) { return rewriter.notifyMatchFailure(binder.op, "Expect non empty pad tensor"); } ArrayRef padsShape = padsTensorType.getSizes(); int64_t padsRank = padsShape.size(); if (padsRank != 1) { return rewriter.notifyMatchFailure(binder.op, "Expect 1-D pad tensor"); } // Extract all the values of 1-D pad tensor and create a list of all // these values as torch.pad op expects pad list. int64_t padsSize = padsShape[0]; Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); SmallVector padsTensorValue; SmallVector emptyShape; Type padsElemType = Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape, padsTensorType.getOptionalDtype()); for (uint32_t i = 0; i < padsSize; ++i) { Value index = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); padsTensorValue.emplace_back(rewriter.create( loc, padsElemType, pads, constZero, index)); } // The torch.pad op expects a different arrangement of padding pairs for // each dimension as compared to the onnx.pad op. So, rearranging pad // tensor to satisfy torch.pad op semantics. SmallVector padsRearrange; for (uint32_t i = 0; i < padsSize / 2; i++) { padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) - 1 - i]); padsRearrange.emplace_back(padsTensorValue[padsSize - 1 - i]); } Value padsSizeList = rewriter .create( loc, Torch::ListType::get(rewriter.getType()), padsRearrange) .getResult(0); Value modeVal = rewriter.create( loc, rewriter.getStringAttr(mode)); // The constant value is a 0-d tensor, which needs to be converted to a // float scalar as torch.pad op expects a float scalar. auto constValueType = constantValue.getType().cast(); if (!constValueType) { return rewriter.notifyMatchFailure(binder.op, "Expect non-none constant value"); } auto resultTensorType = Torch::ValueTensorType::get( constValueType.getContext(), emptyShape, rewriter.getF64Type()); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); Value constFloatValue = rewriter.create( loc, resultTensorType, constantValue, Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getOptionalDtype()), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); Value constScalar = rewriter.create( loc, rewriter.getType(), constFloatValue); rewriter.replaceOpWithNewOp( binder.op, resultType, data, padsSizeList, modeVal, constScalar); return success(); }); patterns.onOp("Pow", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; if (binder.tensorOperand(tensor) || binder.tensorResultType(resultType)) { return failure(); } Value noneVal = rewriter.create(binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, tensor, /*memory_format=*/noneVal); return success(); }); patterns.onOp( "Mean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { if (binder.op->getNumOperands() == 1) { Torch::ValueTensorType resultType; Value x; if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOp(binder.op, x); return success(); } Torch::ValueTensorType resultType; SmallVector valList; int64_t numOperands = binder.op->getNumOperands(); Value numOperandsConstant = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), numOperands)); if (binder.tensorOperands(valList, numOperands) || binder.tensorResultType(resultType)) return failure(); Value constOne = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); // Short circuit to binary add Value curr = rewriter.create( binder.getLoc(), resultType, valList[0], valList[1], constOne); if (numOperands == 2) { rewriter.replaceOpWithNewOp( binder.op, resultType, curr, numOperandsConstant); return success(); } // When binder.op->getNumOperands() > 2 auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( binder.op->getContext()); for (int i = 2; i < numOperands; i++) { if (i == numOperands - 1) { curr = rewriter.create( binder.getLoc(), resultType, curr, valList[i], constOne); } else { curr = rewriter.create( binder.getLoc(), baseType, curr, valList[i], constOne); } } rewriter.replaceOpWithNewOp( binder.op, resultType, curr, numOperandsConstant); return success(); }); patterns.onOp( "IsInf", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; int64_t neg; int64_t pos; if (binder.tensorOperand(tensor) || binder.s64IntegerAttr(neg, "detect_negative", 1) || binder.s64IntegerAttr(pos, "detect_positive", 1) || binder.tensorResultType(resultType)) { return failure(); } if (neg == 0) { // replace all negative infs with 0 tensor = rewriter.create( binder.getLoc(), dyn_cast(tensor.getType()), tensor); } if (pos == 0) { // first use neg op to flip positive inf to negative inf. Then relu to // replace all positive infs with 0. Value flip = rewriter.create( binder.getLoc(), dyn_cast(tensor.getType()), tensor); tensor = rewriter.create( binder.getLoc(), dyn_cast(flip.getType()), flip); } rewriter.replaceOpWithNewOp(binder.op, resultType, tensor); return success(); }); patterns.onOp("IsNaN", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; if (binder.tensorOperand(tensor) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, tensor); return success(); }); patterns.onOp("PRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; Value slope; if (binder.tensorOperands(tensor, slope) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, tensor, slope); return success(); }); }