//===------------------------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/Support/FormatVariadic.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; namespace { LogicalResult windowFunctionImpl(OpBinder binder, ConversionPatternRewriter &rewriter, Value size, Value a0, Value a1, Value a2, Torch::ValueTensorType resultType, int64_t output_datatype, int64_t periodic) { Location loc = binder.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); double isPeriodicFp = static_cast(periodic); Value zero = b.create(rewriter.getF64FloatAttr(0.0)); Value one = b.create(rewriter.getF64FloatAttr(1.0)); Value two = b.create(rewriter.getF64FloatAttr(2.0)); constexpr double pi = llvm::numbers::pi; Value tau = b.create( rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); Value noneVal = b.create(); Value cstFalse = b.create(false); Value float32Type = b.create( rewriter.getI64IntegerAttr(/*float32Type*/ 6)); // Create an f32 ValueTensorType with thse same size as size, the // operand auto shapeOfOperand = dyn_cast(size.getType()).getOptionalSizes(); auto f32ResultType = rewriter.getType( shapeOfOperand, rewriter.getF32Type()); Value periodicSizeFloat = b.create( f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal); Value symmetricSizeFloat = b.create( periodicSizeFloat.getType(), periodicSizeFloat, one, one); Value isPeriodic = b.create(rewriter.getF64FloatAttr(isPeriodicFp)); Value isSymmetricFloat = b.create( rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); Value periodicComponent = b.create( periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic); Value symmetricComponent = b.create( symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat); Value sizeFloat = b.create( symmetricComponent.getType(), symmetricComponent, periodicComponent, one); // Here, size can be used in the place of periodicSizeFloat, as the // latter is just a float representation of the former. Value scalarLimit = getItemOp(binder, rewriter, size); Value rangeArr = b.create( resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal); Value rangeTimesTau = b.create(resultType, rangeArr, tau); Value rangeAngular = b.create(resultType, rangeTimesTau, sizeFloat); Value twoRangeAngular = b.create(resultType, rangeAngular, two); Value cosRangeAngular = b.create(resultType, rangeAngular); Value cosTwoRangeAngular = b.create(resultType, twoRangeAngular); Value a1Component = b.create(resultType, cosRangeAngular, a1); Value a2Component = b.create(resultType, cosTwoRangeAngular, a2); // AtenSubScalarOp actually requires a tensor operand as the LHS, that // is, operand #1. Therefore, to avoid errors, the onnx implementation // has been modified. a1 has been changed to negative half, and the // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add // operation is commutative. Value subA1Component = b.create(resultType, a1Component, a0, one); Value result = b.create(resultType, subA1Component, a2Component, one); std::optional dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(output_datatype); if (!dtypeIntTorch.has_value()) { return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); } Value outputDtype = b.create( rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), dtypeIntTorch.value())); rewriter.replaceOpWithNewOp( binder.op, resultType, result, outputDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/noneVal); return success(); } } // namespace // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with // sinceVersion==1. We interpret this to include the following spec // diffs that are irrelevant to this level of lowering: // * Supported element types. // * Limited broadcasting to full broadcasting support. // // There are a lot of spec revisions that basically generalized elementwise // to be more normal and a direct translation vs a special case. This // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainAtoF( OnnxCustomOpConversionPattern &patterns) { patterns.onOp("Abs", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); // Add became forward compatible with Torch in version 7. patterns.onOp("Add", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); Value const1 = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs, const1); return success(); }); // TODO: AffineGrid patterns.onOp("And", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "ArgMax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; bool keepDims; int64_t axis; bool selectLastIndex; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.s64BoolAttr(keepDims, "keepdims", true) || binder.s64IntegerAttr(axis, "axis", 0) || binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) return failure(); // ONNX allows negative axis. auto operandSizes = cast(operand.getType()).getSizes(); if (axis < 0) axis += operandSizes.size(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); Value constKeepDims = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); if (selectLastIndex) { Value dims = createConstantIntList(binder, rewriter, {axis}); auto operandTy = dyn_cast(operand.getType()); operand = rewriter.create( binder.getLoc(), operandTy, operand, dims); Value argmax = rewriter.create( binder.getLoc(), resultType, operand, constAxis, constKeepDims); Value offset = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); Value alpha = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value sub = rewriter.create( binder.getLoc(), resultType, argmax, offset, alpha); rewriter.replaceOpWithNewOp(binder.op, resultType, sub); return success(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAxis, constKeepDims); return success(); }); patterns.onOp( "ArgMin", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; bool keepDims; int64_t axis; bool selectLastIndex; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.s64BoolAttr(keepDims, "keepdims", true) || binder.s64IntegerAttr(axis, "axis", 0) || binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) return failure(); // ONNX allows negative axis. auto operandSizes = cast(operand.getType()).getSizes(); if (axis < 0) axis += operandSizes.size(); Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); Value constKeepDims = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); if (selectLastIndex) { Value dims = createConstantIntList(binder, rewriter, {axis}); auto operandTy = dyn_cast(operand.getType()); operand = rewriter.create( binder.getLoc(), operandTy, operand, dims); Value argmin = rewriter.create( binder.getLoc(), resultType, operand, constAxis, constKeepDims); Value offset = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); Value alpha = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value sub = rewriter.create( binder.getLoc(), resultType, argmin, offset, alpha); rewriter.replaceOpWithNewOp(binder.op, resultType, sub); return success(); } rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constAxis, constKeepDims); return success(); }); patterns.onOp("Asin", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Asinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Atan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Atanh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Acos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Acosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("BatchNormalization", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input, weight, bias, runningMean, runningVar; bool training; float momentum, eps; if (binder.s64BoolAttr(training, "training_mode", 0)) return failure(); if (training) { // TODO: Add support for training = true return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: training = true"); } if (binder.tensorOperandAtIndex(input, 0) || binder.tensorOperandAtIndex(weight, 1) || binder.tensorOperandAtIndex(bias, 2) || binder.tensorOperandAtIndex(runningMean, 3) || binder.tensorOperandAtIndex(runningVar, 4) || binder.f32FloatAttr(momentum, "momentum", 0.9f) || binder.f32FloatAttr(eps, "epsilon", 1e-05f) || binder.tensorResultType(resultType)) return failure(); Value cstFalse = rewriter.create( binder.getLoc(), false); Value cstMomentum = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(momentum)); Value cstEps = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(eps)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, weight, bias, runningMean, runningVar, /*training=*/cstFalse, cstMomentum, cstEps, /*cudnn_enabled=*/cstFalse); return success(); }); patterns.onOp( "AveragePool", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; SmallVector dilations; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); if (autoPad != "NOTSET") { // TODO: Add support for `auto_pad` != "NOTSET" return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); } Torch::ValueTensorType resultType; Value operand; bool ceilMode, countIncludePad; if (binder.tensorOperand(operand) || binder.s64BoolAttr(ceilMode, "ceil_mode", false) || binder.s64BoolAttr(countIncludePad, "count_include_pad", false) || binder.tensorResultType(resultType)) return failure(); // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(operand); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; SmallVector kernel, padding, strides; if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) { return failure(); } if (kernel.size() != rank - 2) { return rewriter.notifyMatchFailure( binder.op, "kernel list size does not match the number of axes"); } SmallVector defaultPadding(2 * (rank - 2), 0); if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { return failure(); } if (padding.size() != 2 * (rank - 2)) { return rewriter.notifyMatchFailure( binder.op, "padding list size does not match twice the number of axes"); } if (binder.s64IntegerArrayAttr( strides, "strides", llvm::SmallVector(rank - 2, 1))) { return failure(); } if (strides.size() != 1 && strides.size() != rank - 2) { return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); } SmallVector cstKernel, cstPadding, cstStridesDilations; for (int64_t i : kernel) { cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all // axes x. int64_t paddingSizeHalf = padding.size() / 2; for (int64_t i = 0; i < paddingSizeHalf; ++i) { // Check if onnx padding attribute is symmetric. if (padding[i] != padding[i + paddingSizeHalf]) return rewriter.notifyMatchFailure( binder.op, "onnx padding attribute is not symmetric"); cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } for (int64_t i : strides) { cstStridesDilations.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } // No dilations attribute in pytorch avgpool op, so use this trick to // encode dilation into strides. Then in the following torchtolinalg // lowering, decode strides into strides + dilation. // [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...] if (binder.s64IntegerArrayAttr( dilations, "dilations", llvm::SmallVector(rank - 2, 1))) { return failure(); } for (auto dilation : dilations) { cstStridesDilations.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(dilation))); } Value kernelSizeList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); Value paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); Value stridesDilationsList = rewriter.create( binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstStridesDilations); Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); Value cstCountIncludePad = rewriter.create( binder.getLoc(), countIncludePad); Value cstNone = rewriter.create(binder.getLoc()); if (rank == 3) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesDilationsList, paddingList, cstCeilMode, cstCountIncludePad); return success(); } else if (rank == 4) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesDilationsList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } else if (rank == 5) { rewriter.replaceOpWithNewOp( binder.op, resultType, operand, kernelSizeList, stridesDilationsList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } return failure(); }); patterns.onOp( "Bernoulli", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; int64_t dtypeIntOnnx; if (binder.tensorOperand(input) || binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) || binder.tensorResultType(resultType)) return failure(); SmallString<64> name("torch.onnx."); name.append("seed"); auto attr = binder.op->getAttr(name); if (attr) { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for seed attribute"); } Value none = rewriter.create(binder.getLoc()); Value bernoulli = rewriter.create( binder.getLoc(), input.getType(), input, /*generator=*/none); if (dtypeIntOnnx == -1) { // True, if dtype attribute value is not present. rewriter.replaceOp(binder.op, bernoulli); return success(); } std::optional dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); if (!dtypeIntTorch.has_value()) { return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); } Value constDtype = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value cstFalse = rewriter.create(binder.getLoc(), false); rewriter.replaceOpWithNewOp( binder.op, resultType, bernoulli, constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); return success(); }); patterns.onOp( "BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(direction, "direction", "")) return failure(); if (direction == "LEFT") { rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); } else { rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); } return success(); }); patterns.onOp("BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("BitwiseOr", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("BitwiseNot", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("BitwiseXor", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; int64_t dtypeIntOnnx; if (binder.tensorOperand(operand) || binder.s64IntegerAttr(dtypeIntOnnx, "to") || binder.tensorResultType(resultType)) return failure(); std::optional dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); if (!dtypeIntTorch.has_value()) { return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); } Value constDtype = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value none = rewriter.create(binder.getLoc()); Value cstFalse = rewriter.create(binder.getLoc(), false); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); return success(); }); patterns.onOp( "CastLike", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input, target; if (binder.tensorOperands(input, target) || binder.tensorResultType(resultType)) return failure(); // TODO: Add support to handle the `saturate` attribute. // Ignoring it right now, since it's only using during the float8 // conversions which are not supported in Torch-MLIR right now. Torch::ValueTensorType targetTy = cast(target.getType()); if (!targetTy.hasDtype()) { return rewriter.notifyMatchFailure(binder.op, "target tensor must have a dtype"); } Type targetDtype = targetTy.getDtype(); Value constDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), targetDtype); Value none = rewriter.create(binder.getLoc()); Value cstFalse = rewriter.create(binder.getLoc(), false); rewriter.replaceOpWithNewOp( binder.op, resultType, input, constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); return success(); }); patterns.onOp("Ceil", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp( "Celu", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; float alpha; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || binder.f32FloatAttr(alpha, "alpha", 1.0f)) return failure(); // exp(x/alpha) Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); Value xDivAlpha = rewriter.create( binder.getLoc(), resultType, operand, constAlpha); Value expXDivAlpha = rewriter.create( binder.getLoc(), resultType, xDivAlpha); // alpha * (exp(x/alpha) - 1) Value constantOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value subOne = rewriter.create( binder.getLoc(), resultType, expXDivAlpha, constantOne, constantOne); Value mulAlpha = rewriter.create( binder.getLoc(), resultType, subOne, constAlpha); Value constantZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantZero); // min(0, alpha * (exp(x/alpha) - 1)) Value minExpression = rewriter.create( binder.getLoc(), resultType, zeroTensor, mulAlpha); // max(0, x) Value maxExpression = rewriter.create( binder.getLoc(), resultType, zeroTensor, operand); // max(0,x) + min(0, alpha * (exp(x/alpha) - 1)) rewriter.replaceOpWithNewOp( binder.op, resultType, maxExpression, minExpression, constantOne); return success(); }); patterns.onOp( "CenterCropPad", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input, shape; if (binder.tensorOperands(input, shape) || binder.tensorResultType(resultType)) return failure(); auto inputTy = cast(input.getType()); SmallVector inputShape(inputTy.getSizes()); SmallVector resultShape(resultType.getSizes()); int64_t rank = inputShape.size(); SmallVector axes, defaultAxes(rank); std::iota(defaultAxes.begin(), defaultAxes.end(), 0); if (binder.s64IntegerArrayAttr(axes, "axes", defaultAxes)) { return failure(); } int64_t axesSize = axes.size(); Value none = rewriter.create(binder.getLoc()); Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value cstTwo = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(2)); auto scalarTensorType = rewriter.getType( ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ 1)); int64_t lastChangeDim = 0; llvm::SmallVector interShape(inputShape); for (int i = 0; i < rank; i++) { if (inputShape[i] != resultShape[i]) { interShape[i] = -1; lastChangeDim = i; } if (interShape[i] == ShapedType::kDynamic) interShape[i] = Torch::kUnknownSize; } auto interType = rewriter.getType( interShape, resultType.getOptionalDtype()); Value modeVal = rewriter.create( binder.getLoc(), rewriter.getStringAttr("floor")); for (int i = 0; i < axesSize; i++) { if (axes[i] < 0) axes[i] += rank; if (inputShape[axes[i]] == resultShape[axes[i]]) continue; auto opType = axes[i] == lastChangeDim ? resultType : interType; Value axis = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(axes[i])); Value k = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i)); Value kTensor = rewriter.create( binder.getLoc(), scalarTensorType, k); Value sel = rewriter.create( binder.getLoc(), scalarTensorType, shape, cstZero, kTensor); Value outputDimSize = rewriter.create( binder.getLoc(), rewriter.getType(), sel); Value inputDimSize = rewriter.create( binder.getLoc(), input, rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(axes[i]))); if (inputShape[axes[i]] > resultShape[axes[i]]) { Value sub = rewriter.create( binder.getLoc(), inputDimSize, outputDimSize); Value subTensor = rewriter.create( binder.getLoc(), scalarTensorType, sub); Value div = rewriter.create( binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal); Value start = rewriter.create( binder.getLoc(), rewriter.getType(), div); Value end = rewriter.create( binder.getLoc(), start, outputDimSize); input = rewriter.create( binder.getLoc(), opType, input, axis, start, end, cstOne); } else { Value sub = rewriter.create( binder.getLoc(), outputDimSize, inputDimSize); Value subTensor = rewriter.create( binder.getLoc(), scalarTensorType, sub); Value div = rewriter.create( binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal); Value start = rewriter.create( binder.getLoc(), rewriter.getType(), div); Value end = rewriter.create( binder.getLoc(), start, inputDimSize); SmallVector zerosShapeValues; for (int j = 0; j < rank; j++) { if (j == axes[i]) { zerosShapeValues.push_back(outputDimSize); } else { Value dimSize = rewriter.create( binder.getLoc(), input, rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(j))); zerosShapeValues.push_back(dimSize); } } Value zerosShapeList = rewriter.create( binder.getLoc(), rewriter.getType( rewriter.getType()), zerosShapeValues); Value zeros = rewriter.create( binder.getLoc(), opType, zerosShapeList, none, none, none, none); input = rewriter.create( binder.getLoc(), opType, zeros, input, axis, start, end, cstOne); } } rewriter.replaceOp(binder.op, input); return success(); }); patterns.onOp( "Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // https://onnx.ai/onnx/operators/onnx__Clip.html // Inputs and outputs must be tensors. Value source; Torch::ValueTensorType resultType; if (binder.tensorOperandAtIndex(source, 0) || binder.tensorResultType(resultType)) { return failure(); } // Min and max can be args (version 11+) or attributes (version 6-). // They default to numeric_limits::lowest() and numeric_limits::max(). Value min; Value max; if (binder.op->getNumOperands() >= 2) min = binder.op->getOperand(1); if (binder.op->getNumOperands() == 3) max = binder.op->getOperand(2); // Note: attribute versions of the op only support float types. auto resultDtype = resultType.getDtype(); if (!min && binder.op->hasAttr("torch.onnx.min")) { float minValue; if (binder.f32FloatAttr(minValue, "min", std::numeric_limits::lowest())) return failure(); auto minSplatAttr = SplatElementsAttr::get( resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, minValue)); min = rewriter.create( binder.getLoc(), resultType, minSplatAttr); } if (!max && binder.op->hasAttr("torch.onnx.max")) { float maxValue; if (binder.f32FloatAttr(maxValue, "max", std::numeric_limits::max())) return failure(); auto maxSplatAttr = SplatElementsAttr::get( resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, maxValue)); max = rewriter.create( binder.getLoc(), resultType, maxSplatAttr); } if (!min && !max) { // Cliping with no limits is a no-op. rewriter.replaceOp(binder.op, source); return success(); } if (!max) { rewriter.replaceOpWithNewOp( binder.op, resultType, source, min); return success(); } rewriter.replaceOpWithNewOp( binder.op, resultType, source, min, max); return success(); }); patterns.onOp( "Compress", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand, conditionTensor; int64_t axis; if (binder.tensorOperands(operand, conditionTensor) || binder.s64IntegerAttr(axis, "axis", INT64_MAX) || binder.tensorResultType(resultType)) return failure(); auto shapeSizes = dyn_cast(operand.getType()).getSizes(); auto resultSizes = resultType.getSizes(); // flatten input tensor if using default axis if (axis == INT64_MAX) { SmallVector nonzeroShape = {resultSizes[0]}; auto dtype = dyn_cast(conditionTensor.getType()) .getDtype(); auto nonzeroType = rewriter.getType(nonzeroShape, dtype); Value indexVal = rewriter.create( binder.getLoc(), nonzeroType, conditionTensor); Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstNegOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(-1)); int64_t numElements = 1; for (auto i : shapeSizes) { numElements *= i; } SmallVector flattenShape = {numElements}; auto flattenType = rewriter.getType( flattenShape, resultType.getDtype()); Value flattenTensor = rewriter.create( binder.getLoc(), flattenType, operand, cstZero, cstNegOne); rewriter.replaceOpWithNewOp( binder.op, resultType, flattenTensor, cstZero, indexVal); return success(); } // Negative axis value means counting dimensions from the back if (axis < 0) axis += shapeSizes.size(); SmallVector nonzeroShape = {resultSizes[axis]}; auto dtype = dyn_cast(conditionTensor.getType()) .getDtype(); auto nonzeroType = rewriter.getType(nonzeroShape, dtype); Value indexVal = rewriter.create( binder.getLoc(), nonzeroType, conditionTensor); Value dimVal = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(axis)); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, dimVal, indexVal); return success(); }); patterns.onOp( "Concat", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; SmallVector tensors; int64_t dim; if (binder.tensorOperands(tensors, binder.op->getNumOperands()) || binder.s64IntegerAttr(dim, "axis", 0) || binder.tensorResultType(resultType)) return failure(); Type listElemType = cast(tensors[0].getType()) .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value tensorList = rewriter.create( binder.op->getLoc(), listType, tensors); Value cstDim = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(dim)); rewriter.replaceOpWithNewOp(binder.op, resultType, tensorList, cstDim); return success(); }); patterns.onOp( "Constant", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; if (binder.tensorResultType(resultType)) return failure(); auto dtype = resultType.getDtype(); float floatValue; if (binder.op->hasAttr("torch.onnx.value_float") && !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { auto splatAttr = SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getFloatAttr(dtype, floatValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); return success(); } int64_t intValue; if (binder.op->hasAttr("torch.onnx.value_int") && !binder.s64IntegerAttr(intValue, "value_int", 0)) { auto splatAttr = SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getIntegerAttr(dtype, intValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); return success(); } if (DenseResourceElementsAttr attr = dyn_cast_or_null( binder.op->getAttr("torch.onnx.value"))) { // Bytes are stored in little endian order. Big endian support will // require swizzling. if (!Endian::little) { binder.op->emitError( "unimplemented: importing on big endian systems"); return failure(); } auto ty = cast(attr.getType()); ElementsAttr denseAttr; auto ptr = attr.getRawHandle().getBlob(); if (!ptr) { denseAttr = DenseResourceElementsAttr::get( ty, "__onnx_constant_not_found_possibly_due_to_being_elided__", AsmResourceBlob()); rewriter.replaceOpWithNewOp( binder.op, resultType, denseAttr); return success(); } auto data = ptr->getData(); if (cast(attr.getType()).getElementType().isInteger(1)) { llvm::SmallVector newContents; for (auto val : data) { APInt apval(1, val); newContents.push_back(apval); } denseAttr = DenseElementsAttr::get(ty, newContents); } else { denseAttr = DenseElementsAttr::getFromRawBuffer(ty, data); } rewriter.replaceOpWithNewOp( binder.op, resultType, denseAttr); return success(); } if (ElementsAttr attr = dyn_cast_or_null( binder.op->getAttr("torch.onnx.value"))) { rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); } llvm::SmallVector intValues; if (!binder.s64IntegerArrayAttr(intValues, "value_ints", {}) && !intValues.empty()) { llvm::SmallVector apValues; for (auto intVal : intValues) { apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); } auto attr = DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues); rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); } return failure(); }); patterns.onOp( "Col2Im", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input, blockShape, imageShape; SmallVector dilations, strides, pads; // TODO: The length of dilations should be len(imageShape), and the same // goes for strides. The length of pads should be 2 * len(imageShape). // But, as at the moment we are only supporting 3D or 4D input, // len(imageShape) must necessarily be 2, hence the lengths of the // default values. if (binder.tensorOperandAtIndex(input, 0) || binder.tensorOperandAtIndex(imageShape, 1) || binder.tensorOperandAtIndex(blockShape, 2) || binder.tensorResultType(resultType) || binder.s64IntegerArrayAttr(dilations, "dilations", SmallVector{1, 1}) || binder.s64IntegerArrayAttr(strides, "strides", SmallVector{1, 1}) || binder.s64IntegerArrayAttr(pads, "pads", SmallVector{0, 0, 0, 0})) return failure(); auto imageShapeTy = cast(imageShape.getType()); auto imageShapeSizes = imageShapeTy.getSizes(); auto blockShapeTy = cast(blockShape.getType()); auto blockShapeSizes = blockShapeTy.getSizes(); // Check that neither imageShape nor blockShape have dynamic shapes. if (imageShapeSizes[0] == Torch::kUnknownSize || blockShapeSizes[0] == Torch::kUnknownSize) { return rewriter.notifyMatchFailure( binder.op, "Dynamic shapes are not allowed for imageShape and blockShape"); } // TODO: Add support for 5D input tensors. if (imageShapeSizes[0] != 2) { return rewriter.notifyMatchFailure( binder.op, "Expected length of imageShape to be equal to 2"); } if (blockShapeSizes[0] != 2) { return rewriter.notifyMatchFailure( binder.op, "Expected length of blockShape to be equal to 2"); } if (dilations.size() != 2) { return rewriter.notifyMatchFailure( binder.op, "Expected length of dilations to be equal to 2"); } if (strides.size() != 2) { return rewriter.notifyMatchFailure( binder.op, "Expected length of strides to be equal to 2"); } // TODO: Disable this check and add support for different // paddings on lower and higher ends of each axis. // Because we have already checked that imageShape has 2 elements, // we can safely assume that len(padding) will be 4. if (pads[0] != pads[2] || pads[1] != pads[3]) return rewriter.notifyMatchFailure( binder.op, "padding on the lower end and the higher end " "on each axis should be the same"); // Since we know that the padding on the lower end and the higher // end on each axis is the same, we can reduce the size of the // padding list, and filter out the duplicate elements. // (Also, Torch::AtenCol2imOp requires len(padding) to be 2). SmallVector padOnEachAxis = {pads[0], pads[1]}; Value dilationsList = createConstantIntList(binder, rewriter, dilations); Value stridesList = createConstantIntList(binder, rewriter, strides); Value paddingList = createConstantIntList(binder, rewriter, padOnEachAxis); Value zero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); // Index the imageShape and blockShape tensors, as AtenCol2imOp expects // them to be int lists. auto select = [&](Value v, Value k, Torch::ValueTensorType ty) -> Value { Value kTensor = rewriter.create( binder.getLoc(), Torch::ValueTensorType::get( binder.op->getContext(), ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ 1)), k); auto sel = rewriter.create( binder.getLoc(), Torch::ValueTensorType::get(ty.getContext(), ArrayRef{1}, ty.getOptionalDtype()), v, zero, kTensor); Value item = rewriter.create( binder.getLoc(), rewriter.getType(), sel); return item; }; SmallVector imageShapeContainer, blockShapeContainer; for (int64_t i = 0; i < imageShapeSizes[0]; ++i) { Value k = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i)); // Passing in the shapeType of each of these tensors avoids // repeated casts, as these have already been calculated. imageShapeContainer.push_back(select(imageShape, k, imageShapeTy)); blockShapeContainer.push_back(select(blockShape, k, blockShapeTy)); } Value imageShapeAsList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), imageShapeContainer); Value blockShapeAsList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), blockShapeContainer); rewriter.replaceOpWithNewOp( binder.op, resultType, input, imageShapeAsList, blockShapeAsList, dilationsList, paddingList, stridesList); return success(); }); patterns.onOp( "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); if (autoPad != "NOTSET") { // TODO: Add support for `auto_pad` != "NOTSET" return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); } Torch::ValueTensorType resultType; Value input, weight; int64_t group; if (binder.tensorOperandAtIndex(input, 0) || binder.tensorOperandAtIndex(weight, 1) || binder.s64IntegerAttr(group, "group", 1) || binder.tensorResultType(resultType)) return failure(); auto weightTensorType = cast(weight.getType()); if (!weightTensorType || !weightTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected weight type having sizes"); } ArrayRef weightShape = weightTensorType.getSizes(); SmallVector kernelShape; if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) return failure(); if (kernelShape.size()) { if (kernelShape.size() != weightShape.size() - 2) { return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: kernel_shape list size should have " "number of values equal to weight_rank - 2"); } else { for (unsigned i = 0; i < kernelShape.size(); i++) { if (weightShape[i + 2] != kernelShape[i]) { return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: kernel_shape value " "should be equal to the weight tensor shape"); } } } } // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(input); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; SmallVector padding, strides, dilations; SmallVector defaultPadding, defaultStrides, defaultDilations; for (unsigned i = 0; i < rank - 2; i++) { defaultPadding.push_back(0); defaultStrides.push_back(1); defaultDilations.push_back(1); } // Padding for the beginning and ending along each spatial axis, it can // take any value greater than or equal to 0. The value represent the // number of pixels added to the beginning and end part of the // corresponding axis. pads format should be as follow [x1_begin, // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added // at the beginning of axis i and xi_end, the number of pixels added at // the end of axis i. if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { return failure(); } if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { return rewriter.notifyMatchFailure( binder.op, "padding list size does not match the number of axes"); } if (binder.s64IntegerArrayAttr(dilations, "dilations", defaultDilations)) { return failure(); } if (dilations.size() != rank - 2) { return rewriter.notifyMatchFailure( binder.op, "dilations list size does not match the number of axes"); } if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) { return failure(); } if (strides.size() != rank - 2) { return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); } SmallVector cstPadding, cstStrides, cstDilations, cstOutputPadding; Value paddedInput = input; Value paddingList; if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstPadding); } else { // ONNX offers pads in the format listing all starting dims, then all // ending dims, e.g. {t, l, b, r} for conv2d. Torch by default accepts // only starting dims, e.g. {t, l}. However, we can support padding at // the beginning and end of each dimension by first performing // torch.nn.functional.pad on the input. But this requires the pad // values to be rearranged since torch pad() takes pads in the order // rightmost dim start and end, then next to last, and so on, e.g. {l, // r, t, b}. bool matchedPads = true; for (unsigned i = 0; i < padding.size() / 2; i++) { if (padding[i] != padding[i + (padding.size() / 2)]) { matchedPads = false; break; } } if (matchedPads) { for (unsigned i = 0; i < padding.size() / 2; i++) { cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstPadding); } else { SmallVector padsRearrange; SmallVector inputPaddingList; for (uint32_t i = 0; i < padding.size() / 2; i++) { padsRearrange.emplace_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); padsRearrange.emplace_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr( padding[(padding.size() / 2) + i]))); inputPaddingList.emplace_back( rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0))); } // The conv op itself will have no padding since the actual padding // is performed using the torch.pad preceding it. paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), inputPaddingList); Value padsSizeList = rewriter .create( binder.getLoc(), Torch::ListType::get( rewriter.getType()), padsRearrange) .getResult(); Value modeVal = rewriter.create( binder.getLoc(), rewriter.getStringAttr("constant")); Value constantValue; auto inputTensorType = cast(input.getType()); if (isa(inputTensorType.getDtype())) constantValue = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); if (isa(inputTensorType.getDtype())) constantValue = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); // Pad output shape must be computed explicitly from the pad values SmallVector newInputShape(inputTensorType.getSizes()); for (uint32_t i = 0; i < padding.size() / 2; i++) { newInputShape[2 + i] += padding[i] + padding[(padding.size() / 2) + i]; } auto padTy = rewriter.getType( newInputShape, inputTensorType.getDtype()); paddedInput = rewriter.create( binder.getLoc(), padTy, input, padsSizeList, modeVal, constantValue); } } for (int64_t i : dilations) { cstDilations.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); cstOutputPadding = {cstZero, cstZero}; Value dilationsList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); Value stridesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value outputPaddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstOutputPadding); Value transposed = rewriter.create(binder.getLoc(), false); Value bias; if (binder.op->getNumOperands() == 3) { if (binder.tensorOperandAtIndex(bias, 2)) { return failure(); } } else { bias = rewriter.create(binder.getLoc()); } Value cstGroup = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( binder.op, resultType, paddedInput, weight, bias, stridesList, paddingList, dilationsList, transposed, outputPaddingList, cstGroup); return success(); }); patterns.onOp( "ConvInteger", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); if (autoPad != "NOTSET") // TODO: Add support for `auto_pad` != "NOTSET" return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); Torch::ValueTensorType resultType; Value input, weight, inputZp, weightZp; int64_t group; if (binder.tensorOperandAtIndex(input, 0) || binder.tensorOperandAtIndex(weight, 1) || binder.s64IntegerAttr(group, "group", 1) || binder.tensorResultType(resultType)) return failure(); auto inputTy = dyn_cast(input.getType()); auto weightTy = dyn_cast(weight.getType()); if (!weightTy || !weightTy.hasSizes()) return rewriter.notifyMatchFailure( binder.op, "Expected weight type having sizes"); ArrayRef weightShape = weightTy.getSizes(); SmallVector kernelShape; if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) return failure(); if (kernelShape.size()) { if (kernelShape.size() != weightShape.size() - 2) { return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: kernel_shape list size should have " "number of values equal to weight_rank - 2"); } else { for (unsigned i = 0; i < kernelShape.size(); i++) { if (weightShape[i + 2] != kernelShape[i]) return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: kernel_shape value " "should be equal to the weight tensor shape"); } } } // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(input); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; SmallVector padding, strides, dilations; SmallVector defaultPadding(rank - 2, 0), defaultStrides(rank - 2, 1), defaultDilations(rank - 2, 1); // Padding for the beginning and ending along each spatial axis, it can // take any value greater than or equal to 0. The value represent the // number of pixels added to the beginning and end part of the // corresponding axis. pads format should be as follow [x1_begin, // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added // at the beginning of axis i and xi_end, the number of pixels added at // the end of axis i. if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) return failure(); if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) return rewriter.notifyMatchFailure( binder.op, "padding list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(dilations, "dilations", defaultDilations)) return failure(); if (dilations.size() != rank - 2) return rewriter.notifyMatchFailure( binder.op, "dilations list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) return failure(); if (strides.size() != rank - 2) return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); Value scale = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); if (binder.tensorOperandAtIndex(inputZp, 2)) { inputZp = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); } else { inputZp = rewriter.create( binder.getLoc(), rewriter.getType(), inputZp); } if (binder.tensorOperandAtIndex(weightZp, 3)) weightZp = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); // TODO: support per channel quantization if weightZp is a 1-D tensor if (auto zpTy = dyn_cast(weightZp.getType())) { for (auto dim : zpTy.getSizes()) if (dim != 1) return failure(); weightZp = rewriter.create( binder.getLoc(), rewriter.getType(), weightZp); } SmallVector cstPadding; if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } } else { for (unsigned i = 0; i < padding.size() / 2; i++) { if (padding[i] != padding[i + (padding.size() / 2)]) // TODO: Add support for different padding values for the // beginning and ending along each spatial axis return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: padding values for the beginning " "and ending along each spatial axis must be equal"); cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } } Value paddingList = rewriter.create( binder.getLoc(), rewriter.getType( rewriter.getType()), cstPadding); Value dilationsList = createConstantIntList(binder, rewriter, dilations); Value stridesList = createConstantIntList(binder, rewriter, strides); Value outputPaddingList = createConstantIntList(binder, rewriter, {0, 0}); Value transposed = rewriter.create(binder.getLoc(), false); Value bias = rewriter.create(binder.getLoc()); Value cstGroup = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(group)); Type inputQTy = getQTorchTypeFromTorchIntType(inputTy); Type weightQTy = getQTorchTypeFromTorchIntType(weightTy); input = rewriter.create( binder.getLoc(), inputQTy, input, scale, inputZp); weight = rewriter.create( binder.getLoc(), weightQTy, weight, scale, weightZp); rewriter.replaceOpWithNewOp( binder.op, resultType, input, weight, bias, stridesList, paddingList, dilationsList, transposed, outputPaddingList, cstGroup); return success(); }); patterns.onOp( "ConvTranspose", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); if (autoPad != "NOTSET") { // TODO: Add support for `auto_pad` != "NOTSET" return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); } SmallVector outputShape; if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {})) return failure(); if (outputShape.size()) { // TODO: Add support for non-None output_shape value. return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: output_shape should be absent"); } Torch::ValueTensorType resultType; Value input, weight; int64_t group; if (binder.tensorOperandAtIndex(input, 0) || binder.tensorOperandAtIndex(weight, 1) || binder.s64IntegerAttr(group, "group", 1) || binder.tensorResultType(resultType)) return failure(); auto weightTensorType = cast(weight.getType()); if (!weightTensorType || !weightTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected weight type having sizes"); } ArrayRef weightShape = weightTensorType.getSizes(); SmallVector kernelShape; if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) return failure(); if (kernelShape.size()) { if (kernelShape.size() != weightShape.size() - 2) { return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: kernel_shape list size should have " "number of values equal to weight_rank - 2"); } else { for (unsigned i = 0; i < kernelShape.size(); i++) { if (weightShape[i + 2] != kernelShape[i]) { return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: kernel_shape value " "should be equal to the weight tensor shape"); } } } } // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(input); if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; SmallVector padding, strides, dilations, outputPadding; SmallVector defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding; for (unsigned i = 0; i < rank - 2; i++) { defaultPadding.push_back(0); defaultStrides.push_back(1); defaultDilations.push_back(1); defaultOutputPadding.push_back(0); } // Padding for the beginning and ending along each spatial axis, it can // take any value greater than or equal to 0. The value represent the // number of pixels added to the beginning and end part of the // corresponding axis. pads format should be as follow [x1_begin, // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added // at the beginning of axis i and xi_end, the number of pixels added at // the end of axis i. if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { return failure(); } if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { return rewriter.notifyMatchFailure( binder.op, "padding list size does not match the number of axes"); } if (binder.s64IntegerArrayAttr(dilations, "dilations", defaultDilations)) { return failure(); } if (dilations.size() != rank - 2) { return rewriter.notifyMatchFailure( binder.op, "dilations list size does not match the number of axes"); } if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) { return failure(); } if (strides.size() != rank - 2) { return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); } if (binder.s64IntegerArrayAttr(outputPadding, "output_padding", defaultOutputPadding)) { return failure(); } if (outputPadding.size() != rank - 2) { return rewriter.notifyMatchFailure( binder.op, "output_padding list size does not match the number of axes"); } SmallVector cstPadding, cstStrides, cstDilations, cstOutputPadding; if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } } else { for (unsigned i = 0; i < padding.size() / 2; i++) { if (padding[i] != padding[i + (padding.size() / 2)]) { // TODO: Add support for different padding values for the // beginning and ending along each spatial axis return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: padding values for the beginning " "and ending along each spatial axis must be equal"); } cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } } for (int64_t i : dilations) { cstDilations.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } for (int64_t i : outputPadding) { cstOutputPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } Value paddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); Value dilationsList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); Value stridesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value outputPaddingList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstOutputPadding); Value transposed = rewriter.create(binder.getLoc(), true); Value bias; if (binder.op->getNumOperands() == 3) { if (binder.tensorOperandAtIndex(bias, 2)) { return failure(); } } else { bias = rewriter.create(binder.getLoc()); } Value cstGroup = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, weight, bias, stridesList, paddingList, dilationsList, transposed, outputPaddingList, cstGroup); return success(); }); patterns.onOp("Cos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Cosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand, axisTensor; int64_t exclusive, reverse; if (binder.tensorOperands(operand, axisTensor) || binder.s64IntegerAttr(exclusive, "exclusive", 0) || binder.s64IntegerAttr(reverse, "reverse", 0) || binder.tensorResultType(resultType)) return failure(); Torch::BaseTensorType resultTensorType = cast(resultType); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( binder.op, "expected result type to have a dtype"); } // deal with neg axis: if (axis < 0) axis += rank int64_t rank = cast(operand.getType()).getSizes().size(); Value rankVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank)); Value cstZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value axisScalar = rewriter.create( binder.getLoc(), rewriter.getType(), axisTensor); Value isNegative = rewriter.create( binder.getLoc(), axisScalar, cstZero); isNegative = rewriter.create(binder.getLoc(), isNegative); Value finalOffset = rewriter.create( binder.getLoc(), isNegative, rankVal); Value axis = rewriter.create( binder.getLoc(), axisScalar, finalOffset); Value none = rewriter.create(binder.getLoc()); Value res; if (reverse) { Value dims = rewriter.create( binder.getLoc(), rewriter.getType( rewriter.getType()), SmallVector{axis}); Value flip = rewriter.create( binder.getLoc(), resultType, operand, dims); Value cumsum = rewriter.create( binder.getLoc(), resultType, flip, axis, none); res = rewriter.create(binder.getLoc(), resultType, cumsum, dims); } else { res = rewriter.create( binder.getLoc(), resultType, operand, axis, none); } if (exclusive) res = rewriter.create( binder.getLoc(), resultType, res, operand, cstOne); rewriter.replaceOp(binder.op, res); return success(); }); patterns.onOp( "DepthToSpace", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; int64_t blockSize; std::string mode; if (binder.tensorOperand(input) || binder.s64IntegerAttr(blockSize, "blocksize") || binder.customOpNameStringAttr(mode, "mode", "DCR") || binder.tensorResultType(resultType)) return failure(); auto inputTy = dyn_cast(input.getType()); if (!inputTy || !inputTy.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); } SmallVector inputSizes{inputTy.getSizes()}; if (inputSizes.size() != 4) { return rewriter.notifyMatchFailure(binder.op, "Expected input rank to be 4"); } Value b = rewriter.create( binder.getLoc(), input, rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0))); Value c = rewriter.create( binder.getLoc(), input, rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1))); Value h = rewriter.create( binder.getLoc(), input, rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(2))); Value w = rewriter.create( binder.getLoc(), input, rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(3))); Value cstBlockSize = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); Value cstBlockSizeSquare = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize)); Value cDivBlockSizeSquare = rewriter.create( binder.getLoc(), c, cstBlockSizeSquare); cDivBlockSizeSquare = rewriter.create( binder.getLoc(), cDivBlockSizeSquare); Value reshapeSizesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(input.getContext())), llvm::SmallVector{b, cstBlockSize, cstBlockSize, cDivBlockSizeSquare, h, w}); int64_t cDivBlockSizeSquareInt = inputSizes[1] == Torch::kUnknownSize ? Torch::kUnknownSize : inputSizes[1] / (blockSize * blockSize); SmallVector reshapeSizesInt{ inputSizes[0], blockSize, blockSize, cDivBlockSizeSquareInt, inputSizes[2], inputSizes[3]}; Value reshapedInput = rewriter.create( binder.getLoc(), inputTy.getWithSizesAndDtype(reshapeSizesInt, inputTy.getOptionalDtype()), input, reshapeSizesList); Value transposedInput; if (mode == "DCR") { if (failed(createTorchTransposeOp( rewriter, binder.getLoc(), reshapedInput, /*dimA=*/1, /*dimB=*/3, transposedInput))) return rewriter.notifyMatchFailure( binder.op, "Failed to create TorchTranspose op"); if (failed(createTorchTransposeOp( rewriter, binder.getLoc(), transposedInput, /*dimA=*/2, /*dimB=*/4, transposedInput))) return rewriter.notifyMatchFailure( binder.op, "Failed to create TorchTranspose op"); } else { // mode == "CRD" if (failed(createTorchTransposeOp( rewriter, binder.getLoc(), reshapedInput, /*dimA=*/2, /*dimB=*/4, transposedInput))) return rewriter.notifyMatchFailure( binder.op, "Failed to create TorchTranspose op"); if (failed(createTorchTransposeOp( rewriter, binder.getLoc(), transposedInput, /*dimA=*/3, /*dimB=*/4, transposedInput))) return rewriter.notifyMatchFailure( binder.op, "Failed to create TorchTranspose op"); } if (failed(createTorchTransposeOp( rewriter, binder.getLoc(), transposedInput, /*dimA=*/4, /*dimB=*/5, transposedInput))) return rewriter.notifyMatchFailure( binder.op, "Failed to create TorchTranspose op"); Value hMulBlockSize = rewriter.create( binder.getLoc(), h, cstBlockSize); Value wMulBlockSize = rewriter.create( binder.getLoc(), w, cstBlockSize); reshapeSizesList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(input.getContext())), llvm::SmallVector{b, cDivBlockSizeSquare, hMulBlockSize, wMulBlockSize}); rewriter.replaceOpWithNewOp( binder.op, resultType, transposedInput, reshapeSizesList); return success(); }); patterns.onOp( "DeformConv", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { auto loc = binder.getLoc(); // get operands llvm::SmallVector operands; Torch::ValueTensorType resultType; if (binder.tensorOperandsList(operands) || binder.tensorResultType(resultType)) return failure(); if (operands.size() < 3 || operands.size() > 5) return failure(); auto inputType = dyn_cast(operands[0].getType()); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 4) return rewriter.notifyMatchFailure( binder.op, "Unsupported: DeformConv with input rank != 4"); unsigned rank = inputType.getSizes().size(); auto weightType = dyn_cast(operands[1].getType()); if (!weightType || !weightType.hasSizes()) return failure(); auto offsetType = dyn_cast(operands[2].getType()); if (!offsetType || !offsetType.hasSizes()) return failure(); // get attributes SmallVector dilations, kernelShape, pads, strides; SmallVector defaultDilations(rank - 2, 0); SmallVector defaultPads(2 * (rank - 2), 0); SmallVector defaultStrides(rank - 2, 1); int64_t group, offsetGroup; if (binder.s64IntegerArrayAttr(dilations, "dilations", defaultDilations) || binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}) || binder.s64IntegerArrayAttr(pads, "pads", defaultPads) || binder.s64IntegerArrayAttr(strides, "strides", defaultStrides) || binder.s64IntegerAttr(group, "group", 1) || binder.s64IntegerAttr(offsetGroup, "offset_group", 1)) return failure(); for (unsigned i = 0; i < rank - 2; i++) { if (pads[i] != pads[rank + i - 2]) return rewriter.notifyMatchFailure( binder.op, "unsupported: asymmetric padding"); } // Identify and assign names to operands Value input, weight, offset, bias, mask; bool useMask = false; input = operands[0]; weight = operands[1]; offset = operands[2]; if (operands.size() == 4) { auto unknownOpdRank = Torch::getTensorRank(operands[3]); if (!unknownOpdRank) return failure(); if (*unknownOpdRank == 1) bias = operands[3]; else if (*unknownOpdRank == rank) { mask = operands[3]; useMask = true; } else llvm_unreachable("onnx.DeformConv: optional 4th operand of " "unexpected rank encountered"); } if (operands.size() == 5) { bias = operands[3]; mask = operands[4]; useMask = true; } // assign default operand values if necessary ArrayRef weightSizes = weightType.getSizes(); ArrayRef offsetSizes = offsetType.getSizes(); if (!bias) { int64_t outputChannels = weightSizes[0]; SmallVector biasShape(1, outputChannels); Value biasShapeList = mlir::torch::onnx_c::createConstantIntList( binder, rewriter, biasShape); Value cstZero = Torch::getConstantWithGivenDtypeAndValue( rewriter, loc, 0.0f, inputType.getDtype()); bias = Torch::createInitTensor(rewriter, loc, rewriter.getType( biasShape, inputType.getDtype()), cstZero, biasShapeList); } if (!mask) { int64_t batchSize = inputType.getSizes()[0]; int64_t kernelHeight = weightSizes[2]; int64_t kernelWidth = weightSizes[3]; int64_t outputHeight = offsetSizes[2]; int64_t outputWidth = offsetSizes[3]; int64_t maskDimOne = offsetGroup * kernelHeight * kernelWidth; SmallVector maskShape( {batchSize, maskDimOne, outputHeight, outputWidth}); Value cstOne = Torch::getConstantWithGivenDtypeAndValue( rewriter, loc, 1.0f, inputType.getDtype()); Value maskShapeList = mlir::torch::onnx_c::createConstantIntList( binder, rewriter, maskShape); mask = Torch::createInitTensor(rewriter, loc, rewriter.getType( maskShape, inputType.getDtype()), cstOne, maskShapeList); } // get attributes as constant values SmallVector dilationValues, padValues, strideValues; for (auto i : dilations) dilationValues.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); for (auto i : pads) padValues.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); for (auto i : strides) strideValues.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); Value groupValue = rewriter.create( loc, rewriter.getI64IntegerAttr(group)); Value offsetGroupValue = rewriter.create( loc, rewriter.getI64IntegerAttr(offsetGroup)); Value useMaskValue = rewriter.create( loc, rewriter.getBoolAttr(useMask)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, weight, offset, mask, bias, strideValues[0], strideValues[1], padValues[0], padValues[1], dilationValues[0], dilationValues[1], groupValue, offsetGroupValue, useMaskValue); return success(); }); patterns.onOp( "Det", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; if (binder.tensorOperand(input) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp(binder.op, resultType, input); return success(); }); patterns.onOp( "DequantizeLinear", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; llvm::SmallVector operands; if (binder.tensorOperands(operands, 3) || binder.tensorResultType(resultType)) return failure(); Value operand = operands[0]; Value scale = operands[1]; Value zeropoint = operands[2]; auto operandTy = cast(operand.getType()); auto scaleTy = dyn_cast(scale.getType()); if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank"); if (!resultType.hasDtype()) return rewriter.notifyMatchFailure(binder.op, "requires known result dtype"); if (scaleTy.getSizes().size() == 0 || (scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) { auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy); if (!qTensorTy) { return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); } scale = rewriter.create( binder.getLoc(), rewriter.getType(), scale); zeropoint = rewriter.create( binder.getLoc(), rewriter.getType(), zeropoint); auto quantize = rewriter.create( binder.getLoc(), qTensorTy, operand, scale, zeropoint); rewriter.replaceOpWithNewOp( binder.op, resultType, quantize); return success(); } return rewriter.notifyMatchFailure(binder.op, "unimplemented: non-scalar scale"); }); patterns.onOp("Div", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( "Dropout", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); Torch::ValueTensorType resultType; int64_t numOperands = binder.op->getNumOperands(); SmallVector operands; int64_t seed; if (binder.tensorOperands(operands, numOperands) || binder.s64IntegerAttr(seed, "seed", 0) || binder.tensorResultTypeAtIndex(resultType, 0)) return failure(); // Global Seed value is 0. if (seed != 0) { return rewriter.notifyMatchFailure(binder.op, "expected seed value to be 0"); } Value ratio, trainingMode; if (numOperands == 3) { ratio = rewriter.create(loc, operands[1]); Value trainVal = operands[2]; auto trainTensorType = dyn_cast(trainVal.getType()); if (!trainTensorType) return rewriter.notifyMatchFailure(binder.op, "train tensor must have a type"); Type inputDtype = trainTensorType.getOptionalDtype(); if (!inputDtype || !inputDtype.isInteger(1)) return rewriter.notifyMatchFailure( binder.op, "train tensor must have an integer dtype of width 1"); std::optional inputRank = Torch::getTensorRank(trainVal); if (!inputRank || *inputRank != 0) return rewriter.notifyMatchFailure(binder.op, "train tensor must have rank 0"); if (auto valueTensorLiteralOp = trainVal.getDefiningOp()) { auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); trainingMode = rewriter.create(loc, val); } else { Value trainingModeScalar = rewriter.create(loc, operands[2]); Value cstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); trainingMode = rewriter.create( loc, trainingModeScalar, cstOne); } } else if (numOperands == 2) { ratio = rewriter.create(loc, operands[1]); trainingMode = rewriter.create(loc, false); } else { ratio = rewriter.create( loc, rewriter.getF64FloatAttr(0.5)); trainingMode = rewriter.create(loc, false); } Value dropout = rewriter.create( loc, resultType, /*input=*/operands[0], ratio, trainingMode); if (binder.op->getNumResults() == 1) { rewriter.replaceOp(binder.op, dropout); return success(); } Torch::ValueTensorType maskType; if (binder.tensorResultTypeAtIndex(maskType, 1)) return failure(); Value dtype = rewriter.create( loc, rewriter.getI64IntegerAttr( (int64_t)torch_upstream::ScalarType::Bool)); Value none = rewriter.create(loc); Value mask = rewriter.create( loc, maskType, operands[0], dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); rewriter.replaceOp(binder.op, {dropout, mask}); return success(); }); patterns.onOp( "DynamicQuantizeLinear", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); Value input; Torch::ValueTensorType resultType, scaleType, zeroPointType; if (binder.tensorOperand(input) || binder.tensorResultTypeAtIndex(resultType, 0) || binder.tensorResultTypeAtIndex(scaleType, 1) || binder.tensorResultTypeAtIndex(zeroPointType, 2)) return failure(); Value scale, zeroPoint; // scale = ( max(0, max(input)) - min(0, min(input)) ) / 255 Value inputMax = rewriter.create(loc, scaleType, input); Value inputMin = rewriter.create(loc, scaleType, input); Value constantZero = rewriter.create( loc, rewriter.getF64FloatAttr(0)); Value constantOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value zeroTensor = createRank0Tensor(rewriter, loc, scaleType, constantZero); Value inputMaxW0 = rewriter.create( loc, scaleType, inputMax, zeroTensor); Value inputMinW0 = rewriter.create( loc, scaleType, inputMin, zeroTensor); Value scaleTensor = rewriter.create( loc, scaleType, inputMaxW0, inputMinW0, constantOne); // Note: the following is hard-coded for ui8 Value width = rewriter.create( loc, rewriter.getF64FloatAttr(255)); Value widthTensor = createRank0Tensor(rewriter, loc, scaleType, width); scaleTensor = rewriter.create( loc, scaleType, scaleTensor, widthTensor); // compute the preZeroPoint = 0 - (inputMin/scale) // compute the zeroPoint = cast ( round (clip or saturate // (preZeroPoint))) Value preZeroPoint = rewriter.create( loc, scaleType, inputMin, scaleTensor); preZeroPoint = rewriter.create( loc, scaleType, zeroTensor, preZeroPoint, constantOne); // saturate to interval [0, 255] preZeroPoint = rewriter.create( loc, scaleType, preZeroPoint, /*min=*/constantZero, /*max=*/width); // round, then cast to uint8 preZeroPoint = rewriter.create(loc, scaleType, preZeroPoint); Type qTy = rewriter.getType(); auto qTensorTy = rewriter.getType( resultType.getOptionalSizes(), qTy); auto torchqTy = Torch::getScalarTypeForType(qTy); Value tyConst = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); Value zeroPointTensor = rewriter.create( loc, zeroPointType, preZeroPoint, tyConst, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); // extract scale and zeroPoint scalars to pass to // AtenQuantizePerTensorOp zeroPoint = rewriter.create( loc, rewriter.getType(), zeroPointTensor); scale = rewriter.create( loc, rewriter.getType(), scaleTensor); Value quantizedTensor = rewriter.create( loc, qTensorTy, input, scale, zeroPoint, tyConst); // get uint8 tensor output Value output = rewriter.create(loc, resultType, quantizedTensor); rewriter.replaceOp(binder.op, {output, scaleTensor, zeroPointTensor}); return success(); }); patterns.onOp("Equal", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp("Elu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); Torch::ValueTensorType resultType; Value input; float alpha; if (binder.tensorOperand(input) || binder.f32FloatAttr(alpha, "alpha") || binder.tensorResultType(resultType)) return failure(); Value cstAlpha = rewriter.create( loc, rewriter.getF64FloatAttr(alpha)); Value cstOne = rewriter.create( loc, rewriter.getF64FloatAttr(1.0)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, cstAlpha, /*scale=*/cstOne, /*input_scale=*/cstOne); return success(); }); patterns.onOp("Erf", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; std::string direction; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp("Exp", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp( "Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // uses ideas and code from onnx.Reshape auto loc = binder.getLoc(); Torch::ValueTensorType resultType; Value data, shape; if (binder.tensorOperands(data, shape) || binder.tensorResultType(resultType)) return failure(); auto dataType = cast(data.getType()); auto shapeType = cast(shape.getType()); if (!dataType.hasSizes() || !shapeType.hasSizes()) return failure(); auto shapeSizes = shapeType.getSizes(); int64_t dataRank = dataType.getSizes().size(); int64_t shapeRank = shapeSizes.size(); if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize) return failure(); auto rankDifference = dataRank - shapeSizes[0]; SmallVector selectSizes; Type selectResultType = shapeType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); // Variable to store 1-D onnx shape tensor, shapeSizes[0] has the // dimension size // A constant zero value Value zero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); // Variable to store pytorch int list of shape (dimension) SmallVector dimList; // Convert the shape tensor from vector of int64_t to torch int list as // we are using torch implementation Torch::AtenBroadcastToOp which // takes list of int for (int i = 0; i < shapeSizes[0]; i++) { Value selectIndex = rewriter.create( loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); Value extract = rewriter.create( loc, selectResultType, shape, zero, selectIndex); Value dim = rewriter.create( loc, rewriter.getType(), extract); if (i + rankDifference >= 0) { Value iv = rewriter.create(loc, i + rankDifference); auto sz = rewriter.create( loc, rewriter.getType(), data, iv); dim = rewriter.create(loc, dim, sz); } dimList.push_back(dim); } Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), dimList); rewriter.replaceOpWithNewOp( binder.op, resultType, data, dimValueList); return success(); }); patterns.onOp( "EyeLike", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; int64_t dtypeIntOnnx, diagonalIndex; if (binder.tensorOperand(operand) || binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || binder.s64IntegerAttr(diagonalIndex, "k", 0) || binder.tensorResultType(resultType)) return failure(); auto operandTy = cast(operand.getType()); SmallVector shape(operandTy.getSizes()); for (unsigned i = 0; i < shape.size(); i++) { if (shape[i] == ShapedType::kDynamic) shape[i] = Torch::kUnknownSize; } Value cst0 = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value cst1 = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value nVal = rewriter.create(binder.getLoc(), operand, cst0); Value mVal = rewriter.create(binder.getLoc(), operand, cst1); Value noneVal = rewriter.create(binder.getLoc()); std::optional dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); if (!dtypeIntTorch.has_value()) { return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); } Value dtypeVal = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); // diagonalIndex = 0 populates the main diagonal // diagonalIndex > 0 populates an upper diagonal // diagonalIndex < 0 populates a lower diagonal if (diagonalIndex == 0) { rewriter.replaceOpWithNewOp( binder.op, resultType, nVal, mVal, dtypeVal, noneVal, noneVal, noneVal); return success(); } Value diagVal = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(std::abs(diagonalIndex))); Value newN, newM, dimVal, startVal; // get shapes of main diag eye op and zeros op if (diagonalIndex > 0) { newN = nVal; newM = rewriter.create(binder.getLoc(), mVal, diagVal); if (shape[1] != Torch::kUnknownSize) { shape[1] -= diagonalIndex; } dimVal = cst1; startVal = mVal; } else { newN = rewriter.create(binder.getLoc(), nVal, diagVal); newM = mVal; if (shape[0] != Torch::kUnknownSize) { shape[0] += diagonalIndex; } dimVal = cst0; startVal = nVal; } // create main diag eye op auto eyeResultType = rewriter.getType( shape, resultType.getOptionalDtype()); Value eyeOp = rewriter.create( binder.getLoc(), eyeResultType, newN, newM, dtypeVal, noneVal, noneVal, noneVal); // create zeros op SmallVector zerosShapeValues = {nVal, mVal}; Value zerosShapeList = rewriter.create( binder.getLoc(), rewriter.getType( rewriter.getType()), zerosShapeValues); Value zerosOp = rewriter.create( binder.getLoc(), resultType, zerosShapeList, dtypeVal, noneVal, noneVal, noneVal); // embeds the values of the eye matrix into zeros rewriter.replaceOpWithNewOp( binder.op, resultType, zerosOp, eyeOp, dimVal, /*start=*/diagVal, /*end=*/startVal, /*step=*/cst1); return success(); }); patterns.onOp( "Flatten", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // Flatten means to partition the input tensor's dimensions // into a "left range" spanning 0 to axis - 1 and a "right range" // spanning axis to rank - 1. Each range is then collapsed // into a single dimension, resulting in a 2-D tensor. // If either range is empty, it is replaced with a single // dimension of size 1. // // For example, for a 4-D input tensor of shape (a, b, c, d) // and axis==2, flatten produces a 2-D tensor of shape // (a*b, c*d). // // If instead axis==0, the left range is empty, and the result // is (1, a*b*c*d). Torch::ValueTensorType resultType; Value operand; int64_t axis; if (binder.tensorOperand(operand) || binder.s64IntegerAttr(axis, "axis", 1) || binder.tensorResultType(resultType)) return failure(); auto operandTy = cast(operand.getType()); llvm::SmallVector shape(operandTy.getSizes()); int64_t rank = shape.size(); // If axis is negative, count from the right instead of left if (axis < 0) axis = rank + axis; // We collapse in the dimensions to the right of the axis. for (int i = axis + 1; i < rank; ++i) { bool dynamic = shape[axis] == Torch::kUnknownSize || shape[i] == Torch::kUnknownSize; if (dynamic) { shape[axis] = Torch::kUnknownSize; } else { shape[axis] = shape[axis] * shape[i]; } } shape.resize(axis + 1, 1); auto baseType = rewriter.getType( shape, operandTy.getDtype()); Value collapsedRight; if (axis >= rank) { // If the right range is empty, add a dim of size 1 to the // right side of the shape: // cr = torch.unsqueeze(x, x.ndim) Value rankConst = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(rank)); collapsedRight = rewriter.create( binder.getLoc(), baseType, operand, rankConst); } else { // Otherwise, collapse the right range into a single dimension: // cr = torch._prims.collapse(x, axis, x.ndim - 1) Value axisConst = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(axis)); Value rankLess1Const = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); collapsedRight = rewriter.create( binder.getLoc(), baseType, operand, axisConst, rankLess1Const); } Value zero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); if (axis <= 0) { // If the left range is empty, add a dim of size 1 to the // left side of the shape: // torch.unsqueeze(cr, 0) rewriter.replaceOpWithNewOp( binder.op, resultType, collapsedRight, zero); return success(); } // Otherwise, collapse the left range into a single dimension: // torch._prims.collapse(cr, 0, axis - 1) Value axisLess1Const = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1)); rewriter.replaceOpWithNewOp( binder.op, resultType, collapsedRight, zero, axisLess1Const); return success(); }); patterns.onOp("Floor", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); patterns.onOp( "ConstantOfShape", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value shape; if (binder.tensorOperand(shape) || binder.tensorResultType(resultType)) return failure(); // convert shape tensor to list of ints auto shapeSizes = dyn_cast(shape.getType()).getSizes(); SmallVector dimList; Torch::BaseTensorType shapeType = cast(shape.getType()); Type selectResultType = rewriter.getType( ArrayRef({}), shapeType.getOptionalDtype()); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); for (int i = 0; i < shapeSizes[0]; i++) { Value selectIndex = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); Value extract = rewriter.create( binder.getLoc(), selectResultType, shape, zero, selectIndex); Value dim = rewriter.create( binder.getLoc(), rewriter.getType(), extract); dimList.push_back(dim); } Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), dimList); Value noneVal = rewriter.create(binder.getLoc()); // Get fill_value if it is present. // Assumption : resultDType and value attr type match. auto attr = binder.op->getAttr("torch.onnx.value"); auto resultDType = resultType.getDtype(); // Extract the fill value and dtype // ONNX requires value attr to be a tensor if (!attr) { attr = DenseElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDType, 0.0)); } // If its a dense resource attr we need to convert to a dense type: if (DenseResourceElementsAttr rattr = dyn_cast_or_null(attr)) { // Bytes are stored in little endian order. Big endian support will // require swizzling. if (!Endian::little) { binder.op->emitError( "unimplemented: importing on big endian systems"); return failure(); } auto ty = cast(rattr.getType()); auto ptr = rattr.getRawHandle().getBlob()->getData(); auto denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); attr = dyn_cast_or_null(denseAttr); } Attribute splattr; if (isa(attr)) { auto denseAttr = cast(attr); splattr = denseAttr.getSplatValue(); } if (!isa(splattr)) { return rewriter.notifyMatchFailure( binder.op, "`value` attr tensor only supports types int and float for now."); } Value splatvalue; if (auto intattr = dyn_cast(splattr)) { IntegerType intty = cast(intattr.getType()); int64_t value; if (intty.isUnsignedInteger()) { value = intattr.getUInt(); } else if (intty.isSignedInteger()) { value = intattr.getSInt(); } else { value = intattr.getInt(); } splatvalue = rewriter.create(binder.getLoc(), value); } if (auto fpattr = dyn_cast(splattr)) splatvalue = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(fpattr.getValueAsDouble())); rewriter.replaceOpWithNewOp( binder.op, resultType, dimValueList, splatvalue, /*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); return success(); }); patterns.onOp( "Einsum", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; SmallVector tensors; std::string equation; if (binder.tensorOperands(tensors, binder.op->getNumOperands()) || binder.customOpNameStringAttr(equation, "equation") || binder.tensorResultType(resultType)) return failure(); Type listElemType = cast(tensors[0].getType()) .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value tensorList = rewriter.create( binder.op->getLoc(), listType, tensors); Value cstEquation = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getStringAttr(equation)); Value cstNone = rewriter.create(binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone); return success(); }); patterns.onOp( "BlackmanWindow", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value size; Torch::ValueTensorType resultType; int64_t periodic, output_datatype; if (binder.tensorOperand(size) || binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || binder.s64IntegerAttr(periodic, "periodic", 1) || binder.tensorResultType(resultType)) { return failure(); } Location loc = binder.getLoc(); Value a0 = rewriter.create( loc, rewriter.getF64FloatAttr(0.42)); Value a1 = rewriter.create( loc, rewriter.getF64FloatAttr(-0.5)); Value a2 = rewriter.create( loc, rewriter.getF64FloatAttr(0.08)); auto windowFunctionResult = windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, output_datatype, periodic); if (failed(windowFunctionResult)) return failure(); return success(); }); patterns.onOp( "HannWindow", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value size; Torch::ValueTensorType resultType; int64_t periodic, output_datatype; if (binder.tensorOperand(size) || binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || binder.s64IntegerAttr(periodic, "periodic", 1) || binder.tensorResultType(resultType)) { return failure(); } Location loc = binder.getLoc(); Value a0 = rewriter.create( loc, rewriter.getF64FloatAttr(0.5)); Value a1 = rewriter.create( loc, rewriter.getF64FloatAttr(-0.5)); Value a2 = rewriter.create( loc, rewriter.getF64FloatAttr(0.0)); auto windowFunctionResult = windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, output_datatype, periodic); if (failed(windowFunctionResult)) return failure(); return success(); }); patterns.onOp( "HammingWindow", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value size; Torch::ValueTensorType resultType; int64_t periodic, output_datatype; if (binder.tensorOperand(size) || binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || binder.s64IntegerAttr(periodic, "periodic", 1) || binder.tensorResultType(resultType)) { return failure(); } Location loc = binder.getLoc(); Value a0 = rewriter.create( loc, rewriter.getF64FloatAttr(0.543478)); Value a1 = rewriter.create( loc, rewriter.getF64FloatAttr(-0.456522)); Value a2 = rewriter.create( loc, rewriter.getF64FloatAttr(0.0)); auto windowFunctionResult = windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, output_datatype, periodic); if (failed(windowFunctionResult)) return failure(); return success(); }); patterns.onOp( "DFT", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value inTensor, dftLength, axis; Torch::ValueTensorType resultType; int64_t inverse, onesided; if (binder.tensorOperandAtIndex(inTensor, 0) || binder.s64IntegerAttr(inverse, "inverse", 0) || binder.s64IntegerAttr(onesided, "onesided", 0) || binder.tensorResultType(resultType)) return rewriter.notifyMatchFailure( binder.op, "Input Tensor / attrs / resultType bind failed"); if (!binder.tensorOperandAtIndex(dftLength, 1)) { // Convert to int and pass as n dftLength = rewriter.create( binder.getLoc(), rewriter.getType(), dftLength); } else { // Default for torch is None dftLength = rewriter.create(binder.getLoc()); } // Default is same for onnx and torch if (!binder.tensorOperandAtIndex(axis, 2)) { // convert to int and pass to dims axis = rewriter.create( binder.getLoc(), rewriter.getType(), axis); } else { // Default in torch is -1 and onnx is -2 (since -1 is for real / img) axis = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(-2)); } if (onesided == 1) return rewriter.notifyMatchFailure(binder.op, "Unsupported option : onesided"); // norm default string attr Value norm = rewriter.create( binder.getLoc(), rewriter.getStringAttr(Twine("backward"))); // Convert from [....., 2] complex number repr for fft consumption. Torch::ValueTensorType inType = binder.toValidTensorType(inTensor.getType()); int64_t lastIndex = inType.getSizes().back(); if (lastIndex != 1 && lastIndex != 2) return rewriter.notifyMatchFailure( binder.op, "Expected input tensor to have dims [..., 1] or [..., 2]"); // concat with zeros to make it [..., 2] Value inForComplexVal = inTensor; ArrayRef inForComplexSizes = inType.getSizes().drop_back(); if (lastIndex == 1) { Value constZeroVal = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(0)); Value constOne = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value constZero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value padSizeList = rewriter .create( binder.getLoc(), Torch::ListType::get(rewriter.getType()), SmallVector({constZero, constOne})) .getResult(); Value modeVal = rewriter.create( binder.getLoc(), rewriter.getStringAttr("constant")); SmallVector resSize(inForComplexSizes); resSize.push_back(2); inForComplexVal = rewriter.create( binder.getLoc(), inType.getWithSizesAndDtype(resSize, inType.getOptionalDtype()), inTensor, padSizeList, modeVal, constZeroVal); } Type inComplexTensorType = Torch::ValueTensorType::get( binder.op->getContext(), inForComplexSizes, mlir::ComplexType::get(inType.getDtype())); Value inComplexTensor = rewriter.create( binder.getLoc(), inComplexTensorType, inForComplexVal); Value ftOp; if (inverse == 0) { ftOp = rewriter.create( binder.getLoc(), inComplexTensorType, inComplexTensor, /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); } else { ftOp = rewriter.create( binder.getLoc(), inComplexTensorType, inComplexTensor, /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); } rewriter.replaceOpWithNewOp(binder.op, resultType, ftOp); return success(); }); }