mirror of https://github.com/llvm/torch-mlir
llvm: update tag to 061e0189 (#1180)
Summary of changes: - Switch to C++17 (similar to https://reviews.llvm.org/D131348) - Update MHLO to build with LLVM commit hash 061e0189 - Replace deprecated `hasValue()` and `getValue()` with `has_value()` and `value()` respectively (https://reviews.llvm.org/D131349) - Use `TypedAttr` (https://reviews.llvm.org/D130092) - Use updated assembly format of `mhlo.compare` op (commit d03ef01e70fbf9afd0fa1976fbb7ed31838929b3 in MHLO repo)pull/1187/head
parent
3e97a33c80
commit
bb47c166a0
|
@ -23,7 +23,7 @@ endif()
|
||||||
|
|
||||||
project(torch-mlir LANGUAGES CXX C)
|
project(torch-mlir LANGUAGES CXX C)
|
||||||
set(CMAKE_C_STANDARD 11)
|
set(CMAKE_C_STANDARD 11)
|
||||||
set(CMAKE_CXX_STANDARD 14)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
macro(torch_mlir_add_llvm_external_project name identifier location)
|
macro(torch_mlir_add_llvm_external_project name identifier location)
|
||||||
message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}")
|
message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}")
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 02b3a358926e7bbcac9226cbecbfc3067c2ad61b
|
Subproject commit 061e0189a3dab6b1831a80d489ff1b15ad93aafb
|
|
@ -1 +1 @@
|
||||||
Subproject commit ad54b43c623cc5ae69b0e90f395b3fba13ffa55a
|
Subproject commit 0430519b7ebf11a3f44c469fce8b579561fa6052
|
|
@ -54,12 +54,12 @@ public:
|
||||||
Type getOptionalDtype() const;
|
Type getOptionalDtype() const;
|
||||||
|
|
||||||
/// Return true if this type has a list of sizes.
|
/// Return true if this type has a list of sizes.
|
||||||
bool hasSizes() const { return getOptionalSizes().hasValue(); }
|
bool hasSizes() const { return getOptionalSizes().has_value(); }
|
||||||
|
|
||||||
/// Get the list of sizes. Requires `hasSizes()`.
|
/// Get the list of sizes. Requires `hasSizes()`.
|
||||||
ArrayRef<int64_t> getSizes() const {
|
ArrayRef<int64_t> getSizes() const {
|
||||||
assert(hasSizes() && "must have sizes");
|
assert(hasSizes() && "must have sizes");
|
||||||
return getOptionalSizes().getValue();
|
return getOptionalSizes().value();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return true if all sizes of this tensor are known.
|
/// Return true if all sizes of this tensor are known.
|
||||||
|
|
|
@ -49,8 +49,8 @@ class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
|
||||||
AttrOrTypeParameter<
|
AttrOrTypeParameter<
|
||||||
"::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> {
|
"::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> {
|
||||||
let allocator = [{
|
let allocator = [{
|
||||||
if ($_self.hasValue()) {
|
if ($_self.has_value()) {
|
||||||
$_dst.getValue() = $_allocator.copyInto($_self.getValue());
|
$_dst.value() = $_allocator.copyInto($_self.value());
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
|
@ -213,7 +213,8 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||||
auto attrTensorType = unwrap(attr).getType().cast<RankedTensorType>();
|
auto attrTensorType =
|
||||||
|
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
|
||||||
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
|
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
|
||||||
attrTensorType.getShape(),
|
attrTensorType.getShape(),
|
||||||
attrTensorType.getElementType()));
|
attrTensorType.getElementType()));
|
||||||
|
|
|
@ -342,7 +342,7 @@ public:
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inferredDimension.hasValue()) {
|
if (inferredDimension.has_value()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "at most one element in size list is allowed to be -1");
|
op, "at most one element in size list is allowed to be -1");
|
||||||
}
|
}
|
||||||
|
@ -363,7 +363,7 @@ public:
|
||||||
// then we don't need to analyze the static information of the input
|
// then we don't need to analyze the static information of the input
|
||||||
// shape since the reassociation of dimensions only requires rank
|
// shape since the reassociation of dimensions only requires rank
|
||||||
// information.
|
// information.
|
||||||
if (inferredDimension.hasValue() && outputShape.size() > 1) {
|
if (inferredDimension.has_value() && outputShape.size() > 1) {
|
||||||
if (llvm::count(outputShape, kUnknownSize) != 1 ||
|
if (llvm::count(outputShape, kUnknownSize) != 1 ||
|
||||||
llvm::count(inputShape, kUnknownSize) != 0) {
|
llvm::count(inputShape, kUnknownSize) != 0) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -585,14 +585,14 @@ public:
|
||||||
collapsedInput = rewriter
|
collapsedInput = rewriter
|
||||||
.create<tensor::ExpandShapeOp>(
|
.create<tensor::ExpandShapeOp>(
|
||||||
loc, adjustedResultType,
|
loc, adjustedResultType,
|
||||||
expandedInput.hasValue() ? expandedInput.value()
|
expandedInput.has_value() ? expandedInput.value()
|
||||||
: castedInput,
|
: castedInput,
|
||||||
outputAssociations)
|
outputAssociations)
|
||||||
.result();
|
.result();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value result = collapsedInput.hasValue() ? collapsedInput.value()
|
Value result = collapsedInput.has_value() ? collapsedInput.value()
|
||||||
: expandedInput.value();
|
: expandedInput.value();
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,7 +119,7 @@ public:
|
||||||
|
|
||||||
SmallVector<int32_t> values(size, fillVal);
|
SmallVector<int32_t> values(size, fillVal);
|
||||||
auto constOp =
|
auto constOp =
|
||||||
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).getValue();
|
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
|
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
|
||||||
return success();
|
return success();
|
||||||
|
@ -884,7 +884,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
op->getLoc(), mhloBatchNormOutTy, input,
|
op->getLoc(), mhloBatchNormOutTy, input,
|
||||||
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
|
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
|
||||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||||
.getValue());
|
.value());
|
||||||
|
|
||||||
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
|
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
|
||||||
SmallVector<APFloat> zeroConstVec(
|
SmallVector<APFloat> zeroConstVec(
|
||||||
|
@ -920,19 +920,19 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
||||||
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
||||||
{static_cast<int64_t>(outputTy.getShape().size())})
|
{static_cast<int64_t>(outputTy.getShape().size())})
|
||||||
.getValue());
|
.value());
|
||||||
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
|
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
|
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
|
||||||
mhlo::getConstTensor(
|
mhlo::getConstTensor(
|
||||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||||
.getValue());
|
.value());
|
||||||
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
|
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
|
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
|
||||||
mhlo::getConstTensor(
|
mhlo::getConstTensor(
|
||||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||||
.getValue());
|
.value());
|
||||||
|
|
||||||
// Apply affine transform: output x weight + bias [element-wise]
|
// Apply affine transform: output x weight + bias [element-wise]
|
||||||
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
||||||
|
|
|
@ -314,8 +314,7 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
initIndexTensor, inputShapeTensor)
|
initIndexTensor, inputShapeTensor)
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
Value initIdx =
|
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||||
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
|
|
||||||
|
|
||||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
||||||
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
||||||
|
@ -491,7 +490,7 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||||
if (countIncludePad) {
|
if (countIncludePad) {
|
||||||
Value divisor = mhlo::getConstTensor<int64_t>(
|
Value divisor = mhlo::getConstTensor<int64_t>(
|
||||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||||
.getValue();
|
.value();
|
||||||
divisor = mhlo::promoteType(rewriter, divisor, outTy);
|
divisor = mhlo::promoteType(rewriter, divisor, outTy);
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||||
|
@ -501,7 +500,7 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||||
Value windowSizeConst =
|
Value windowSizeConst =
|
||||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).getValue();
|
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||||
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||||
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
|
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
|
|
|
@ -87,7 +87,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
if (!initValue) return llvm::None;
|
if (!initValue) return llvm::None;
|
||||||
|
|
||||||
Value initIndex =
|
Value initIndex =
|
||||||
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
|
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||||
|
|
||||||
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
|
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({}, rewriter.getI64Type()), dim);
|
RankedTensorType::get({}, rewriter.getI64Type()), dim);
|
||||||
|
@ -224,7 +224,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
auto inputShapeVec = *inputShapeInfo;
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
auto mhloReduceResults =
|
auto mhloReduceResults =
|
||||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue();
|
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
auto outShapeVec = inputShapeVec;
|
auto outShapeVec = inputShapeVec;
|
||||||
|
@ -301,7 +301,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
auto inputShapeVec = *inputShapeInfo;
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
auto mhloReduceResults =
|
auto mhloReduceResults =
|
||||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue();
|
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
auto outShapeVec = inputShapeVec;
|
auto outShapeVec = inputShapeVec;
|
||||||
|
|
|
@ -178,7 +178,7 @@ public:
|
||||||
}));
|
}));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
if (auto elements = op.valueAttr().dyn_cast<OpaqueElementsAttr>()) {
|
if (auto elements = op.valueAttr().dyn_cast<SparseElementsAttr>()) {
|
||||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
||||||
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
||||||
Type builtinTensorElemTy =
|
Type builtinTensorElemTy =
|
||||||
|
@ -186,8 +186,7 @@ public:
|
||||||
auto shapedType =
|
auto shapedType =
|
||||||
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
|
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
|
||||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||||
op, OpaqueElementsAttr::get(elements.getDialect(), shapedType,
|
op, DenseElementsAttr::get(shapedType, elements.getValues()));
|
||||||
elements.getValue()));
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,7 +148,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
tosaTensor = tosa::getConstTensor<float>(
|
tosaTensor = tosa::getConstTensor<float>(
|
||||||
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
||||||
.getValue();
|
.value();
|
||||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||||
auto w = intType.getWidth();
|
auto w = intType.getWidth();
|
||||||
if (w != 32 && w != 64)
|
if (w != 32 && w != 64)
|
||||||
|
@ -165,7 +165,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||||
: static_cast<int32_t>(intValue);
|
: static_cast<int32_t>(intValue);
|
||||||
tosaTensor =
|
tosaTensor =
|
||||||
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
|
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
|
||||||
} else if (w == 64) {
|
} else if (w == 64) {
|
||||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -174,7 +174,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||||
tosaTensor =
|
tosaTensor =
|
||||||
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
|
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
||||||
|
@ -592,7 +592,7 @@ public:
|
||||||
|
|
||||||
// TBD - support dtype casting.
|
// TBD - support dtype casting.
|
||||||
|
|
||||||
rewriter.replaceOp(op, {result.getValue()});
|
rewriter.replaceOp(op, {result.value()});
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1222,7 +1222,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(transposedLhsType),
|
->convertType(transposedLhsType),
|
||||||
rankBroadcastedLhs, transposedLhsDimsConst.getValue())
|
rankBroadcastedLhs, transposedLhsDimsConst.value())
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1301,7 +1301,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(transposedRhsType),
|
->convertType(transposedRhsType),
|
||||||
rankBroadcastedRhs, transposedRhsDimsConst.getValue())
|
rankBroadcastedRhs, transposedRhsDimsConst.value())
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1452,14 +1452,13 @@ public:
|
||||||
|
|
||||||
auto transposedOpType =
|
auto transposedOpType =
|
||||||
RankedTensorType::get(transposedOpShape, outputElemTy);
|
RankedTensorType::get(transposedOpShape, outputElemTy);
|
||||||
output =
|
output = rewriter
|
||||||
rewriter
|
.create<tosa::TransposeOp>(
|
||||||
.create<tosa::TransposeOp>(
|
op->getLoc(),
|
||||||
op->getLoc(),
|
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
->convertType(transposedOpType),
|
||||||
->convertType(transposedOpType),
|
reshapedOp.getResult(), transposedOpShapeConst.value())
|
||||||
reshapedOp.getResult(), transposedOpShapeConst.getValue())
|
.getResult();
|
||||||
.getResult();
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
output = reshapedOp.getResult();
|
output = reshapedOp.getResult();
|
||||||
|
@ -1646,7 +1645,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
transposedRhsType),
|
transposedRhsType),
|
||||||
rhs, transposedRhsShapeConst.getValue());
|
rhs, transposedRhsShapeConst.value());
|
||||||
|
|
||||||
Value matmulOutput;
|
Value matmulOutput;
|
||||||
if (failed(
|
if (failed(
|
||||||
|
@ -1759,12 +1758,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
SmallVector<int32_t> zeroVec(weightShape[0], 0);
|
SmallVector<int32_t> zeroVec(weightShape[0], 0);
|
||||||
bias = tosa::getConstTensor<int32_t>(
|
bias = tosa::getConstTensor<int32_t>(
|
||||||
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
|
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
|
||||||
.getValue();
|
.value();
|
||||||
} else {
|
} else {
|
||||||
SmallVector<float> zeroVec(weightShape[0], 0);
|
SmallVector<float> zeroVec(weightShape[0], 0);
|
||||||
bias = tosa::getConstTensor<float>(rewriter, op, zeroVec,
|
bias = tosa::getConstTensor<float>(rewriter, op, zeroVec,
|
||||||
{static_cast<int32_t>(weightShape[0])})
|
{static_cast<int32_t>(weightShape[0])})
|
||||||
.getValue();
|
.value();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!bias.getType().cast<RankedTensorType>())
|
if (!bias.getType().cast<RankedTensorType>())
|
||||||
|
@ -1808,7 +1807,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
.create<tosa::TransposeOp>(
|
.create<tosa::TransposeOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
getTypeConverter()->convertType(transposedInputType), input,
|
getTypeConverter()->convertType(transposedInputType), input,
|
||||||
nchwToNhwcTransposeConst.getValue())
|
nchwToNhwcTransposeConst.value())
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
SmallVector<int64_t> transposedWeightShape(
|
SmallVector<int64_t> transposedWeightShape(
|
||||||
|
@ -1820,7 +1819,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
.create<tosa::TransposeOp>(
|
.create<tosa::TransposeOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
getTypeConverter()->convertType(transposedWeightType), weight,
|
getTypeConverter()->convertType(transposedWeightType), weight,
|
||||||
nchwToNhwcTransposeConst.getValue())
|
nchwToNhwcTransposeConst.value())
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
int64_t outputHDim, outputWDim;
|
int64_t outputHDim, outputWDim;
|
||||||
|
@ -1867,7 +1866,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
.create<tosa::TransposeOp>(
|
.create<tosa::TransposeOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
getTypeConverter()->convertType(transposedOutputType),
|
getTypeConverter()->convertType(transposedOutputType),
|
||||||
convOpResult, nhwcToNchwTransposeConst.getValue())
|
convOpResult, nhwcToNchwTransposeConst.value())
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
Value rescaledResult = transposedOutput;
|
Value rescaledResult = transposedOutput;
|
||||||
|
@ -2146,7 +2145,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
auto elemCntConst =
|
auto elemCntConst =
|
||||||
tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
||||||
{static_cast<float>(elemCnt)}, {1})
|
{static_cast<float>(elemCnt)}, {1})
|
||||||
.getValue();
|
.value();
|
||||||
Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>(
|
Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>(
|
||||||
op.getLoc(), elemCntConst.getType(), elemCntConst);
|
op.getLoc(), elemCntConst.getType(), elemCntConst);
|
||||||
|
|
||||||
|
@ -2313,7 +2312,7 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||||
transposeDimsConst.getValue());
|
transposeDimsConst.value());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2333,7 +2332,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||||
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
|
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
|
||||||
auto ln2Op =
|
auto ln2Op =
|
||||||
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
|
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
|
||||||
.getValue();
|
.value();
|
||||||
auto rcpOp =
|
auto rcpOp =
|
||||||
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
|
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
|
||||||
|
|
||||||
|
@ -2523,24 +2522,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
||||||
auto outType = x.getType().cast<TensorType>();
|
auto outType = x.getType().cast<TensorType>();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
|
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
|
||||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue();
|
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue();
|
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
|
||||||
|
|
||||||
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).getValue();
|
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).value();
|
||||||
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
|
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
|
||||||
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
|
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
|
||||||
|
|
||||||
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).getValue();
|
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).value();
|
||||||
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
|
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
|
||||||
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
|
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
|
||||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
|
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
|
||||||
|
|
||||||
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).getValue();
|
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).value();
|
||||||
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
|
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
|
||||||
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
|
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
|
||||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
|
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
|
||||||
|
|
||||||
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).getValue();
|
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).value();
|
||||||
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
|
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
|
||||||
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
|
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
|
||||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
|
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
|
||||||
|
@ -2564,8 +2563,8 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
||||||
|
|
||||||
static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
||||||
Operation *op, Value x) {
|
Operation *op, Value x) {
|
||||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue();
|
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue();
|
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// buildNormalCdf, mean = zero, sigma = one
|
// buildNormalCdf, mean = zero, sigma = one
|
||||||
|
@ -2574,12 +2573,12 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
||||||
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
|
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
|
||||||
// rsqrt of 2
|
// rsqrt of 2
|
||||||
Value rsqrt2 =
|
Value rsqrt2 =
|
||||||
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue();
|
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).value();
|
||||||
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
|
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
|
||||||
/*shift=*/0);
|
/*shift=*/0);
|
||||||
Value erf = approximateErfOp(rewriter, op, erfArg);
|
Value erf = approximateErfOp(rewriter, op, erfArg);
|
||||||
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
|
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
|
||||||
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).getValue();
|
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).value();
|
||||||
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
|
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
|
||||||
erfPlus1, /*shift=*/0);
|
erfPlus1, /*shift=*/0);
|
||||||
return normalCdf;
|
return normalCdf;
|
||||||
|
@ -2651,10 +2650,9 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||||
const double kAlpha = cstAlpha0 * cstAlpha1;
|
const double kAlpha = cstAlpha0 * cstAlpha1;
|
||||||
|
|
||||||
Value kAlphaHalf =
|
Value kAlphaHalf =
|
||||||
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {})
|
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}).value();
|
||||||
.getValue();
|
|
||||||
Value negOneHalf =
|
Value negOneHalf =
|
||||||
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).getValue();
|
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).value();
|
||||||
Value inputSquared = rewriter.create<tosa::MulOp>(
|
Value inputSquared = rewriter.create<tosa::MulOp>(
|
||||||
loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0);
|
loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0);
|
||||||
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
|
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
|
||||||
|
@ -2810,7 +2808,7 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||||
transposeDimsConst.getValue());
|
transposeDimsConst.value());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2992,7 +2990,7 @@ public:
|
||||||
RankedTensorType::get(transposedInputShape, inputElemTy);
|
RankedTensorType::get(transposedInputShape, inputElemTy);
|
||||||
return rewriter
|
return rewriter
|
||||||
.create<tosa::TransposeOp>(op->getLoc(), transposedInputType, input,
|
.create<tosa::TransposeOp>(op->getLoc(), transposedInputType, input,
|
||||||
transposeDimsConst.getValue())
|
transposeDimsConst.value())
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3319,7 +3317,7 @@ public:
|
||||||
|
|
||||||
SmallVector<int32_t> values(size, fillVal);
|
SmallVector<int32_t> values(size, fillVal);
|
||||||
auto constOp =
|
auto constOp =
|
||||||
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).getValue();
|
tosa::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
||||||
|
|
||||||
|
|
|
@ -297,13 +297,13 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
||||||
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
|
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
|
||||||
output_zp);
|
output_zp);
|
||||||
|
|
||||||
if (!val.hasValue())
|
if (!val.has_value())
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
|
||||||
if (!input_is_qtype) {
|
if (!input_is_qtype) {
|
||||||
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
|
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
|
||||||
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
|
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
|
||||||
val.getValue(), div_const, 0)
|
val.value(), div_const, 0)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ Type Torch::parseTorchDialectType(AsmParser &parser) {
|
||||||
StringRef mnemonic;
|
StringRef mnemonic;
|
||||||
Type genType;
|
Type genType;
|
||||||
auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
|
auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
|
||||||
if (parseResult.hasValue())
|
if (parseResult.has_value())
|
||||||
return genType;
|
return genType;
|
||||||
parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `"
|
parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `"
|
||||||
<< TorchDialect::getDialectNamespace() << "`";
|
<< TorchDialect::getDialectNamespace() << "`";
|
||||||
|
|
|
@ -290,7 +290,7 @@ LogicalResult ClassTypeOp::verify() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional<unsigned int> index) {
|
OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional<unsigned int> index) {
|
||||||
assert(index.hasValue() && index.value() == 0);
|
assert(index.has_value() && index.value() == 0);
|
||||||
return iterArgsInit();
|
return iterArgsInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,7 +299,7 @@ void PrimLoopOp::getSuccessorRegions(
|
||||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
(void)operands;
|
(void)operands;
|
||||||
|
|
||||||
if (!index.hasValue()) {
|
if (!index.has_value()) {
|
||||||
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
regions.emplace_back(®ion(), region().getArguments().slice(1));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -371,7 +371,7 @@ void PrimIfOp::getSuccessorRegions(Optional<unsigned> index,
|
||||||
ArrayRef<Attribute> operands,
|
ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
// The `then` and the `else` region branch back to the parent operation.
|
// The `then` and the `else` region branch back to the parent operation.
|
||||||
if (index.hasValue()) {
|
if (index.has_value()) {
|
||||||
regions.push_back(RegionSuccessor(getResults()));
|
regions.push_back(RegionSuccessor(getResults()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -579,7 +579,7 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
|
// r[i] = lo + step*i such that i >= 0 and r[i] < hi
|
||||||
// So maximize `i` such that lo + step * i < hi
|
// So maximize `i` such that lo + step * i < hi
|
||||||
// ==> i == ceildiv(hi - lo, step)
|
// ==> i == ceildiv(hi - lo, step)
|
||||||
return IntegerAttr::get(lo.getType(),
|
return IntegerAttr::get(lo.cast<TypedAttr>().getType(),
|
||||||
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
|
llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt,
|
||||||
APInt::Rounding::UP));
|
APInt::Rounding::UP));
|
||||||
}
|
}
|
||||||
|
@ -597,7 +597,8 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
|
||||||
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||||
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto startInt = start.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||||
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
auto stepInt = step.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||||
return IntegerAttr::get(index.getType(), startInt + stepInt * indexInt);
|
return IntegerAttr::get(index.cast<TypedAttr>().getType(),
|
||||||
|
startInt + stepInt * indexInt);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1946,7 +1947,7 @@ void ShapeCalculateOp::getSuccessorRegions(
|
||||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
(void)operands;
|
(void)operands;
|
||||||
|
|
||||||
if (!index.hasValue()) {
|
if (!index.has_value()) {
|
||||||
// First thing the op does is branch into the shape calculation.
|
// First thing the op does is branch into the shape calculation.
|
||||||
regions.emplace_back(&shapeCalculation());
|
regions.emplace_back(&shapeCalculation());
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -236,7 +236,7 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser,
|
||||||
}
|
}
|
||||||
int64_t size;
|
int64_t size;
|
||||||
auto optionalInt = parser.parseOptionalInteger(size);
|
auto optionalInt = parser.parseOptionalInteger(size);
|
||||||
if (optionalInt.hasValue()) {
|
if (optionalInt.has_value()) {
|
||||||
if (failed(*optionalInt))
|
if (failed(*optionalInt))
|
||||||
return Type();
|
return Type();
|
||||||
sizes.push_back(size);
|
sizes.push_back(size);
|
||||||
|
|
|
@ -646,7 +646,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
monomorphization.argInstances[0].instance.getDefiningOp<NnModuleOp>(),
|
monomorphization.argInstances[0].instance.getDefiningOp<NnModuleOp>(),
|
||||||
monomorphization.func);
|
monomorphization.func);
|
||||||
}
|
}
|
||||||
if (linkageInfo.hasValue()) {
|
if (linkageInfo.has_value()) {
|
||||||
// It's a method.
|
// It's a method.
|
||||||
newFunc.setVisibility(linkageInfo->isPrivate
|
newFunc.setVisibility(linkageInfo->isPrivate
|
||||||
? SymbolTable::Visibility::Private
|
? SymbolTable::Visibility::Private
|
||||||
|
|
|
@ -123,8 +123,8 @@ public:
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
|
|
||||||
DenseMap<int, Type> originalReturnTypes;
|
DenseMap<int, Type> originalReturnTypes;
|
||||||
if (ops.returnOp.hasValue()) {
|
if (ops.returnOp.has_value()) {
|
||||||
auto returnOp = ops.returnOp.getValue();
|
auto returnOp = ops.returnOp.value();
|
||||||
for (auto operand : llvm::enumerate(returnOp->getOperands())) {
|
for (auto operand : llvm::enumerate(returnOp->getOperands())) {
|
||||||
auto type = operand.value().getType();
|
auto type = operand.value().getType();
|
||||||
if (!type.isa<NonValueTensorType>())
|
if (!type.isa<NonValueTensorType>())
|
||||||
|
@ -160,8 +160,8 @@ public:
|
||||||
result.setType(resultType.getWithValueSemantics());
|
result.setType(resultType.getWithValueSemantics());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if (ops.returnOp.hasValue()) {
|
if (ops.returnOp.has_value()) {
|
||||||
auto returnOp = ops.returnOp.getValue();
|
auto returnOp = ops.returnOp.value();
|
||||||
for (int i = 0, e = returnOp->getNumOperands(); i < e; i++) {
|
for (int i = 0, e = returnOp->getNumOperands(); i < e; i++) {
|
||||||
OpOperand &operand = returnOp->getOpOperand(i);
|
OpOperand &operand = returnOp->getOpOperand(i);
|
||||||
auto it = originalReturnTypes.find(i);
|
auto it = originalReturnTypes.find(i);
|
||||||
|
|
|
@ -310,15 +310,15 @@ struct ValueKnowledge {
|
||||||
const ValueKnowledge &rhs) {
|
const ValueKnowledge &rhs) {
|
||||||
Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs);
|
Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs);
|
||||||
|
|
||||||
if (!knowledge.hasValue())
|
if (!knowledge.has_value())
|
||||||
return None;
|
return None;
|
||||||
ValueKnowledge result = knowledge.getValue();
|
ValueKnowledge result = knowledge.value();
|
||||||
|
|
||||||
Optional<OptionalKnowledge> optional =
|
Optional<OptionalKnowledge> optional =
|
||||||
meetOptionalKnowledge(lhs.optional, rhs.optional);
|
meetOptionalKnowledge(lhs.optional, rhs.optional);
|
||||||
if (!optional.hasValue())
|
if (!optional.has_value())
|
||||||
return None;
|
return None;
|
||||||
result.optional = optional.getValue();
|
result.optional = optional.value();
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -518,13 +518,13 @@ updateResultTypeState(const ValueKnowledge *tensor,
|
||||||
Optional<bool> rankIsNonZero,
|
Optional<bool> rankIsNonZero,
|
||||||
const torch_upstream::ResultTypeState &inState,
|
const torch_upstream::ResultTypeState &inState,
|
||||||
bool skipRankCheck = false) {
|
bool skipRankCheck = false) {
|
||||||
if (!rankIsNonZero.hasValue() && !skipRankCheck)
|
if (!rankIsNonZero.has_value() && !skipRankCheck)
|
||||||
return torch_upstream::ResultTypeState{};
|
return torch_upstream::ResultTypeState{};
|
||||||
assert(tensor->dtype && "tensor.dtype must be not none");
|
assert(tensor->dtype && "tensor.dtype must be not none");
|
||||||
|
|
||||||
torch_upstream::ResultTypeState new_state = inState;
|
torch_upstream::ResultTypeState new_state = inState;
|
||||||
torch_upstream::ScalarType current = getScalarTypeForType(tensor->dtype);
|
torch_upstream::ScalarType current = getScalarTypeForType(tensor->dtype);
|
||||||
if (skipRankCheck || rankIsNonZero.getValue())
|
if (skipRankCheck || rankIsNonZero.value())
|
||||||
new_state.dimResult = promote_skip_undefined(inState.dimResult, current);
|
new_state.dimResult = promote_skip_undefined(inState.dimResult, current);
|
||||||
else
|
else
|
||||||
new_state.zeroResult = promote_skip_undefined(inState.zeroResult, current);
|
new_state.zeroResult = promote_skip_undefined(inState.zeroResult, current);
|
||||||
|
@ -1108,8 +1108,8 @@ void TypeAnalysis::incorporateKnowledge(Value v,
|
||||||
const ValueKnowledge &knowledge) {
|
const ValueKnowledge &knowledge) {
|
||||||
auto updatedKnowledge = ValueKnowledge::meet(
|
auto updatedKnowledge = ValueKnowledge::meet(
|
||||||
knowledge, ValueKnowledge::getPessimisticValueState(v));
|
knowledge, ValueKnowledge::getPessimisticValueState(v));
|
||||||
assert(updatedKnowledge.hasValue() && "IR has contradictory type!");
|
assert(updatedKnowledge.has_value() && "IR has contradictory type!");
|
||||||
getLatticeElement(v)->join(updatedKnowledge.getValue());
|
getLatticeElement(v)->join(updatedKnowledge.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op,
|
void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op,
|
||||||
|
@ -1170,9 +1170,9 @@ void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op,
|
||||||
// `dtype` is inferred to be the default dtype, see
|
// `dtype` is inferred to be the default dtype, see
|
||||||
// `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to
|
// `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to
|
||||||
// be `torch.int64`
|
// be `torch.int64`
|
||||||
if ((start.hasValue() && (*start).getType().isa<Torch::FloatType>()) ||
|
if ((start.has_value() && (*start).getType().isa<Torch::FloatType>()) ||
|
||||||
end.getType().isa<Torch::FloatType>() ||
|
end.getType().isa<Torch::FloatType>() ||
|
||||||
(step.hasValue() && (*step).getType().isa<Torch::FloatType>())) {
|
(step.has_value() && (*step).getType().isa<Torch::FloatType>())) {
|
||||||
// TODO: Should get the dtype from torch.get_default_dtype().
|
// TODO: Should get the dtype from torch.get_default_dtype().
|
||||||
// For now, use float32 which is the initial default dtype.
|
// For now, use float32 which is the initial default dtype.
|
||||||
knowledge.dtype = Float32Type::get(op->getContext());
|
knowledge.dtype = Float32Type::get(op->getContext());
|
||||||
|
@ -1264,7 +1264,7 @@ void TypeAnalysis::visitConstantTensorAllocOp(OpTy op,
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
if (!dataType)
|
if (!dataType)
|
||||||
dataType = Torch::FloatType::get(op->getContext());
|
dataType = Torch::FloatType::get(op->getContext());
|
||||||
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue());
|
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.value());
|
||||||
incorporateKnowledge(op.getResult(), knowledge);
|
incorporateKnowledge(op.getResult(), knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1334,11 +1334,11 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
|
||||||
}));
|
}));
|
||||||
for (auto tensor : tensors) {
|
for (auto tensor : tensors) {
|
||||||
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
|
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
|
||||||
if (!newDtype.hasValue()) {
|
if (!newDtype.has_value()) {
|
||||||
incorporateKnowledge(op.getResult(), knowledge);
|
incorporateKnowledge(op.getResult(), knowledge);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
knowledge.dtype = newDtype.getValue();
|
knowledge.dtype = newDtype.value();
|
||||||
}
|
}
|
||||||
incorporateKnowledge(op.getResult(), knowledge);
|
incorporateKnowledge(op.getResult(), knowledge);
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,9 +98,9 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
||||||
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor<i64>
|
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||||
// CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({
|
// CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({
|
||||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<i64>, %[[IVAL_2:.*]]: tensor<f32>, %[[IVAL_3:.*]]: tensor<i64>):
|
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<i64>, %[[IVAL_2:.*]]: tensor<f32>, %[[IVAL_3:.*]]: tensor<i64>):
|
||||||
// CHECK: %[[IVAL_4:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[IVAL_4:.*]] = mhlo.compare GE, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[IVAL_6:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[IVAL_6:.*]] = mhlo.compare EQ, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor<i64>
|
// CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor<i64>
|
||||||
// CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
// CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
@ -215,4 +215,4 @@ func.func @torch.aten.avg_pool2d$count_include_pad(%arg0: !torch.vtensor<[?,?,?,
|
||||||
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
return %3 : !torch.vtensor<[?,?,?,?],f32>
|
return %3 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,9 +17,9 @@
|
||||||
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||||
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||||
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||||
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
@ -58,9 +58,9 @@ func.func @torch.aten.max.dim$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> (!tor
|
||||||
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||||
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||||
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||||
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
@ -95,9 +95,9 @@ func.func @torch.aten.max.dim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtens
|
||||||
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||||
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||||
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||||
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
@ -134,9 +134,9 @@ func.func @torch.aten.argmax$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> !torch
|
||||||
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
// CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
// CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor<?x?xf32>, tensor<?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||||
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
// CHECK: reducer(%[[VAL_11:.*]]: tensor<f32>, %[[VAL_13:.*]]: tensor<f32>) (%[[VAL_12:.*]]: tensor<i64>, %[[VAL_14:.*]]: tensor<i64>) {
|
||||||
// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
// CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor<i64>
|
||||||
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
// CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||||
|
@ -240,4 +240,4 @@ func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<
|
||||||
func.func @torch.aten.max(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
func.func @torch.aten.max(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
|
||||||
%0 = torch.aten.max %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[],f32>
|
%0 = torch.aten.max %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[],f32>
|
||||||
return %0 : !torch.vtensor<[],f32>
|
return %0 : !torch.vtensor<[],f32>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue