diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 401cfb089..40f3f1076 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1343,11 +1343,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( SmallVector padsRearrange; SmallVector inputPaddingList; for (uint32_t i = 0; i < padding.size() / 2; i++) { - padsRearrange.emplace_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); padsRearrange.emplace_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr( - padding[(padding.size() / 2) + i]))); + padding[padding.size() / 2 - i - 1]))); + padsRearrange.emplace_back(rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(padding[padding.size() - i - 1]))); inputPaddingList.emplace_back( rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0))); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 462eecf74..e5022cea1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -2233,41 +2233,84 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "The axes parameter is not supported yet"); } if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorOperandAtIndex(pads, 1) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(mode, "mode", "constant")) return failure(); + bool cstMode = (mode == "constant"); + + // get input rank + auto dataOpTy = cast(data.getType()); + TensorType dataTensor = dataOpTy.toBuiltinTensor(); + if (!dataTensor || !dataTensor.hasRank()) + return rewriter.notifyMatchFailure( + binder.op, "pad length unknown and data operand unranked"); + int64_t dataRank = dataTensor.getRank(); + int64_t padsSize = 2 * dataRank; + Location loc = binder.getLoc(); - // Get pads shape and rank. The pads tensor is expected to be 1-D - // tensor. - auto padsTensorType = cast(pads.getType()); - if (!padsTensorType || !padsTensorType.hasSizes()) { - return rewriter.notifyMatchFailure(binder.op, - "Expect non empty pad tensor"); - } - ArrayRef padsShape = padsTensorType.getSizes(); - int64_t padsRank = padsShape.size(); - if (padsRank != 1) - return rewriter.notifyMatchFailure(binder.op, - "expect 1-d pad tensor"); + // get pads (earlier versions use an attribute, newer versions use a + // tensor input) + SmallVector padsTensorValue; + if (binder.tensorOperandAtIndex(pads, 1)) { + SmallVector defaultPads(2 * dataRank, 0); + SmallVector padInts; + if (binder.s64IntegerArrayAttr(padInts, "pads", defaultPads)) + return rewriter.notifyMatchFailure(binder.op, + "pads binder failure"); + // opset_version 1 uses the attribute name "paddings" + if (padInts == defaultPads) { + SmallVector paddingsInts; + if (binder.s64IntegerArrayAttr(paddingsInts, "paddings", + defaultPads)) + return rewriter.notifyMatchFailure(binder.op, + "paddings binder failure"); + padInts = paddingsInts; + } + for (auto p : padInts) + padsTensorValue.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(p))); + } else { + // Get pads shape and rank. The pads tensor is expected to be 1-D + // tensor. + auto padsTensorType = cast(pads.getType()); + if (!padsTensorType || !padsTensorType.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non empty pad tensor"); + } + ArrayRef padsShape = padsTensorType.getSizes(); + int64_t padsRank = padsShape.size(); + if (padsRank != 1) + return rewriter.notifyMatchFailure(binder.op, + "expect 1-d pad tensor"); + if (padsShape[0] != Torch::kUnknownSize) { + // As per onnx.Pad documentation, padSize = 2*num_data_axes + // (if axes param not passed). Need to be updated when adding + // support for `axes` param. + padsSize = padsShape[0]; + } - int64_t padsSize = padsShape[0]; - if (padsSize == Torch::kUnknownSize) { - // As per onnx.Pad documentation, padSize = 2*num_data_axes - // (if axes param not passed). Need to be updated when adding - // support for `axes` param. - auto dataOpTy = cast(data.getType()); - TensorType dataTensor = dataOpTy.toBuiltinTensor(); - if (!dataTensor || !dataTensor.hasRank()) - return rewriter.notifyMatchFailure( - binder.op, "pad length unknown and data operand unranked"); - int64_t dataRank = dataTensor.getRank(); - padsSize = 2 * dataRank; + // Extract all the values of 1-D pad tensor and create a list of all + // these values as torch.pad op expects pad list. + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + SmallVector emptyShape; + Type padsElemType = Torch::ValueTensorType::get( + padsTensorType.getContext(), emptyShape, + padsTensorType.getOptionalDtype()); + for (uint32_t i = 0; i < padsSize; ++i) { + Value index = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + auto select = rewriter.create( + loc, padsElemType, pads, constZero, index); + Value selectInt = rewriter.create( + loc, rewriter.getType(), select); + padsTensorValue.push_back(selectInt); + } } Value constantValue; - if (binder.getNumOperands() >= 3) { + if (binder.getNumOperands() >= 3 && cstMode) { if (!binder.tensorOperandAtIndex(constantValue, 2)) { auto constTy = dyn_cast(constantValue.getType()); @@ -2283,38 +2326,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } } - if (!constantValue) { + if (!constantValue && cstMode) { auto dataTensorType = cast(data.getType()); if (isa(dataTensorType.getDtype())) constantValue = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); - if (isa(dataTensorType.getDtype())) + // Earlier versions used a FLOAT attribute to store the constant + // value. The following will pick up on any non-default value attr if + // provided. + float constantFloat; + if (isa(dataTensorType.getDtype()) && + !binder.f32FloatAttr(constantFloat, "value", 0.0f)) constantValue = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0f)); + loc, rewriter.getF64FloatAttr(constantFloat)); if (!constantValue) return rewriter.notifyMatchFailure( binder.op, "expected integer or float data tensor"); } - // Extract all the values of 1-D pad tensor and create a list of all - // these values as torch.pad op expects pad list. - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - SmallVector padsTensorValue; - SmallVector emptyShape; - Type padsElemType = - Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape, - padsTensorType.getOptionalDtype()); - for (uint32_t i = 0; i < padsSize; ++i) { - Value index = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - auto select = rewriter.create( - loc, padsElemType, pads, constZero, index); - Value selectInt = rewriter.create( - loc, rewriter.getType(), select); - padsTensorValue.push_back(selectInt); - } + // for modes other than "constant" a value is not required + if (!cstMode) + constantValue = rewriter.create(loc); // The torch.pad op expects a different arrangement of padding pairs for // each dimension as compared to the onnx.pad op. Rearrange the pad @@ -2335,6 +2368,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Torch::ListType::get(rewriter.getType()), padsRearrange) .getResult(); + + // lowering to AtenConstantPadNdOp directly allows passing any torch + // scalar type for the value, whereas AtenPadOp takes an optional float + // type. + if (cstMode && !isa(constantValue.getType())) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, padsSizeList, constantValue); + return success(); + } + + // translate a few mismatching mode names ONNX -> Torch + mode = (mode == "edge") ? "replicate" : mode; + mode = (mode == "wrap") ? "circular" : mode; + Value modeVal = rewriter.create( loc, rewriter.getStringAttr(mode)); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 06da3e001..02853b140 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -97,8 +97,12 @@ public: Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = cast(newResultType).getElementType(); + + auto dstOriginalDtype = + cast(op.getType()).getDtype(); Value castedValue = - convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); + convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType, + std::nullopt, dstOriginalDtype); Type padType = tensor::PadOp::inferResultType( cast(self.getType()), staticLow, staticHigh); @@ -209,26 +213,38 @@ public: Value one = getConstant(rewriter, loc, 1, indexType); Value hDimSizeMinusOne = createSub(hDimSize, one); Value vDimSizeMinusOne = createSub(vDimSize, one); - SmallVector allOneStrides(numDims, one); + SmallVector allOneStridesVal(numDims, one); + SmallVector allOneStrides = + getAsOpFoldResult(allOneStridesVal); - SmallVector extractOffsetsLT(numDims, zero); - extractOffsetsLT[hDim] = zero; - extractOffsetsLT[vDim] = zero; - SmallVector extractShapeLR(numDims, one); - extractShapeLR[hDim] = one; - extractShapeLR[vDim] = vDimSize; + SmallVector extractOffsetsLTVal(numDims, zero); + extractOffsetsLTVal[hDim] = zero; + extractOffsetsLTVal[vDim] = zero; + SmallVector extractOffsetsLT = + getAsOpFoldResult(extractOffsetsLTVal); + SmallVector extractShapeLRVal(numDims, one); + extractShapeLRVal[hDim] = one; + extractShapeLRVal[vDim] = vDimSize; + SmallVector extractShapeLR = + getAsOpFoldResult(extractShapeLRVal); - SmallVector extractOffsetsRight(numDims, zero); - extractOffsetsRight[hDim] = hDimSizeMinusOne; - extractOffsetsRight[vDim] = zero; + SmallVector extractOffsetsRightVal(numDims, zero); + extractOffsetsRightVal[hDim] = hDimSizeMinusOne; + extractOffsetsRightVal[vDim] = zero; + SmallVector extractOffsetsRight = + getAsOpFoldResult(extractOffsetsRightVal); - SmallVector extractOffsetsBottom(numDims, zero); - extractOffsetsBottom[hDim] = zero; - extractOffsetsBottom[vDim] = vDimSizeMinusOne; + SmallVector extractOffsetsBottomVal(numDims, zero); + extractOffsetsBottomVal[hDim] = zero; + extractOffsetsBottomVal[vDim] = vDimSizeMinusOne; + SmallVector extractOffsetsBottom = + getAsOpFoldResult(extractOffsetsBottomVal); - SmallVector extractShapeTB(numDims, one); - extractShapeTB[hDim] = hDimSize; - extractShapeTB[vDim] = one; + SmallVector extractShapeTBVal(numDims, one); + extractShapeTBVal[hDim] = hDimSize; + extractShapeTBVal[vDim] = one; + SmallVector extractShapeTB = + getAsOpFoldResult(extractShapeTBVal); SmallVector tensorsLeft; SmallVector tensorsRight; @@ -240,24 +256,26 @@ public: Value vCenterLeftSlice = rewriter.create( loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); Value vLeftSlice = vCenterLeftSlice; + SmallVector extractIndices(numDims, zero); if (hasTopPadding) { - Value topLeftValue = rewriter.create( - loc, input, ValueRange{zero, zero, zero, zero}); + Value topLeftValue = + rewriter.create(loc, input, extractIndices); // pad vCenterLeftSlice on the top - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - lowPadding[2] = padInts[2]; + SmallVector lowPadding(numDims, 0); + SmallVector highPadding(numDims, 0); + lowPadding[vDim] = padInts[2]; vLeftSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); } if (hasBottomPadding) { - Value bottomLeftValue = rewriter.create( - loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); + extractIndices[vDim] = vDimSizeMinusOne; + Value bottomLeftValue = + rewriter.create(loc, input, extractIndices); // pad vLeftSlice at the bottom - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - highPadding[2] = padInts[3]; + SmallVector lowPadding(numDims, 0); + SmallVector highPadding(numDims, 0); + highPadding[vDim] = padInts[3]; vLeftSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); } @@ -265,7 +283,7 @@ public: tensorsLeft.push_back(vLeftSlice); } Value leftPadTile = - rewriter.create(loc, 3, tensorsLeft); + rewriter.create(loc, hDim, tensorsLeft); tensorsRes.push_back(leftPadTile); } if (hasTopPadding) { @@ -283,33 +301,35 @@ public: tensorsCenter.push_back(bottomHcenterSlice); } } - centerTile = rewriter.create(loc, 2, tensorsCenter); + centerTile = rewriter.create(loc, vDim, tensorsCenter); tensorsRes.push_back(centerTile); if (hasRightPadding) { Value vCenterRightSlice = rewriter.create( loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); Value vRightSlice = vCenterRightSlice; + SmallVector extractIndices(numDims, zero); + extractIndices[hDim] = hDimSizeMinusOne; if (hasTopPadding) { Value topRightValue = rewriter.create( loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); // pad vCenterRightSlice on the top - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - lowPadding[2] = padInts[2]; + SmallVector lowPadding(numDims, 0); + SmallVector highPadding(numDims, 0); + lowPadding[vDim] = padInts[2]; vRightSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); } if (hasBottomPadding) { - Value bottomRightValue = rewriter.create( - loc, input, - ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); + extractIndices[vDim] = vDimSizeMinusOne; + Value bottomRightValue = + rewriter.create(loc, input, extractIndices); // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - highPadding[2] = padInts[3]; + SmallVector lowPadding(numDims, 0); + SmallVector highPadding(numDims, 0); + highPadding[vDim] = padInts[3]; vRightSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vRightSlice, lowPadding, highPadding, bottomRightValue); @@ -318,10 +338,10 @@ public: tensorsRight.push_back(vRightSlice); } Value rightPadTile = - rewriter.create(loc, 3, tensorsRight); + rewriter.create(loc, hDim, tensorsRight); tensorsRes.push_back(rightPadTile); } - Value resTensor = rewriter.create(loc, 3, tensorsRes); + Value resTensor = rewriter.create(loc, hDim, tensorsRes); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, resTensor); return success(); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 33809cce5..491c6f2f9 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -6379,17 +6379,91 @@ class DecomposeAtenPadOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenPadOp op, PatternRewriter &rewriter) const override { + std::string mode; + if (!matchPattern(op.getMode(), m_TorchConstantStr(mode))) + return rewriter.notifyMatchFailure(op, "mode must be a constant string"); - Value value = op.getValue(); - if (isa(value.getType())) - return rewriter.notifyMatchFailure(op, "optional type not supported"); - if (isa(value.getType())) - value = rewriter.create( - op.getLoc(), rewriter.getF64FloatAttr(0)); + if (mode == "constant") { + Value value = op.getValue(); + if (isa(value.getType())) + return rewriter.notifyMatchFailure(op, "optional type not supported"); + if (isa(value.getType())) + value = rewriter.create( + op.getLoc(), rewriter.getF64FloatAttr(0)); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getPad(), value); - return success(); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getPad(), value); + return success(); + } + + SmallVector padValues; + if (!getListConstructElements(op.getPad(), padValues)) + return failure(); + SmallVector padInts; + Value usefulPads = op.getPad(); + uint64_t usefulPadIndexEnd = padValues.size(); + + // try to reduce the number of padding dims if possible + if (matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) { + if ((padInts.size() % 2) == 1) + return rewriter.notifyMatchFailure(op, + "expected an even number of pads"); + + for (uint64_t i = padInts.size() - 1; i > 0; i -= 2) { + if (padInts[i] != 0 || padInts[i - 1] != 0) + break; + usefulPadIndexEnd = i - 1; + } + if (usefulPadIndexEnd == 0) { + rewriter.replaceOp(op, op.getSelf()); + return success(); + } + } + + // we don't have support for 1-D replicate pad, so pass it as 2d if + // possible. + // TODO: add support for AtenReplicatePad1dOp and remove this. + if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4) + usefulPadIndexEnd = 4; + + // make a new list of padding ints if dimensionality reduction can be + // performed + if (usefulPadIndexEnd < padValues.size()) { + ArrayRef usefulPadValues(padValues.begin(), + padValues.begin() + usefulPadIndexEnd); + usefulPads = rewriter.create( + op.getLoc(), + rewriter.getType(rewriter.getType()), + usefulPadValues); + } + + uint64_t numPadDims = usefulPadIndexEnd / 2; + + if (mode == "reflect") { + // only support for relectionpad 1d and 2d + if (numPadDims == 2) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + return success(); + } + if (numPadDims == 1) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + return success(); + } + return failure(); + } + + if (mode == "replicate") { + // only support for replication pad 2d + if (numPadDims != 2) + return failure(); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + return success(); + } + + return rewriter.notifyMatchFailure(op, "unsupported mode: " + mode); } }; } // namespace diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 5925dd07e..7e52ea116 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -40,7 +40,8 @@ bool isQCommutingOp(mlir::Operation *op) { // if adding a new commuting op here, be sure to add a // RemoveUnused pattern for that op to clean up afterwards return llvm::isa(op); + PrimsCollapseOp, AtenViewOp, AtenPadOp, AtenConstantPadNdOp>( + op); } // The following conversion takes patterns of the form [op0 -> MPTQT -> dequant @@ -65,7 +66,7 @@ public: for (unsigned i : QuantInfo::operandsToQuantize) { Value operand = operands[i]; std::stack commutingOpStack; - Value dequantOpd, MPTQTOpd; + Value dequantOpd, MPTQTOpd, scale, zeroPoint; for (unsigned k = 0; k < depth + 1; k++) { auto currOp = operand.getDefiningOp(); // Case 0 : currOp is a nullptr (e.g., operand is a block argument) @@ -84,6 +85,8 @@ public: auto MPTQTOp = dequantOpd.getDefiningOp(); MPTQTOpd = MPTQTOp.getOperand(0); + scale = MPTQTOp.getOperand(1); + zeroPoint = MPTQTOp.getOperand(2); } // either a dequant was found or chain broken, so break loop break; @@ -107,6 +110,47 @@ public: commutingOpStack.pop(); llvm::SmallVector currOperands(currOp->getOperands()); currOperands[0] = oldOpd; + // pad ops aren't quite commuting, so we include some extra logic to + // quantize the padding value + if (isa(currOp)) { + Value floatPadValue = currOperands.back(); + Value quantPadValue; + if (isa(floatPadValue.getType())) + quantPadValue = rewriter.create(loc, zeroPoint); + else { + floatPadValue = + rewriter.create(loc, floatPadValue); + quantPadValue = rewriter.create( + loc, floatPadValue, scale); + quantPadValue = rewriter.create( + loc, quantPadValue, zeroPoint); + } + // clamp pad value to qint range + if (auto intType = dyn_cast(intDType)) { + bool isSigned = intType.isSignedInteger(); + int64_t width = intType.getWidth(); + assert(width < 64 && + "quantized int bitwidth should be less than 64"); + int64_t minInt = isSigned ? -(1 << (width - 1)) : 0; + int64_t maxInt = isSigned ? -minInt - 1 : ((1 << width) - 1); + Value minQValueFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(minInt)); + Value maxQValueFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(maxInt)); + SmallVector emptyShape; + auto floatTensorType = rewriter.getType( + emptyShape, rewriter.getF64Type()); + Value quantPadValueTensor = createRank0Tensor( + rewriter, loc, floatTensorType, quantPadValue); + Value clampedTensor = rewriter.create( + loc, floatTensorType, quantPadValueTensor, minQValueFloat, + maxQValueFloat); + quantPadValue = rewriter.create( + loc, rewriter.getType(), clampedTensor); + } + // quantPadValue is a float, but will get converted/truncated + currOperands.back() = quantPadValue; + } // get new result type auto oldType = cast(currOp->getResultTypes()[0]); auto intType = @@ -374,7 +418,8 @@ public: RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, - RemoveUnused, + RemoveUnused, RemoveUnused, + RemoveUnused, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index bdc6beb0b..3196efe83 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1014,6 +1014,38 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, // ----- +// CHECK-LABEL: @test_conv_with_asymmetric_padding +func.func @test_conv_with_asymmetric_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int2_1:.*]] = torch.constant.int 2 + // CHECK: %[[int0_2:.*]] = torch.constant.int 0 + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int2]], %[[int2_1]], %[[int0_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[str:.*]] = torch.constant.str "constant" + // CHECK: %[[float0:.*]] = torch.constant.float 0.000 + // CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,7,5],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[1,1,9,7],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,9,7],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + // CHECK: return %[[Conv]] + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [2 : si64, 0 : si64, 0 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// ----- + // CHECK-LABEL: @test_conv_with_bias_strides_padding func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C3:.*]] = torch.constant.int 3 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 38f81f4c0..72cd012b2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -815,32 +815,40 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t // ----- -// CHECK-LABEL: @test_grid_sampler03 -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[B0:.*]] = torch.constant.bool true -// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT1]], %[[INT0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> -func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.mode = "nearest", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> +// CHECK-LABEL: func.func @test_oldest_pad +func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32> + %0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.paddings = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> } // ----- -// CHECK-LABEL: func.func @test_less_or_equal -func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> - // CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> - %0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> - return %0 : !torch.vtensor<[3,4,5],i1> +// CHECK-LABEL: func.func @test_old_pad +func.func @test_old_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32> + %0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.pads = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> } // ----- // CHECK-LABEL: func.func @test_pad func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { - // CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> @@ -854,9 +862,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], // CHECK: %[[INT3:.+]] = torch.constant.int 3 // CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> // CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[STR:.+]] = torch.constant.str "constant" - // CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[5,4],f32> // CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32> %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> return %0 : !torch.vtensor<[5,4],f32> @@ -864,12 +872,36 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], // ----- +// CHECK-LABEL: func.func @test_i32pad +func.func @test_i32pad(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],si32>, !torch.list, !torch.int -> !torch.vtensor<[5,4],si32> + // CHECK: return %[[PAD]] : !torch.vtensor<[5,4],si32> + %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32> + return %0 : !torch.vtensor<[5,4],si32> +} + +// ----- + // CHECK-LABEL: @test_pad_optional_constant // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> // CHECK: %[[VAL:.+]] = torch.constant.float 0 -// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant" -// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> +// CHECK: torch.aten.constant_pad_nd %[[ARG0]], %{{.*}}, %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[5,4],f32> func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> @@ -878,6 +910,34 @@ func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: ! // ----- +// CHECK-LABEL: @test_pad_wrap +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> +// CHECK: %[[VAL:.+]] = torch.constant.none +// CHECK: %[[STR:.+]] = torch.constant.str "circular" +// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32> + +func.func @test_pad_wrap(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "wrap"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_pad_edge +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> +// CHECK: %[[VAL:.+]] = torch.constant.none +// CHECK: %[[STR:.+]] = torch.constant.str "replicate" +// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32> + +func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "edge"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_pow func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index 594295d4e..cb39cbd53 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -82,6 +82,48 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch. // ----- +// CHECK-LABEL: func.func @mm_pad_commute +func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> { + // CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[int7:.*]] = torch.constant.int 7 + // CHECK-DAG: %[[none:.*]] = torch.constant.none + // CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02 + // CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02 + // CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00 + // CHECK-DAG: %[[str:.*]] = torch.constant.str "constant" + // CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[int0:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[int1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64> + // CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64> + // CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float + // CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8> + // CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8> + // CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8> + // CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32> + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %two = torch.constant.int 2 + %floatpad = torch.constant.float 3.5 + %zp = torch.constant.int -128 + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[8,8],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[8,8],!torch.qint8> -> !torch.vtensor<[8,8],f32> + %list = torch.prim.ListConstruct %one, %two, %zero, %one : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %str = torch.constant.str "constant" + %pad = torch.aten.pad %7, %list, %str, %floatpad : !torch.vtensor<[8,8],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[9,11],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[11,4],f32> + %16 = torch.aten.mm %pad, %13 : !torch.vtensor<[9,11],f32>, !torch.vtensor<[11,4],f32> -> !torch.vtensor<[9,4],f32> + return %16 : !torch.vtensor<[9,4],f32> +} + +// ----- + // CHECK-LABEL: @convolution_bias func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { %scale = torch.constant.float 0.5