mirror of https://github.com/llvm/torch-mlir
Bump stablehlo to openxla/stablehlo@fd52182f76 (#2821)
With the recent LLVM integrate and changes from https://github.com/llvm/llvm-project/pull/78260, we hit this build error in Stablehlo (which is quite old). ``` external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter' rewriter.startRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter' rewriter.finalizeRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter' rewriter.cancelRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter' rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; }); ~~~~~~~~ ^ 4 errors generated. Target @torch-mlir//:torch-mlir-opt failed to build ``` I'm still puzzled as to how this didn't fail with the CMake merge gating CI (do we not test Stablehlo builds/tests?). In any case, bumping our submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it. It exposes a new failing lit test in TorchToStablehlo though, that I have looped stablehlo developers into ([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)). ``` bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test ...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> ^ LLVM ERROR: Failed to infer result type(s). ``` Bazel CI: https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228pull/2847/head
parent
54e258792c
commit
8a17c98b74
|
@ -1 +1 @@
|
|||
Subproject commit ab709fe48de88c67717abfbd7ef17425eb95ddaf
|
||||
Subproject commit fd52182f76cadb82f2064fe5fc49a4fb4347a826
|
|
@ -377,12 +377,12 @@ public:
|
|||
if (!skipMultiplyAlpha(op.getAlpha())) {
|
||||
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
||||
adaptor.getAlpha(), outElemTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
|
||||
bcastDimensions);
|
||||
}
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
return success();
|
||||
|
@ -424,7 +424,7 @@ public:
|
|||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
auto loc = op.getLoc();
|
||||
|
@ -542,7 +542,7 @@ public:
|
|||
} else {
|
||||
return op.emitError("operator haven't been supported");
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastCompareOp>(
|
||||
op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr,
|
||||
compareTypeAttr);
|
||||
|
@ -570,7 +570,7 @@ public:
|
|||
Value rhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
return success();
|
||||
|
@ -757,7 +757,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op, outType, self, bcastShapeTensor,
|
||||
rewriter.getI64TensorAttr(dimensionNumbers));
|
||||
rewriter.getDenseI64ArrayAttr(dimensionNumbers));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -887,7 +887,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
if (!rhsType) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
auto loc = op.getLoc();
|
||||
|
@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
|
||||
Value window =
|
||||
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
||||
DenseIntElementsAttr broadcastDimensions;
|
||||
DenseI64ArrayAttr broadcastDimensions;
|
||||
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
|
||||
broadcastDimensions);
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, mulOut, start,
|
||||
|
@ -1721,7 +1721,7 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
|||
rewriter.create<shape::ShapeOfOp>(op->getLoc(), adaptor.getSelf());
|
||||
Value bcastScalar = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(), outType, scalarTensor, shapeTensor,
|
||||
rewriter.getI64TensorAttr({}));
|
||||
rewriter.getDenseI64ArrayAttr({}));
|
||||
rewriter.replaceOp(op, bcastScalar);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -334,7 +334,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
|||
return failure();
|
||||
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0}));
|
||||
op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}),
|
||||
elementTy);
|
||||
|
||||
Region ®ion = stablehloReduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
|
@ -510,7 +511,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||||
op, input, gatherIndicies, dimsAttr,
|
||||
rewriter.getI64TensorAttr(sliceSizes));
|
||||
rewriter.getDenseI64ArrayAttr(sliceSizes));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -666,7 +667,8 @@ LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite(
|
|||
/*indexVectorDim=*/indexVecDim);
|
||||
|
||||
auto stablehloScatterOp = rewriter.create<stablehlo::ScatterOp>(
|
||||
loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false);
|
||||
loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers,
|
||||
false, false);
|
||||
|
||||
// config update computation function: just return the element from src.
|
||||
Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock();
|
||||
|
@ -833,7 +835,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||||
op, resultType, input, finalIndexTensor, dimsAttr,
|
||||
rewriter.getI64TensorAttr(sliceSizes));
|
||||
rewriter.getDenseI64ArrayAttr(sliceSizes));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -39,10 +39,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
|||
RankedTensorType outTy =
|
||||
RankedTensorType::get(shape, tensorTy.getElementType());
|
||||
|
||||
RankedTensorType attrTy =
|
||||
RankedTensorType::get({static_cast<int64_t>(broadcastDims.size())},
|
||||
rewriter.getIntegerType(64));
|
||||
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
|
||||
auto broadcastAttr = rewriter.getDenseI64ArrayAttr(broadcastDims);
|
||||
|
||||
auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
loc, outTy, tensor, stablehloShape, broadcastAttr);
|
||||
|
@ -549,8 +546,7 @@ public:
|
|||
|
||||
// Prepare for transposed convolution
|
||||
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
|
||||
DenseIntElementsAttr stablehloStride =
|
||||
rewriter.getI64TensorAttr(stablehloStrideVec);
|
||||
auto stablehloStride = rewriter.getDenseI64ArrayAttr(stablehloStrideVec);
|
||||
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
|
||||
for (int i = 0; i < nSpatialDims; ++i) {
|
||||
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
|
||||
|
@ -563,15 +559,15 @@ public:
|
|||
stablehloPaddingVec);
|
||||
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
|
||||
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
|
||||
DenseIntElementsAttr stablehloLhsDilation =
|
||||
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
|
||||
auto stablehloLhsDilation =
|
||||
rewriter.getDenseI64ArrayAttr(stablehloLhsDilationVec);
|
||||
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
stablehloRhsDilationVec.begin());
|
||||
DenseIntElementsAttr stablehloRhsDilation =
|
||||
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
|
||||
auto stablehloRhsDilation =
|
||||
rewriter.getDenseI64ArrayAttr(stablehloRhsDilationVec);
|
||||
|
||||
DenseElementsAttr windowReversal;
|
||||
DenseBoolArrayAttr windowReversal;
|
||||
ArrayAttr precisionConfig;
|
||||
|
||||
SmallVector<int64_t> spatialDims;
|
||||
|
@ -614,10 +610,7 @@ public:
|
|||
int64_t nDims = outType.getRank();
|
||||
|
||||
// Get stablehlo::ConvolutionOp attributes
|
||||
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<long int>(stride.size())},
|
||||
rewriter.getI64Type()),
|
||||
stride);
|
||||
auto stablehloWindowStride = rewriter.getDenseI64ArrayAttr(stride);
|
||||
std::vector<int64_t> stablehloPaddingVec;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
stablehloPaddingVec.emplace_back(padding[i]);
|
||||
|
@ -628,10 +621,7 @@ public:
|
|||
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
|
||||
rewriter.getI64Type()),
|
||||
stablehloPaddingVec);
|
||||
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<long int>(dilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
dilation);
|
||||
auto stablehloRhsDilation = rewriter.getDenseI64ArrayAttr(dilation);
|
||||
SmallVector<int64_t> spatialDimensions;
|
||||
for (int64_t i = 2; i < nDims; i++) {
|
||||
spatialDimensions.emplace_back(i);
|
||||
|
@ -648,8 +638,8 @@ public:
|
|||
/*outputSpatialDimensions=*/spatialDimensions);
|
||||
|
||||
// stablehlo::ConvolutionOp's optional attributes, leave them as default
|
||||
DenseIntElementsAttr stablehloLhsDilation;
|
||||
DenseElementsAttr windowReversal;
|
||||
DenseI64ArrayAttr stablehloLhsDilation;
|
||||
DenseBoolArrayAttr windowReversal;
|
||||
ArrayAttr precisionConfig;
|
||||
|
||||
auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
|
||||
|
@ -781,7 +771,7 @@ public:
|
|||
options.dimSizeIndexBits);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||
op, outTy, stablehloConvResult, bias, bcastDimensions);
|
||||
return success();
|
||||
|
|
|
@ -136,19 +136,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloDilation);
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
|
@ -242,19 +233,10 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloDilation);
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
|
@ -453,20 +435,10 @@ public:
|
|||
Value initVal =
|
||||
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloDilation);
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
|
@ -508,7 +480,7 @@ public:
|
|||
.value();
|
||||
}
|
||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||
return success();
|
||||
|
@ -528,7 +500,7 @@ public:
|
|||
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
||||
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
|
||||
windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({}));
|
||||
|
||||
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
|
@ -599,19 +571,10 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
|||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
stablehloPadding[dim * 2] = inputShape[dim] - 1;
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloDilation);
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
|
|
|
@ -130,7 +130,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
initValue,
|
||||
initIndex,
|
||||
},
|
||||
rewriter.getI64TensorAttr(dim));
|
||||
rewriter.getDenseI64ArrayAttr(dim));
|
||||
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
|
||||
|
@ -412,7 +412,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
@ -473,7 +473,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
return failure();
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
@ -535,7 +535,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
|
|||
return failure();
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
@ -625,7 +625,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Region ®ion = stablehloReduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
|
@ -729,7 +729,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
|
||||
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), squareOp.getResult(), initValue,
|
||||
rewriter.getI64TensorAttr(dims));
|
||||
rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Region ®ion = reduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
|
@ -848,7 +848,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
|||
ord, nullptr);
|
||||
|
||||
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims));
|
||||
op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));
|
||||
|
||||
Region ®ion = reduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
|
|
|
@ -241,10 +241,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
|||
if (!do_bcast) {
|
||||
return input;
|
||||
}
|
||||
DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
|
||||
rewriter.getI64Type()),
|
||||
bcastDims);
|
||||
auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims);
|
||||
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
|
||||
op->getLoc(), outType, input, bcast_attr);
|
||||
return bcast_op.getResult();
|
||||
|
@ -360,7 +357,7 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
|||
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
|
||||
return rewriter
|
||||
.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
loc, outType, constTensor, shape, rewriter.getI64TensorAttr({}))
|
||||
loc, outType, constTensor, shape, rewriter.getDenseI64ArrayAttr({}))
|
||||
.getResult();
|
||||
}
|
||||
} // namespace hlo
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
|
@ -51,7 +51,7 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||
// CHECK: })
|
||||
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
|
@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
|||
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
|
||||
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
||||
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
||||
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
|
||||
// CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
|
||||
|
@ -141,7 +141,7 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
|||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
||||
|
@ -162,7 +162,7 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
|||
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
|
@ -198,7 +198,7 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
||||
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[T10]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
|
||||
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue