llvm: update tag to 061e0189 (#1180)

Summary of changes:
 - Switch to C++17 (similar to https://reviews.llvm.org/D131348)
 - Update MHLO to build with LLVM commit hash 061e0189
 - Replace deprecated `hasValue()` and `getValue()` with `has_value()`
   and `value()` respectively (https://reviews.llvm.org/D131349)
 - Use `TypedAttr` (https://reviews.llvm.org/D130092)
 - Use updated assembly format of `mhlo.compare` op (commit
   d03ef01e70fbf9afd0fa1976fbb7ed31838929b3 in MHLO repo)
pull/1187/head
Ashay Rane 2022-08-08 20:17:35 -07:00 committed by GitHub
parent 3e97a33c80
commit bb47c166a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 106 additions and 108 deletions

View File

@ -23,7 +23,7 @@ endif()
project(torch-mlir LANGUAGES CXX C) project(torch-mlir LANGUAGES CXX C)
set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 17)
macro(torch_mlir_add_llvm_external_project name identifier location) macro(torch_mlir_add_llvm_external_project name identifier location)
message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}") message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}")

@ -1 +1 @@
Subproject commit 02b3a358926e7bbcac9226cbecbfc3067c2ad61b Subproject commit 061e0189a3dab6b1831a80d489ff1b15ad93aafb

2
externals/mlir-hlo vendored

@ -1 +1 @@
Subproject commit ad54b43c623cc5ae69b0e90f395b3fba13ffa55a Subproject commit 0430519b7ebf11a3f44c469fce8b579561fa6052

View File

@ -54,12 +54,12 @@ public:
Type getOptionalDtype() const; Type getOptionalDtype() const;
/// Return true if this type has a list of sizes. /// Return true if this type has a list of sizes.
bool hasSizes() const { return getOptionalSizes().hasValue(); } bool hasSizes() const { return getOptionalSizes().has_value(); }
/// Get the list of sizes. Requires `hasSizes()`. /// Get the list of sizes. Requires `hasSizes()`.
ArrayRef<int64_t> getSizes() const { ArrayRef<int64_t> getSizes() const {
assert(hasSizes() && "must have sizes"); assert(hasSizes() && "must have sizes");
return getOptionalSizes().getValue(); return getOptionalSizes().value();
} }
/// Return true if all sizes of this tensor are known. /// Return true if all sizes of this tensor are known.

View File

@ -49,8 +49,8 @@ class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
AttrOrTypeParameter< AttrOrTypeParameter<
"::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> { "::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> {
let allocator = [{ let allocator = [{
if ($_self.hasValue()) { if ($_self.has_value()) {
$_dst.getValue() = $_allocator.copyInto($_self.getValue()); $_dst.value() = $_allocator.copyInto($_self.value());
} }
}]; }];
} }

View File

@ -213,7 +213,8 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
} }
MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) { MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType = unwrap(attr).getType().cast<RankedTensorType>(); auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(), return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(), attrTensorType.getShape(),
attrTensorType.getElementType())); attrTensorType.getElementType()));

View File

@ -342,7 +342,7 @@ public:
continue; continue;
} }
if (inferredDimension.hasValue()) { if (inferredDimension.has_value()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "at most one element in size list is allowed to be -1"); op, "at most one element in size list is allowed to be -1");
} }
@ -363,7 +363,7 @@ public:
// then we don't need to analyze the static information of the input // then we don't need to analyze the static information of the input
// shape since the reassociation of dimensions only requires rank // shape since the reassociation of dimensions only requires rank
// information. // information.
if (inferredDimension.hasValue() && outputShape.size() > 1) { if (inferredDimension.has_value() && outputShape.size() > 1) {
if (llvm::count(outputShape, kUnknownSize) != 1 || if (llvm::count(outputShape, kUnknownSize) != 1 ||
llvm::count(inputShape, kUnknownSize) != 0) { llvm::count(inputShape, kUnknownSize) != 0) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -585,13 +585,13 @@ public:
collapsedInput = rewriter collapsedInput = rewriter
.create<tensor::ExpandShapeOp>( .create<tensor::ExpandShapeOp>(
loc, adjustedResultType, loc, adjustedResultType,
expandedInput.hasValue() ? expandedInput.value() expandedInput.has_value() ? expandedInput.value()
: castedInput, : castedInput,
outputAssociations) outputAssociations)
.result(); .result();
} }
Value result = collapsedInput.hasValue() ? collapsedInput.value() Value result = collapsedInput.has_value() ? collapsedInput.value()
: expandedInput.value(); : expandedInput.value();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success(); return success();

View File

@ -119,7 +119,7 @@ public:
SmallVector<int32_t> values(size, fillVal); SmallVector<int32_t> values(size, fillVal);
auto constOp = auto constOp =
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).getValue(); mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp); rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
return success(); return success();
@ -884,7 +884,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
op->getLoc(), mhloBatchNormOutTy, input, op->getLoc(), mhloBatchNormOutTy, input,
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())}) {static_cast<int64_t>(inputFlattenShape.size())})
.getValue()); .value());
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec( SmallVector<APFloat> zeroConstVec(
@ -920,19 +920,19 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
mhlo::getConstTensor(rewriter, op, outputTy.getShape(), mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
{static_cast<int64_t>(outputTy.getShape().size())}) {static_cast<int64_t>(outputTy.getShape().size())})
.getValue()); .value());
auto mean = rewriter.create<mhlo::DynamicReshapeOp>( auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
mhlo::getConstTensor( mhlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(), rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())}) {static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.getValue()); .value());
auto var = rewriter.create<mhlo::DynamicReshapeOp>( auto var = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
mhlo::getConstTensor( mhlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(), rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())}) {static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.getValue()); .value());
// Apply affine transform: output x weight + bias [element-wise] // Apply affine transform: output x weight + bias [element-wise]
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);

View File

@ -314,8 +314,7 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
initIndexTensor, inputShapeTensor) initIndexTensor, inputShapeTensor)
.getResult(); .getResult();
Value initIdx = Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>( auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
@ -491,7 +490,7 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
if (countIncludePad) { if (countIncludePad) {
Value divisor = mhlo::getConstTensor<int64_t>( Value divisor = mhlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.getValue(); .value();
divisor = mhlo::promoteType(rewriter, divisor, outTy); divisor = mhlo::promoteType(rewriter, divisor, outTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>( rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
@ -501,7 +500,7 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
// Use another mhlo.ReduceWindowOp to get the divisor // Use another mhlo.ReduceWindowOp to get the divisor
Value windowSizeConst = Value windowSizeConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).getValue(); mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input); auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(

View File

@ -87,7 +87,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
if (!initValue) return llvm::None; if (!initValue) return llvm::None;
Value initIndex = Value initIndex =
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue(); mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
RankedTensorType::get({}, rewriter.getI64Type()), dim); RankedTensorType::get({}, rewriter.getI64Type()), dim);
@ -224,7 +224,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
} }
auto inputShapeVec = *inputShapeInfo; auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = auto mhloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue(); getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
if (keepDim) { if (keepDim) {
auto outShapeVec = inputShapeVec; auto outShapeVec = inputShapeVec;
@ -301,7 +301,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
} }
auto inputShapeVec = *inputShapeInfo; auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = auto mhloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue(); getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
if (keepDim) { if (keepDim) {
auto outShapeVec = inputShapeVec; auto outShapeVec = inputShapeVec;

View File

@ -178,7 +178,7 @@ public:
})); }));
return success(); return success();
} }
if (auto elements = op.valueAttr().dyn_cast<OpaqueElementsAttr>()) { if (auto elements = op.valueAttr().dyn_cast<SparseElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) { if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) { if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
Type builtinTensorElemTy = Type builtinTensorElemTy =
@ -186,8 +186,7 @@ public:
auto shapedType = auto shapedType =
RankedTensorType::get(type.getShape(), builtinTensorElemTy); RankedTensorType::get(type.getShape(), builtinTensorElemTy);
rewriter.replaceOpWithNewOp<arith::ConstantOp>( rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, OpaqueElementsAttr::get(elements.getDialect(), shapedType, op, DenseElementsAttr::get(shapedType, elements.getValues()));
elements.getValue()));
return success(); return success();
} }
} }

View File

@ -148,7 +148,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
if (dtype.isa<mlir::FloatType>()) { if (dtype.isa<mlir::FloatType>()) {
tosaTensor = tosa::getConstTensor<float>( tosaTensor = tosa::getConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dshape) rewriter, op, (isFloat ? doubleValue : intValue), dshape)
.getValue(); .value();
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) { } else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
auto w = intType.getWidth(); auto w = intType.getWidth();
if (w != 32 && w != 64) if (w != 32 && w != 64)
@ -165,7 +165,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
int32_t d = isFloat ? static_cast<int32_t>(doubleValue) int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue); : static_cast<int32_t>(intValue);
tosaTensor = tosaTensor =
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue(); tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
} else if (w == 64) { } else if (w == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) { if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -174,7 +174,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
} }
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue); int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
tosaTensor = tosaTensor =
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue(); tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
} }
} else { } else {
return rewriter.notifyMatchFailure(op, "Usupported element type"); return rewriter.notifyMatchFailure(op, "Usupported element type");
@ -592,7 +592,7 @@ public:
// TBD - support dtype casting. // TBD - support dtype casting.
rewriter.replaceOp(op, {result.getValue()}); rewriter.replaceOp(op, {result.value()});
return success(); return success();
} }
@ -1222,7 +1222,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter() OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(transposedLhsType), ->convertType(transposedLhsType),
rankBroadcastedLhs, transposedLhsDimsConst.getValue()) rankBroadcastedLhs, transposedLhsDimsConst.value())
.getResult(); .getResult();
} }
@ -1301,7 +1301,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter() OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(transposedRhsType), ->convertType(transposedRhsType),
rankBroadcastedRhs, transposedRhsDimsConst.getValue()) rankBroadcastedRhs, transposedRhsDimsConst.value())
.getResult(); .getResult();
} }
@ -1452,13 +1452,12 @@ public:
auto transposedOpType = auto transposedOpType =
RankedTensorType::get(transposedOpShape, outputElemTy); RankedTensorType::get(transposedOpShape, outputElemTy);
output = output = rewriter
rewriter
.create<tosa::TransposeOp>( .create<tosa::TransposeOp>(
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter() OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(transposedOpType), ->convertType(transposedOpType),
reshapedOp.getResult(), transposedOpShapeConst.getValue()) reshapedOp.getResult(), transposedOpShapeConst.value())
.getResult(); .getResult();
} else { } else {
@ -1646,7 +1645,7 @@ public:
op->getLoc(), op->getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
transposedRhsType), transposedRhsType),
rhs, transposedRhsShapeConst.getValue()); rhs, transposedRhsShapeConst.value());
Value matmulOutput; Value matmulOutput;
if (failed( if (failed(
@ -1759,12 +1758,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
SmallVector<int32_t> zeroVec(weightShape[0], 0); SmallVector<int32_t> zeroVec(weightShape[0], 0);
bias = tosa::getConstTensor<int32_t>( bias = tosa::getConstTensor<int32_t>(
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])}) rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
.getValue(); .value();
} else { } else {
SmallVector<float> zeroVec(weightShape[0], 0); SmallVector<float> zeroVec(weightShape[0], 0);
bias = tosa::getConstTensor<float>(rewriter, op, zeroVec, bias = tosa::getConstTensor<float>(rewriter, op, zeroVec,
{static_cast<int32_t>(weightShape[0])}) {static_cast<int32_t>(weightShape[0])})
.getValue(); .value();
} }
} else { } else {
if (!bias.getType().cast<RankedTensorType>()) if (!bias.getType().cast<RankedTensorType>())
@ -1808,7 +1807,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
.create<tosa::TransposeOp>( .create<tosa::TransposeOp>(
op->getLoc(), op->getLoc(),
getTypeConverter()->convertType(transposedInputType), input, getTypeConverter()->convertType(transposedInputType), input,
nchwToNhwcTransposeConst.getValue()) nchwToNhwcTransposeConst.value())
.getResult(); .getResult();
SmallVector<int64_t> transposedWeightShape( SmallVector<int64_t> transposedWeightShape(
@ -1820,7 +1819,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
.create<tosa::TransposeOp>( .create<tosa::TransposeOp>(
op->getLoc(), op->getLoc(),
getTypeConverter()->convertType(transposedWeightType), weight, getTypeConverter()->convertType(transposedWeightType), weight,
nchwToNhwcTransposeConst.getValue()) nchwToNhwcTransposeConst.value())
.getResult(); .getResult();
int64_t outputHDim, outputWDim; int64_t outputHDim, outputWDim;
@ -1867,7 +1866,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
.create<tosa::TransposeOp>( .create<tosa::TransposeOp>(
op->getLoc(), op->getLoc(),
getTypeConverter()->convertType(transposedOutputType), getTypeConverter()->convertType(transposedOutputType),
convOpResult, nhwcToNchwTransposeConst.getValue()) convOpResult, nhwcToNchwTransposeConst.value())
.getResult(); .getResult();
Value rescaledResult = transposedOutput; Value rescaledResult = transposedOutput;
@ -2146,7 +2145,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto elemCntConst = auto elemCntConst =
tosa::getConstTensor<float>(rewriter, op.getOperation(), tosa::getConstTensor<float>(rewriter, op.getOperation(),
{static_cast<float>(elemCnt)}, {1}) {static_cast<float>(elemCnt)}, {1})
.getValue(); .value();
Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>( Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), elemCntConst.getType(), elemCntConst); op.getLoc(), elemCntConst.getType(), elemCntConst);
@ -2313,7 +2312,7 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<tosa::TransposeOp>( rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(), op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
transposeDimsConst.getValue()); transposeDimsConst.value());
return success(); return success();
} }
@ -2333,7 +2332,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1); SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
auto ln2Op = auto ln2Op =
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape) tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
.getValue(); .value();
auto rcpOp = auto rcpOp =
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op); rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
@ -2523,24 +2522,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
auto outType = x.getType().cast<TensorType>(); auto outType = x.getType().cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x); auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue(); auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue(); auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).getValue(); auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).value();
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0); auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one); auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).getValue(); auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).value();
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0); auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0); auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X); sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).getValue(); auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).value();
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0); auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0); auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X); sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).getValue(); auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).value();
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0); auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0); auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X); sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
@ -2564,8 +2563,8 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Operation *op, Value x) { Operation *op, Value x) {
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue(); auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue(); auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
auto loc = op->getLoc(); auto loc = op->getLoc();
// buildNormalCdf, mean = zero, sigma = one // buildNormalCdf, mean = zero, sigma = one
@ -2574,12 +2573,12 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean); Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
// rsqrt of 2 // rsqrt of 2
Value rsqrt2 = Value rsqrt2 =
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue(); tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).value();
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2, Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
/*shift=*/0); /*shift=*/0);
Value erf = approximateErfOp(rewriter, op, erfArg); Value erf = approximateErfOp(rewriter, op, erfArg);
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf); Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).getValue(); Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).value();
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf, Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
erfPlus1, /*shift=*/0); erfPlus1, /*shift=*/0);
return normalCdf; return normalCdf;
@ -2651,10 +2650,9 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
const double kAlpha = cstAlpha0 * cstAlpha1; const double kAlpha = cstAlpha0 * cstAlpha1;
Value kAlphaHalf = Value kAlphaHalf =
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}) tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}).value();
.getValue();
Value negOneHalf = Value negOneHalf =
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).getValue(); tosa::getConstTensor<float>(rewriter, op, -0.5, {}).value();
Value inputSquared = rewriter.create<tosa::MulOp>( Value inputSquared = rewriter.create<tosa::MulOp>(
loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0); loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0);
Value negHalfInputSquared = rewriter.create<tosa::MulOp>( Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
@ -2810,7 +2808,7 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<tosa::TransposeOp>( rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(), op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
transposeDimsConst.getValue()); transposeDimsConst.value());
return success(); return success();
} }
@ -2992,7 +2990,7 @@ public:
RankedTensorType::get(transposedInputShape, inputElemTy); RankedTensorType::get(transposedInputShape, inputElemTy);
return rewriter return rewriter
.create<tosa::TransposeOp>(op->getLoc(), transposedInputType, input, .create<tosa::TransposeOp>(op->getLoc(), transposedInputType, input,
transposeDimsConst.getValue()) transposeDimsConst.value())
.getResult(); .getResult();
} }
@ -3319,7 +3317,7 @@ public:
SmallVector<int32_t> values(size, fillVal); SmallVector<int32_t> values(size, fillVal);
auto constOp = auto constOp =
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).getValue(); tosa::getConstTensor<int32_t>(rewriter, op, values, shape).value();
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp); rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);

View File

@ -297,13 +297,13 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale, reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
output_zp); output_zp);
if (!val.hasValue()) if (!val.has_value())
return llvm::None; return llvm::None;
if (!input_is_qtype) { if (!input_is_qtype) {
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type, return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
val.getValue(), div_const, 0) val.value(), div_const, 0)
.getResult(); .getResult();
} }

View File

@ -65,7 +65,7 @@ Type Torch::parseTorchDialectType(AsmParser &parser) {
StringRef mnemonic; StringRef mnemonic;
Type genType; Type genType;
auto parseResult = generatedTypeParser(parser, &mnemonic, genType); auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
if (parseResult.hasValue()) if (parseResult.has_value())
return genType; return genType;
parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `" parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `"
<< TorchDialect::getDialectNamespace() << "`"; << TorchDialect::getDialectNamespace() << "`";

View File

@ -290,7 +290,7 @@ LogicalResult ClassTypeOp::verify() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional<unsigned int> index) { OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional<unsigned int> index) {
assert(index.hasValue() && index.value() == 0); assert(index.has_value() && index.value() == 0);
return iterArgsInit(); return iterArgsInit();
} }
@ -299,7 +299,7 @@ void PrimLoopOp::getSuccessorRegions(
SmallVectorImpl<RegionSuccessor> &regions) { SmallVectorImpl<RegionSuccessor> &regions) {
(void)operands; (void)operands;
if (!index.hasValue()) { if (!index.has_value()) {
regions.emplace_back(&region(), region().getArguments().slice(1)); regions.emplace_back(&region(), region().getArguments().slice(1));
return; return;
} }
@ -371,7 +371,7 @@ void PrimIfOp::getSuccessorRegions(Optional<unsigned> index,
ArrayRef<Attribute> operands, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) { SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation. // The `then` and the `else` region branch back to the parent operation.
if (index.hasValue()) { if (index.has_value()) {
regions.push_back(RegionSuccessor(getResults())); regions.push_back(RegionSuccessor(getResults()));
return; return;
} }
@ -579,7 +579,7 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
// r[i] = lo + step*i such that i >= 0 and r[i] < hi // r[i] = lo + step*i such that i >= 0 and r[i] < hi
// So maximize `i` such that lo + step * i < hi // So maximize `i` such that lo + step * i < hi
// ==> i == ceildiv(hi - lo, step) // ==> i == ceildiv(hi - lo, step)
return IntegerAttr::get(lo.getType(), return IntegerAttr::get(lo.cast<TypedAttr>().getType(),
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt, llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
APInt::Rounding::UP)); APInt::Rounding::UP));
} }
@ -597,7 +597,8 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue(); auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue(); auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue();
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue(); auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
return IntegerAttr::get(index.getType(), startInt + stepInt * indexInt); return IntegerAttr::get(index.cast<TypedAttr>().getType(),
startInt + stepInt * indexInt);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1946,7 +1947,7 @@ void ShapeCalculateOp::getSuccessorRegions(
SmallVectorImpl<RegionSuccessor> &regions) { SmallVectorImpl<RegionSuccessor> &regions) {
(void)operands; (void)operands;
if (!index.hasValue()) { if (!index.has_value()) {
// First thing the op does is branch into the shape calculation. // First thing the op does is branch into the shape calculation.
regions.emplace_back(&shapeCalculation()); regions.emplace_back(&shapeCalculation());
return; return;

View File

@ -236,7 +236,7 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser,
} }
int64_t size; int64_t size;
auto optionalInt = parser.parseOptionalInteger(size); auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) { if (optionalInt.has_value()) {
if (failed(*optionalInt)) if (failed(*optionalInt))
return Type(); return Type();
sizes.push_back(size); sizes.push_back(size);

View File

@ -646,7 +646,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
monomorphization.argInstances[0].instance.getDefiningOp<NnModuleOp>(), monomorphization.argInstances[0].instance.getDefiningOp<NnModuleOp>(),
monomorphization.func); monomorphization.func);
} }
if (linkageInfo.hasValue()) { if (linkageInfo.has_value()) {
// It's a method. // It's a method.
newFunc.setVisibility(linkageInfo->isPrivate newFunc.setVisibility(linkageInfo->isPrivate
? SymbolTable::Visibility::Private ? SymbolTable::Visibility::Private

View File

@ -123,8 +123,8 @@ public:
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
DenseMap<int, Type> originalReturnTypes; DenseMap<int, Type> originalReturnTypes;
if (ops.returnOp.hasValue()) { if (ops.returnOp.has_value()) {
auto returnOp = ops.returnOp.getValue(); auto returnOp = ops.returnOp.value();
for (auto operand : llvm::enumerate(returnOp->getOperands())) { for (auto operand : llvm::enumerate(returnOp->getOperands())) {
auto type = operand.value().getType(); auto type = operand.value().getType();
if (!type.isa<NonValueTensorType>()) if (!type.isa<NonValueTensorType>())
@ -160,8 +160,8 @@ public:
result.setType(resultType.getWithValueSemantics()); result.setType(resultType.getWithValueSemantics());
}); });
} }
if (ops.returnOp.hasValue()) { if (ops.returnOp.has_value()) {
auto returnOp = ops.returnOp.getValue(); auto returnOp = ops.returnOp.value();
for (int i = 0, e = returnOp->getNumOperands(); i < e; i++) { for (int i = 0, e = returnOp->getNumOperands(); i < e; i++) {
OpOperand &operand = returnOp->getOpOperand(i); OpOperand &operand = returnOp->getOpOperand(i);
auto it = originalReturnTypes.find(i); auto it = originalReturnTypes.find(i);

View File

@ -310,15 +310,15 @@ struct ValueKnowledge {
const ValueKnowledge &rhs) { const ValueKnowledge &rhs) {
Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs); Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs);
if (!knowledge.hasValue()) if (!knowledge.has_value())
return None; return None;
ValueKnowledge result = knowledge.getValue(); ValueKnowledge result = knowledge.value();
Optional<OptionalKnowledge> optional = Optional<OptionalKnowledge> optional =
meetOptionalKnowledge(lhs.optional, rhs.optional); meetOptionalKnowledge(lhs.optional, rhs.optional);
if (!optional.hasValue()) if (!optional.has_value())
return None; return None;
result.optional = optional.getValue(); result.optional = optional.value();
return result; return result;
} }
@ -518,13 +518,13 @@ updateResultTypeState(const ValueKnowledge *tensor,
Optional<bool> rankIsNonZero, Optional<bool> rankIsNonZero,
const torch_upstream::ResultTypeState &inState, const torch_upstream::ResultTypeState &inState,
bool skipRankCheck = false) { bool skipRankCheck = false) {
if (!rankIsNonZero.hasValue() && !skipRankCheck) if (!rankIsNonZero.has_value() && !skipRankCheck)
return torch_upstream::ResultTypeState{}; return torch_upstream::ResultTypeState{};
assert(tensor->dtype && "tensor.dtype must be not none"); assert(tensor->dtype && "tensor.dtype must be not none");
torch_upstream::ResultTypeState new_state = inState; torch_upstream::ResultTypeState new_state = inState;
torch_upstream::ScalarType current = getScalarTypeForType(tensor->dtype); torch_upstream::ScalarType current = getScalarTypeForType(tensor->dtype);
if (skipRankCheck || rankIsNonZero.getValue()) if (skipRankCheck || rankIsNonZero.value())
new_state.dimResult = promote_skip_undefined(inState.dimResult, current); new_state.dimResult = promote_skip_undefined(inState.dimResult, current);
else else
new_state.zeroResult = promote_skip_undefined(inState.zeroResult, current); new_state.zeroResult = promote_skip_undefined(inState.zeroResult, current);
@ -1108,8 +1108,8 @@ void TypeAnalysis::incorporateKnowledge(Value v,
const ValueKnowledge &knowledge) { const ValueKnowledge &knowledge) {
auto updatedKnowledge = ValueKnowledge::meet( auto updatedKnowledge = ValueKnowledge::meet(
knowledge, ValueKnowledge::getPessimisticValueState(v)); knowledge, ValueKnowledge::getPessimisticValueState(v));
assert(updatedKnowledge.hasValue() && "IR has contradictory type!"); assert(updatedKnowledge.has_value() && "IR has contradictory type!");
getLatticeElement(v)->join(updatedKnowledge.getValue()); getLatticeElement(v)->join(updatedKnowledge.value());
} }
void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op, void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op,
@ -1170,9 +1170,9 @@ void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op,
// `dtype` is inferred to be the default dtype, see // `dtype` is inferred to be the default dtype, see
// `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to // `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to
// be `torch.int64` // be `torch.int64`
if ((start.hasValue() && (*start).getType().isa<Torch::FloatType>()) || if ((start.has_value() && (*start).getType().isa<Torch::FloatType>()) ||
end.getType().isa<Torch::FloatType>() || end.getType().isa<Torch::FloatType>() ||
(step.hasValue() && (*step).getType().isa<Torch::FloatType>())) { (step.has_value() && (*step).getType().isa<Torch::FloatType>())) {
// TODO: Should get the dtype from torch.get_default_dtype(). // TODO: Should get the dtype from torch.get_default_dtype().
// For now, use float32 which is the initial default dtype. // For now, use float32 which is the initial default dtype.
knowledge.dtype = Float32Type::get(op->getContext()); knowledge.dtype = Float32Type::get(op->getContext());
@ -1264,7 +1264,7 @@ void TypeAnalysis::visitConstantTensorAllocOp(OpTy op,
ValueKnowledge::getTensorPessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
if (!dataType) if (!dataType)
dataType = Torch::FloatType::get(op->getContext()); dataType = Torch::FloatType::get(op->getContext());
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue()); fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.value());
incorporateKnowledge(op.getResult(), knowledge); incorporateKnowledge(op.getResult(), knowledge);
} }
@ -1334,11 +1334,11 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
})); }));
for (auto tensor : tensors) { for (auto tensor : tensors) {
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype); auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
if (!newDtype.hasValue()) { if (!newDtype.has_value()) {
incorporateKnowledge(op.getResult(), knowledge); incorporateKnowledge(op.getResult(), knowledge);
return; return;
} }
knowledge.dtype = newDtype.getValue(); knowledge.dtype = newDtype.value();
} }
incorporateKnowledge(op.getResult(), knowledge); incorporateKnowledge(op.getResult(), knowledge);
} }

View File

@ -98,9 +98,9 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor<i64> // CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({ // CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<i64>, %[[IVAL_2:.*]]: tensor<f32>, %[[IVAL_3:.*]]: tensor<i64>): // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<i64>, %[[IVAL_2:.*]]: tensor<f32>, %[[IVAL_3:.*]]: tensor<i64>):
// CHECK: %[[IVAL_4:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[IVAL_4:.*]] = mhlo.compare GE, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> // CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[IVAL_6:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[IVAL_6:.*]] = mhlo.compare EQ, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor<i64> // CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor<i64>
// CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>

View File

@ -17,9 +17,9 @@
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64> // CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>) // CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) { // CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> // CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64> // CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
@ -58,9 +58,9 @@ func.func @torch.aten.max.dim$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> (!tor
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64> // CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>) // CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) { // CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> // CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64> // CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
@ -95,9 +95,9 @@ func.func @torch.aten.max.dim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtens
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64> // CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>) // CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) { // CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> // CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64> // CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
@ -134,9 +134,9 @@ func.func @torch.aten.argmax$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> !torch
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64> // CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>) // CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) { // CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> // CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64> // CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> // CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>