//===------------------------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with // sinceVersion==1. We interpret this to include the following spec // diffs that are irrelevant to this level of lowering: // * Supported element types. // * Limited broadcasting to full broadcasting support. // // There are a lot of spec revisions that basically generalized elementwise // to be more normal and a direct translation vs a special case. This // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { patterns.onOp("HardSigmoid", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensorOperand; float alpha, beta; if (binder.tensorOperand(tensorOperand) || binder.f32FloatAttr(alpha, "alpha", 0.2) || binder.f32FloatAttr(beta, "beta", 0.5) || binder.tensorResultType(resultType)) return failure(); // HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta)) Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); // Expression: alpha * x + beta Value alpha_x_plus_beta = rewriter.create( binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha); // Expression: min(1, alpha * x + beta) Value constantOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantOne); Value minExpression = rewriter.create( binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); // Expression: max(0, min(1, alpha * x + beta)) Value constantZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantZero); rewriter.replaceOpWithNewOp( binder.op, resultType, zeroTensor, minExpression); return success(); }); patterns.onOp( "Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value operand; Torch::ValueTensorType resultType; std::string approximate; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(approximate, "approximate", "none")) return failure(); Value vApproximate = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getStringAttr(approximate)); rewriter.replaceOpWithNewOp(binder.op, resultType, operand, vApproximate); return success(); }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("LessOrEqual", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Log", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("MatMul", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Mul", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("GreaterOrEqual", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType) || operands.size() == 0) { return failure(); } Value result = operands[0]; for (int i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp( binder.op, result.getDefiningOp()); return success(); }); patterns.onOp("Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType) || operands.size() == 0) { return failure(); } Value result = operands[0]; for (int i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp( binder.op, result.getDefiningOp()); return success(); }); patterns.onOp("Neg", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Not", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Or", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "GatherElements", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, indices; int64_t axis; if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(indices, 1) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); Value sparseGrad = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, sparseGrad); return success(); }); patterns.onOp( "Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value a, b, c; float alpha, beta; int64_t transA, transB; if (binder.tensorOperandAtIndex(a, 0) || binder.tensorOperandAtIndex(b, 1) || binder.tensorOperandAtIndex(c, 2) || binder.s64IntegerAttr(transA, "transA", 0) || binder.s64IntegerAttr(transB, "transB", 0) || binder.f32FloatAttr(alpha, "alpha", 1.0) || binder.f32FloatAttr(beta, "beta", 1.0) || binder.tensorResultType(resultType)) return failure(); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); Value one = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); auto transpose = [&](Value m) -> Value { auto tty = m.getType().cast(); auto shape = tty.getOptionalSizes(); if (shape.has_value()) { llvm::SmallVector newShape(shape.value()); std::reverse(newShape.begin(), newShape.end()); shape = std::move(newShape); } auto oty = Torch::ValueTensorType::get(tty.getContext(), shape, tty.getOptionalDtype()); return rewriter.create(binder.getLoc(), oty, m, zero, one); }; if (transA) { a = transpose(a); } if (transB) { b = transpose(b); } Value mm = rewriter.create(binder.getLoc(), resultType, a, b); if (alpha == 1.0 && beta == 1.0) { rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, one); return success(); } if (alpha != 1.0 && beta != 1.0) { Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); mm = rewriter.create( binder.getLoc(), resultType, mm, constAlpha); alpha = 1.0; } if (alpha != 1.0) { std::swap(alpha, beta); std::swap(mm, c); } Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, constBeta); return success(); }); patterns.onOp( "GlobalAveragePool", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); auto inputTensorType = operand.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); unsigned inputRank = inputShape.size(); if (!resultType || !resultType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected result type having sizes"); } ArrayRef resultShape = resultType.getSizes(); SmallVector cstKernel, cstPadding, cstStrides; Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (unsigned i = 2; i < inputRank; i++) { int64_t kernelSize = inputShape[i] - resultShape[i] + 1; cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); cstPadding.push_back(cstZero); cstStrides.push_back(cstOne); } Value kernelSizeList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); Value paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); Value stridesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value cstFalse = rewriter.create(binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; Value cstNone = rewriter.create(binder.getLoc()); if (inputRank == 3) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad); return success(); } else if (inputRank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } else if (inputRank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } return failure(); }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; float alpha; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.f32FloatAttr(alpha, "alpha", 0.01)) return failure(); Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAlpha); return success(); }); patterns.onOp("Pow", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); }