mirror of https://github.com/llvm/torch-mlir
llvm: update tag to 061e0189 (#1180)
Summary of changes: - Switch to C++17 (similar to https://reviews.llvm.org/D131348) - Update MHLO to build with LLVM commit hash 061e0189 - Replace deprecated `hasValue()` and `getValue()` with `has_value()` and `value()` respectively (https://reviews.llvm.org/D131349) - Use `TypedAttr` (https://reviews.llvm.org/D130092) - Use updated assembly format of `mhlo.compare` op (commit d03ef01e70fbf9afd0fa1976fbb7ed31838929b3 in MHLO repo)pull/1187/head
parent
3e97a33c80
commit
bb47c166a0
|
@ -23,7 +23,7 @@ endif()
|
|||
|
||||
project(torch-mlir LANGUAGES CXX C)
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
macro(torch_mlir_add_llvm_external_project name identifier location)
|
||||
message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}")
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 02b3a358926e7bbcac9226cbecbfc3067c2ad61b
|
||||
Subproject commit 061e0189a3dab6b1831a80d489ff1b15ad93aafb
|
|
@ -1 +1 @@
|
|||
Subproject commit ad54b43c623cc5ae69b0e90f395b3fba13ffa55a
|
||||
Subproject commit 0430519b7ebf11a3f44c469fce8b579561fa6052
|
|
@ -54,12 +54,12 @@ public:
|
|||
Type getOptionalDtype() const;
|
||||
|
||||
/// Return true if this type has a list of sizes.
|
||||
bool hasSizes() const { return getOptionalSizes().hasValue(); }
|
||||
bool hasSizes() const { return getOptionalSizes().has_value(); }
|
||||
|
||||
/// Get the list of sizes. Requires `hasSizes()`.
|
||||
ArrayRef<int64_t> getSizes() const {
|
||||
assert(hasSizes() && "must have sizes");
|
||||
return getOptionalSizes().getValue();
|
||||
return getOptionalSizes().value();
|
||||
}
|
||||
|
||||
/// Return true if all sizes of this tensor are known.
|
||||
|
|
|
@ -49,8 +49,8 @@ class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
|
|||
AttrOrTypeParameter<
|
||||
"::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> {
|
||||
let allocator = [{
|
||||
if ($_self.hasValue()) {
|
||||
$_dst.getValue() = $_allocator.copyInto($_self.getValue());
|
||||
if ($_self.has_value()) {
|
||||
$_dst.value() = $_allocator.copyInto($_self.value());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -213,7 +213,8 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
|||
}
|
||||
|
||||
MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||
auto attrTensorType = unwrap(attr).getType().cast<RankedTensorType>();
|
||||
auto attrTensorType =
|
||||
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
|
||||
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
|
||||
attrTensorType.getShape(),
|
||||
attrTensorType.getElementType()));
|
||||
|
|
|
@ -342,7 +342,7 @@ public:
|
|||
continue;
|
||||
}
|
||||
|
||||
if (inferredDimension.hasValue()) {
|
||||
if (inferredDimension.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "at most one element in size list is allowed to be -1");
|
||||
}
|
||||
|
@ -363,7 +363,7 @@ public:
|
|||
// then we don't need to analyze the static information of the input
|
||||
// shape since the reassociation of dimensions only requires rank
|
||||
// information.
|
||||
if (inferredDimension.hasValue() && outputShape.size() > 1) {
|
||||
if (inferredDimension.has_value() && outputShape.size() > 1) {
|
||||
if (llvm::count(outputShape, kUnknownSize) != 1 ||
|
||||
llvm::count(inputShape, kUnknownSize) != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -585,14 +585,14 @@ public:
|
|||
collapsedInput = rewriter
|
||||
.create<tensor::ExpandShapeOp>(
|
||||
loc, adjustedResultType,
|
||||
expandedInput.hasValue() ? expandedInput.value()
|
||||
: castedInput,
|
||||
expandedInput.has_value() ? expandedInput.value()
|
||||
: castedInput,
|
||||
outputAssociations)
|
||||
.result();
|
||||
}
|
||||
|
||||
Value result = collapsedInput.hasValue() ? collapsedInput.value()
|
||||
: expandedInput.value();
|
||||
Value result = collapsedInput.has_value() ? collapsedInput.value()
|
||||
: expandedInput.value();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -119,7 +119,7 @@ public:
|
|||
|
||||
SmallVector<int32_t> values(size, fillVal);
|
||||
auto constOp =
|
||||
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).getValue();
|
||||
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
|
||||
return success();
|
||||
|
@ -884,7 +884,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
op->getLoc(), mhloBatchNormOutTy, input,
|
||||
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
|
||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||
.getValue());
|
||||
.value());
|
||||
|
||||
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
|
||||
SmallVector<APFloat> zeroConstVec(
|
||||
|
@ -920,19 +920,19 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
||||
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
||||
{static_cast<int64_t>(outputTy.getShape().size())})
|
||||
.getValue());
|
||||
.value());
|
||||
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
|
||||
mhlo::getConstTensor(
|
||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||
.getValue());
|
||||
.value());
|
||||
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
|
||||
mhlo::getConstTensor(
|
||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||
.getValue());
|
||||
.value());
|
||||
|
||||
// Apply affine transform: output x weight + bias [element-wise]
|
||||
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
||||
|
|
|
@ -314,8 +314,7 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
initIndexTensor, inputShapeTensor)
|
||||
.getResult();
|
||||
|
||||
Value initIdx =
|
||||
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
|
||||
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
|
||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
||||
|
@ -491,7 +490,7 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
if (countIncludePad) {
|
||||
Value divisor = mhlo::getConstTensor<int64_t>(
|
||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||
.getValue();
|
||||
.value();
|
||||
divisor = mhlo::promoteType(rewriter, divisor, outTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||
|
@ -501,7 +500,7 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
|
||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||
Value windowSizeConst =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).getValue();
|
||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
|
|
|
@ -87,7 +87,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
if (!initValue) return llvm::None;
|
||||
|
||||
Value initIndex =
|
||||
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
|
||||
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
|
||||
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({}, rewriter.getI64Type()), dim);
|
||||
|
@ -224,7 +224,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto mhloReduceResults =
|
||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue();
|
||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
|
@ -301,7 +301,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto mhloReduceResults =
|
||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue();
|
||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
|
|
|
@ -178,7 +178,7 @@ public:
|
|||
}));
|
||||
return success();
|
||||
}
|
||||
if (auto elements = op.valueAttr().dyn_cast<OpaqueElementsAttr>()) {
|
||||
if (auto elements = op.valueAttr().dyn_cast<SparseElementsAttr>()) {
|
||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
||||
Type builtinTensorElemTy =
|
||||
|
@ -186,8 +186,7 @@ public:
|
|||
auto shapedType =
|
||||
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, OpaqueElementsAttr::get(elements.getDialect(), shapedType,
|
||||
elements.getValue()));
|
||||
op, DenseElementsAttr::get(shapedType, elements.getValues()));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
if (dtype.isa<mlir::FloatType>()) {
|
||||
tosaTensor = tosa::getConstTensor<float>(
|
||||
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
||||
.getValue();
|
||||
.value();
|
||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
auto w = intType.getWidth();
|
||||
if (w != 32 && w != 64)
|
||||
|
@ -165,7 +165,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||
: static_cast<int32_t>(intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
|
||||
} else if (w == 64) {
|
||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -174,7 +174,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
|
||||
}
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
||||
|
@ -592,7 +592,7 @@ public:
|
|||
|
||||
// TBD - support dtype casting.
|
||||
|
||||
rewriter.replaceOp(op, {result.getValue()});
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1222,7 +1222,7 @@ public:
|
|||
op->getLoc(),
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(transposedLhsType),
|
||||
rankBroadcastedLhs, transposedLhsDimsConst.getValue())
|
||||
rankBroadcastedLhs, transposedLhsDimsConst.value())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
@ -1301,7 +1301,7 @@ public:
|
|||
op->getLoc(),
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(transposedRhsType),
|
||||
rankBroadcastedRhs, transposedRhsDimsConst.getValue())
|
||||
rankBroadcastedRhs, transposedRhsDimsConst.value())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
@ -1452,14 +1452,13 @@ public:
|
|||
|
||||
auto transposedOpType =
|
||||
RankedTensorType::get(transposedOpShape, outputElemTy);
|
||||
output =
|
||||
rewriter
|
||||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(transposedOpType),
|
||||
reshapedOp.getResult(), transposedOpShapeConst.getValue())
|
||||
.getResult();
|
||||
output = rewriter
|
||||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(transposedOpType),
|
||||
reshapedOp.getResult(), transposedOpShapeConst.value())
|
||||
.getResult();
|
||||
|
||||
} else {
|
||||
output = reshapedOp.getResult();
|
||||
|
@ -1646,7 +1645,7 @@ public:
|
|||
op->getLoc(),
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
transposedRhsType),
|
||||
rhs, transposedRhsShapeConst.getValue());
|
||||
rhs, transposedRhsShapeConst.value());
|
||||
|
||||
Value matmulOutput;
|
||||
if (failed(
|
||||
|
@ -1759,12 +1758,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
SmallVector<int32_t> zeroVec(weightShape[0], 0);
|
||||
bias = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
|
||||
.getValue();
|
||||
.value();
|
||||
} else {
|
||||
SmallVector<float> zeroVec(weightShape[0], 0);
|
||||
bias = tosa::getConstTensor<float>(rewriter, op, zeroVec,
|
||||
{static_cast<int32_t>(weightShape[0])})
|
||||
.getValue();
|
||||
.value();
|
||||
}
|
||||
} else {
|
||||
if (!bias.getType().cast<RankedTensorType>())
|
||||
|
@ -1808,7 +1807,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
getTypeConverter()->convertType(transposedInputType), input,
|
||||
nchwToNhwcTransposeConst.getValue())
|
||||
nchwToNhwcTransposeConst.value())
|
||||
.getResult();
|
||||
|
||||
SmallVector<int64_t> transposedWeightShape(
|
||||
|
@ -1820,7 +1819,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
getTypeConverter()->convertType(transposedWeightType), weight,
|
||||
nchwToNhwcTransposeConst.getValue())
|
||||
nchwToNhwcTransposeConst.value())
|
||||
.getResult();
|
||||
|
||||
int64_t outputHDim, outputWDim;
|
||||
|
@ -1867,7 +1866,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
|||
.create<tosa::TransposeOp>(
|
||||
op->getLoc(),
|
||||
getTypeConverter()->convertType(transposedOutputType),
|
||||
convOpResult, nhwcToNchwTransposeConst.getValue())
|
||||
convOpResult, nhwcToNchwTransposeConst.value())
|
||||
.getResult();
|
||||
|
||||
Value rescaledResult = transposedOutput;
|
||||
|
@ -2146,7 +2145,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
auto elemCntConst =
|
||||
tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
||||
{static_cast<float>(elemCnt)}, {1})
|
||||
.getValue();
|
||||
.value();
|
||||
Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op.getLoc(), elemCntConst.getType(), elemCntConst);
|
||||
|
||||
|
@ -2313,7 +2312,7 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
|
||||
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||
transposeDimsConst.getValue());
|
||||
transposeDimsConst.value());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -2333,7 +2332,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
|||
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
|
||||
auto ln2Op =
|
||||
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
|
||||
.getValue();
|
||||
.value();
|
||||
auto rcpOp =
|
||||
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
|
||||
|
||||
|
@ -2523,24 +2522,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
|||
auto outType = x.getType().cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue();
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
|
||||
|
||||
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).getValue();
|
||||
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).value();
|
||||
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
|
||||
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
|
||||
|
||||
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).getValue();
|
||||
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).value();
|
||||
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
|
||||
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
|
||||
|
||||
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).getValue();
|
||||
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).value();
|
||||
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
|
||||
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
|
||||
|
||||
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).getValue();
|
||||
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).value();
|
||||
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
|
||||
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
|
||||
|
@ -2564,8 +2563,8 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
|||
|
||||
static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value x) {
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue();
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// buildNormalCdf, mean = zero, sigma = one
|
||||
|
@ -2574,12 +2573,12 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
|||
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
|
||||
// rsqrt of 2
|
||||
Value rsqrt2 =
|
||||
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue();
|
||||
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).value();
|
||||
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
|
||||
/*shift=*/0);
|
||||
Value erf = approximateErfOp(rewriter, op, erfArg);
|
||||
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
|
||||
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).getValue();
|
||||
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).value();
|
||||
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
|
||||
erfPlus1, /*shift=*/0);
|
||||
return normalCdf;
|
||||
|
@ -2651,10 +2650,9 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
const double kAlpha = cstAlpha0 * cstAlpha1;
|
||||
|
||||
Value kAlphaHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {})
|
||||
.getValue();
|
||||
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}).value();
|
||||
Value negOneHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).getValue();
|
||||
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).value();
|
||||
Value inputSquared = rewriter.create<tosa::MulOp>(
|
||||
loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0);
|
||||
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
|
||||
|
@ -2810,7 +2808,7 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
|
|||
|
||||
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||
transposeDimsConst.getValue());
|
||||
transposeDimsConst.value());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -2992,7 +2990,7 @@ public:
|
|||
RankedTensorType::get(transposedInputShape, inputElemTy);
|
||||
return rewriter
|
||||
.create<tosa::TransposeOp>(op->getLoc(), transposedInputType, input,
|
||||
transposeDimsConst.getValue())
|
||||
transposeDimsConst.value())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
@ -3319,7 +3317,7 @@ public:
|
|||
|
||||
SmallVector<int32_t> values(size, fillVal);
|
||||
auto constOp =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).getValue();
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
||||
|
||||
|
|
|
@ -297,13 +297,13 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
|||
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
|
||||
output_zp);
|
||||
|
||||
if (!val.hasValue())
|
||||
if (!val.has_value())
|
||||
return llvm::None;
|
||||
|
||||
if (!input_is_qtype) {
|
||||
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
|
||||
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
|
||||
val.getValue(), div_const, 0)
|
||||
val.value(), div_const, 0)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ Type Torch::parseTorchDialectType(AsmParser &parser) {
|
|||
StringRef mnemonic;
|
||||
Type genType;
|
||||
auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
|
||||
if (parseResult.hasValue())
|
||||
if (parseResult.has_value())
|
||||
return genType;
|
||||
parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `"
|
||||
<< TorchDialect::getDialectNamespace() << "`";
|
||||
|
|
|
@ -290,7 +290,7 @@ LogicalResult ClassTypeOp::verify() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional<unsigned int> index) {
|
||||
assert(index.hasValue() && index.value() == 0);
|
||||
assert(index.has_value() && index.value() == 0);
|
||||
return iterArgsInit();
|
||||
}
|
||||
|
||||
|
@ -299,7 +299,7 @@ void PrimLoopOp::getSuccessorRegions(
|
|||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
(void)operands;
|
||||
|
||||
if (!index.hasValue()) {
|
||||
if (!index.has_value()) {
|
||||
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
||||
return;
|
||||
}
|
||||
|
@ -371,7 +371,7 @@ void PrimIfOp::getSuccessorRegions(Optional<unsigned> index,
|
|||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// The `then` and the `else` region branch back to the parent operation.
|
||||
if (index.hasValue()) {
|
||||
if (index.has_value()) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
|
@ -579,7 +579,7 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
|
|||
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
|
||||
// So maximize `i` such that lo + step * i < hi
|
||||
// ==> i == ceildiv(hi - lo, step)
|
||||
return IntegerAttr::get(lo.getType(),
|
||||
return IntegerAttr::get(lo.cast<TypedAttr>().getType(),
|
||||
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
|
||||
APInt::Rounding::UP));
|
||||
}
|
||||
|
@ -597,7 +597,8 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
|
|||
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||
return IntegerAttr::get(index.getType(), startInt + stepInt * indexInt);
|
||||
return IntegerAttr::get(index.cast<TypedAttr>().getType(),
|
||||
startInt + stepInt * indexInt);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1946,7 +1947,7 @@ void ShapeCalculateOp::getSuccessorRegions(
|
|||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
(void)operands;
|
||||
|
||||
if (!index.hasValue()) {
|
||||
if (!index.has_value()) {
|
||||
// First thing the op does is branch into the shape calculation.
|
||||
regions.emplace_back(&shapeCalculation());
|
||||
return;
|
||||
|
|
|
@ -236,7 +236,7 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser,
|
|||
}
|
||||
int64_t size;
|
||||
auto optionalInt = parser.parseOptionalInteger(size);
|
||||
if (optionalInt.hasValue()) {
|
||||
if (optionalInt.has_value()) {
|
||||
if (failed(*optionalInt))
|
||||
return Type();
|
||||
sizes.push_back(size);
|
||||
|
|
|
@ -646,7 +646,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
|||
monomorphization.argInstances[0].instance.getDefiningOp<NnModuleOp>(),
|
||||
monomorphization.func);
|
||||
}
|
||||
if (linkageInfo.hasValue()) {
|
||||
if (linkageInfo.has_value()) {
|
||||
// It's a method.
|
||||
newFunc.setVisibility(linkageInfo->isPrivate
|
||||
? SymbolTable::Visibility::Private
|
||||
|
|
|
@ -123,8 +123,8 @@ public:
|
|||
PatternRewriter &rewriter) {
|
||||
|
||||
DenseMap<int, Type> originalReturnTypes;
|
||||
if (ops.returnOp.hasValue()) {
|
||||
auto returnOp = ops.returnOp.getValue();
|
||||
if (ops.returnOp.has_value()) {
|
||||
auto returnOp = ops.returnOp.value();
|
||||
for (auto operand : llvm::enumerate(returnOp->getOperands())) {
|
||||
auto type = operand.value().getType();
|
||||
if (!type.isa<NonValueTensorType>())
|
||||
|
@ -160,8 +160,8 @@ public:
|
|||
result.setType(resultType.getWithValueSemantics());
|
||||
});
|
||||
}
|
||||
if (ops.returnOp.hasValue()) {
|
||||
auto returnOp = ops.returnOp.getValue();
|
||||
if (ops.returnOp.has_value()) {
|
||||
auto returnOp = ops.returnOp.value();
|
||||
for (int i = 0, e = returnOp->getNumOperands(); i < e; i++) {
|
||||
OpOperand &operand = returnOp->getOpOperand(i);
|
||||
auto it = originalReturnTypes.find(i);
|
||||
|
|
|
@ -310,15 +310,15 @@ struct ValueKnowledge {
|
|||
const ValueKnowledge &rhs) {
|
||||
Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs);
|
||||
|
||||
if (!knowledge.hasValue())
|
||||
if (!knowledge.has_value())
|
||||
return None;
|
||||
ValueKnowledge result = knowledge.getValue();
|
||||
ValueKnowledge result = knowledge.value();
|
||||
|
||||
Optional<OptionalKnowledge> optional =
|
||||
meetOptionalKnowledge(lhs.optional, rhs.optional);
|
||||
if (!optional.hasValue())
|
||||
if (!optional.has_value())
|
||||
return None;
|
||||
result.optional = optional.getValue();
|
||||
result.optional = optional.value();
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -518,13 +518,13 @@ updateResultTypeState(const ValueKnowledge *tensor,
|
|||
Optional<bool> rankIsNonZero,
|
||||
const torch_upstream::ResultTypeState &inState,
|
||||
bool skipRankCheck = false) {
|
||||
if (!rankIsNonZero.hasValue() && !skipRankCheck)
|
||||
if (!rankIsNonZero.has_value() && !skipRankCheck)
|
||||
return torch_upstream::ResultTypeState{};
|
||||
assert(tensor->dtype && "tensor.dtype must be not none");
|
||||
|
||||
torch_upstream::ResultTypeState new_state = inState;
|
||||
torch_upstream::ScalarType current = getScalarTypeForType(tensor->dtype);
|
||||
if (skipRankCheck || rankIsNonZero.getValue())
|
||||
if (skipRankCheck || rankIsNonZero.value())
|
||||
new_state.dimResult = promote_skip_undefined(inState.dimResult, current);
|
||||
else
|
||||
new_state.zeroResult = promote_skip_undefined(inState.zeroResult, current);
|
||||
|
@ -1108,8 +1108,8 @@ void TypeAnalysis::incorporateKnowledge(Value v,
|
|||
const ValueKnowledge &knowledge) {
|
||||
auto updatedKnowledge = ValueKnowledge::meet(
|
||||
knowledge, ValueKnowledge::getPessimisticValueState(v));
|
||||
assert(updatedKnowledge.hasValue() && "IR has contradictory type!");
|
||||
getLatticeElement(v)->join(updatedKnowledge.getValue());
|
||||
assert(updatedKnowledge.has_value() && "IR has contradictory type!");
|
||||
getLatticeElement(v)->join(updatedKnowledge.value());
|
||||
}
|
||||
|
||||
void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op,
|
||||
|
@ -1170,9 +1170,9 @@ void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op,
|
|||
// `dtype` is inferred to be the default dtype, see
|
||||
// `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to
|
||||
// be `torch.int64`
|
||||
if ((start.hasValue() && (*start).getType().isa<Torch::FloatType>()) ||
|
||||
if ((start.has_value() && (*start).getType().isa<Torch::FloatType>()) ||
|
||||
end.getType().isa<Torch::FloatType>() ||
|
||||
(step.hasValue() && (*step).getType().isa<Torch::FloatType>())) {
|
||||
(step.has_value() && (*step).getType().isa<Torch::FloatType>())) {
|
||||
// TODO: Should get the dtype from torch.get_default_dtype().
|
||||
// For now, use float32 which is the initial default dtype.
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
|
@ -1264,7 +1264,7 @@ void TypeAnalysis::visitConstantTensorAllocOp(OpTy op,
|
|||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
if (!dataType)
|
||||
dataType = Torch::FloatType::get(op->getContext());
|
||||
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue());
|
||||
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.value());
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
}
|
||||
|
||||
|
@ -1334,11 +1334,11 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
|
|||
}));
|
||||
for (auto tensor : tensors) {
|
||||
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
|
||||
if (!newDtype.hasValue()) {
|
||||
if (!newDtype.has_value()) {
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
knowledge.dtype = newDtype.getValue();
|
||||
knowledge.dtype = newDtype.value();
|
||||
}
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
}
|
||||
|
|
|
@ -98,9 +98,9 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
|||
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({
|
||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<i64>, %[[IVAL_2:.*]]: tensor<f32>, %[[IVAL_3:.*]]: tensor<i64>):
|
||||
// CHECK: %[[IVAL_4:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[IVAL_4:.*]] = mhlo.compare GE, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[IVAL_6:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[IVAL_6:.*]] = mhlo.compare EQ, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor<i64>
|
||||
// CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
|
@ -215,4 +215,4 @@ func.func @torch.aten.avg_pool2d$count_include_pad(%arg0: !torch.vtensor<[?,?,?,
|
|||
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
|
@ -58,9 +58,9 @@ func.func @torch.aten.max.dim$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> (!tor
|
|||
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
|
@ -95,9 +95,9 @@ func.func @torch.aten.max.dim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtens
|
|||
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
|
@ -134,9 +134,9 @@ func.func @torch.aten.argmax$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> !torch
|
|||
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
|
@ -240,4 +240,4 @@ func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<
|
|||
func.func @torch.aten.max(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||
%0 = torch.aten.max %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue