Merge remote-tracking branch 'upstream/main' into emit_detach_

pull/3287/head
Xinyu Yang 2024-05-06 13:04:03 +08:00
commit 99e2f143f0
66 changed files with 1709 additions and 376 deletions

View File

@ -11,7 +11,7 @@ repos:
- id: check-yaml - id: check-yaml
- id: check-added-large-files - id: check-added-large-files
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 22.10.0 rev: 24.4.2
hooks: hooks:
- id: black - id: black

22
.yamllint.yml 100644
View File

@ -0,0 +1,22 @@
---
extends: default
rules:
# These do not appear to be conventional in GitHub actions.
document-end:
present: false
document-start:
present: false
# GitHub actions use "on" for triggers.
truthy: disable
# We have lots of long strings and command lines.
line-length: disable
comments:
# Formatters may do this (e.g. Prettier does) and it seems like the most
# trivial thing to get a failing check for.
min-spaces-from-content: 1
# This is not a useful check, especially when disabling entire blocks.
comments-indentation: disable
ignore: /third_party/*

View File

@ -50,6 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \
-DLLVM_TARGETS_TO_BUILD=host \ -DLLVM_TARGETS_TO_BUILD=host \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_LTC=ON
echo "::endgroup::" echo "::endgroup::"
echo "::group::Build" echo "::group::Build"

View File

@ -432,6 +432,8 @@ function clean_build() {
} }
function build_torch_mlir() { function build_torch_mlir() {
# Disable LTC build for releases
export TORCH_MLIR_ENABLE_LTC=0
local torch_version="$1" local torch_version="$1"
case $torch_version in case $torch_version in
nightly) nightly)
@ -440,7 +442,7 @@ function build_torch_mlir() {
--extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
CMAKE_GENERATOR=Ninja \ CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir \ python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir \
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \
-r /main_checkout/torch-mlir/whl-requirements.txt -r /main_checkout/torch-mlir/whl-requirements.txt
;; ;;
@ -450,7 +452,7 @@ function build_torch_mlir() {
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
CMAKE_GENERATOR=Ninja \ CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir
;; ;;
*) *)
echo "Unrecognized torch version '$torch_version'" echo "Unrecognized torch version '$torch_version'"
@ -474,7 +476,7 @@ function build_torch_mlir_core() {
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \ TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \ TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir
} }
function clean_wheels() { function clean_wheels() {

View File

@ -2,6 +2,7 @@
See https://github.com/llvm/torch-mlir/issues/1374 See https://github.com/llvm/torch-mlir/issues/1374
""" """
import argparse import argparse
import json import json

@ -1 +1 @@
Subproject commit a952c123880eb1168f1021b116485e27170d48ca Subproject commit 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf

2
externals/stablehlo vendored

@ -1 +1 @@
Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97 Subproject commit ab92adeda9119a6c3914cd42367b0a2b70765e91

View File

@ -30,8 +30,6 @@ namespace detail {
LogicalResult verifyTMTensorOpInterface(Operation *op); LogicalResult verifyTMTensorOpInterface(Operation *op);
} }
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
/// Include the generated interface declarations. /// Include the generated interface declarations.
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
@ -39,4 +37,6 @@ LogicalResult verifyTMTensorOpInterface(Operation *op);
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ #endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_

View File

@ -97,6 +97,31 @@ struct OpBinder {
return success(); return success();
} }
ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
for (auto result : op->getResults()) {
auto t = toValidTensorType(result.getType());
if (!t)
return failure();
typeList.push_back(t);
}
return success();
}
// The importer imports Onnx.GraphProto attributes as regions attached to the
// op.
ParseResult getRegionAtIndex(mlir::Region *&region, int64_t idx) {
if (idx >= op->getNumRegions())
return failure();
region = &op->getRegion(idx);
if (region == nullptr) {
return failure();
}
return success();
}
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
int64_t idx) { int64_t idx) {
if (idx >= op->getNumResults()) if (idx >= op->getNumResults())

View File

@ -38,6 +38,13 @@ Value createConstantIntList(OpBinder binder,
Type getQTorchTypeFromTorchIntType(Type ty); Type getQTorchTypeFromTorchIntType(Type ty);
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Value &ofItem) {
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
rewriter.getType<T>(), ofItem);
}
LogicalResult OnnxLstmExpander(OpBinder binder, LogicalResult OnnxLstmExpander(OpBinder binder,
ConversionPatternRewriter &rewriter); ConversionPatternRewriter &rewriter);

View File

@ -69,6 +69,17 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, ArrayRef<int64_t> inputUnsqzDims, Value tensor, ArrayRef<int64_t> inputUnsqzDims,
size_t dimSizeIndexBits); size_t dimSizeIndexBits);
// Get a tensor that collapse the specified dimensions of the input tensor
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits);
// Get a tensor that splits the specified dimensions of the input tensor
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits);
Value getConstantOfShape(PatternRewriter &rewriter, Location loc, Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType); TensorType outType);

View File

@ -4810,6 +4810,53 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [
}]; }];
} }
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCeluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCelu_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenRealOp : Torch_Op<"aten.real", [ def Torch_AtenRealOp : Torch_Op<"aten.real", [
AllowsTypeRefinement, AllowsTypeRefinement,
ReadOnly ReadOnly
@ -6590,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
}]; }];
} }
def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenMaxPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -6645,6 +6720,7 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices",
printDefaultTorchOp(printer, *this, 6, 2); printDefaultTorchOp(printer, *this, 6, 2);
} }
}]; }];
let hasCanonicalizer = 1;
} }
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [ def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
@ -13601,6 +13677,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [
}]; }];
} }
def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`";
let arguments = (ins
AnyTorchListOfTorchStringType:$l,
Torch_StringType:$item
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -15904,6 +16005,7 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_PrimsVarOp : Torch_Op<"prims.var", [ def Torch_PrimsVarOp : Torch_Op<"prims.var", [

View File

@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl<bool> &bind_values) {
return detail::torch_list_of_constant_bools_op_binder(bind_values); return detail::torch_list_of_constant_bools_op_binder(bind_values);
} }
namespace detail {
/// Matches the constant strs stored in a `torch.ListConstruct`.
struct torch_list_of_constant_strs_op_binder {
SmallVectorImpl<std::string> &bind_values;
/// Creates a matcher instance that binds the value to bvs if match succeeds.
torch_list_of_constant_strs_op_binder(SmallVectorImpl<std::string> &bvs)
: bind_values(bvs) {}
bool match(Operation *op) {
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
if (!listConstruct)
return false;
for (Value value : listConstruct.getElements()) {
std::string str;
if (matchPattern(value, m_TorchConstantStr(str)))
bind_values.push_back(str);
else
return false;
}
return true;
}
};
} // namespace detail
/// Matches the constant strs stored in a `torch.prim.ListConstruct`.
inline detail::torch_list_of_constant_strs_op_binder
m_TorchListOfConstantStrs(SmallVectorImpl<std::string> &bind_values) {
return detail::torch_list_of_constant_strs_op_binder(bind_values);
}
namespace detail { namespace detail {
/// Matches the expected tensor and dim from `torch.aten.size.int`. /// Matches the expected tensor and dim from `torch.aten.size.int`.
struct torch_tensor_size_int_op_binder { struct torch_tensor_size_int_op_binder {

View File

@ -35,6 +35,108 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
return success(); return success();
} }
namespace {
LogicalResult windowFunctionImpl(OpBinder binder,
ConversionPatternRewriter &rewriter,
Value size, Value a0, Value a1, Value a2,
Torch::ValueTensorType resultType,
int64_t output_datatype, int64_t periodic) {
Location loc = binder.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
double isPeriodicFp = static_cast<double>(periodic);
Value zero = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(0.0));
Value one = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(1.0));
Value two = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2.0));
constexpr double pi = llvm::numbers::pi;
Value tau = b.create<Torch::ConstantFloatOp>(
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
Value noneVal = b.create<Torch::ConstantNoneOp>();
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
Value float32Type = b.create<Torch::ConstantIntOp>(
rewriter.getI64IntegerAttr(/*float32Type*/ 6));
// Create an f32 ValueTensorType with thse same size as size, the
// operand
auto shapeOfOperand =
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
shapeOfOperand, rewriter.getF32Type());
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal);
Value symmetricSizeFloat = b.create<Torch::AtenSubScalarOp>(
periodicSizeFloat.getType(), periodicSizeFloat, one, one);
Value isPeriodic =
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(isPeriodicFp));
Value isSymmetricFloat = b.create<Torch::ConstantFloatOp>(
rewriter.getF64FloatAttr(1.0 - isPeriodicFp));
Value periodicComponent = b.create<Torch::AtenMulScalarOp>(
periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic);
Value symmetricComponent = b.create<Torch::AtenMulScalarOp>(
symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat);
Value sizeFloat = b.create<Torch::AtenAddTensorOp>(
symmetricComponent.getType(), symmetricComponent, periodicComponent, one);
// Here, size can be used in the place of periodicSizeFloat, as the
// latter is just a float representation of the former.
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
Value rangeArr = b.create<Torch::AtenArangeStartStepOp>(
resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal);
Value rangeTimesTau =
b.create<Torch::AtenMulScalarOp>(resultType, rangeArr, tau);
Value rangeAngular =
b.create<Torch::AtenDivTensorOp>(resultType, rangeTimesTau, sizeFloat);
Value twoRangeAngular =
b.create<Torch::AtenMulScalarOp>(resultType, rangeAngular, two);
Value cosRangeAngular = b.create<Torch::AtenCosOp>(resultType, rangeAngular);
Value cosTwoRangeAngular =
b.create<Torch::AtenCosOp>(resultType, twoRangeAngular);
Value a1Component =
b.create<Torch::AtenMulScalarOp>(resultType, cosRangeAngular, a1);
Value a2Component =
b.create<Torch::AtenMulScalarOp>(resultType, cosTwoRangeAngular, a2);
// AtenSubScalarOp actually requires a tensor operand as the LHS, that
// is, operand #1. Therefore, to avoid errors, the onnx implementation
// has been modified. a1 has been changed to negative half, and the
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
// operation is commutative.
Value subA1Component =
b.create<Torch::AtenAddScalarOp>(resultType, a1Component, a0, one);
Value result = b.create<Torch::AtenAddTensorOp>(resultType, subA1Component,
a2Component, one);
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(output_datatype);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for the given dtype conversion");
}
Value outputDtype = b.create<Torch::ConstantIntOp>(
rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch.value()));
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, result, outputDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/noneVal);
return success();
}
} // namespace
// Simple rewrites for the default domain. // Simple rewrites for the default domain.
// See: https://onnx.ai/onnx/operators/ // See: https://onnx.ai/onnx/operators/
// For operators that are effectively version invariant, we register with // For operators that are effectively version invariant, we register with
@ -2186,4 +2288,65 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone); binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
return success(); return success();
}); });
patterns.onOp(
"BlackmanWindow", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value size;
Torch::ValueTensorType resultType;
int64_t periodic, output_datatype;
if (binder.tensorOperand(size) ||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
auto windowFunctionResult =
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
output_datatype, periodic);
if (failed(windowFunctionResult))
return failure();
return success();
});
patterns.onOp(
"HannWindow", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value size;
Torch::ValueTensorType resultType;
int64_t periodic, output_datatype;
if (binder.tensorOperand(size) ||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.5));
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
auto windowFunctionResult =
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
output_datatype, periodic);
if (failed(windowFunctionResult))
return failure();
return success();
});
} }

View File

@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
alignCorners); alignCorners);
return success(); return success();
}); });
patterns.onOp(
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value conditionTensor;
if (binder.tensorOperand(conditionTensor)) {
return rewriter.notifyMatchFailure(binder.op,
"condition bind failure");
}
auto conditionType =
conditionTensor.getType().cast<Torch::ValueTensorType>();
if (!conditionType || conditionType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(
binder.op, "condition must have one single element per "
"https://onnx.ai/onnx/operators/onnx__If.html");
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
conditionTensor);
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
llvm::SmallVector<mlir::Type> resultTypes;
if (binder.tensorResultTypes(resultTypes)) {
return rewriter.notifyMatchFailure(binder.op,
"result type bind failure");
}
Region *thenRegion, *elseRegion;
if (binder.getRegionAtIndex(elseRegion, 0) ||
binder.getRegionAtIndex(thenRegion, 1)) {
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
}
auto primIfOp = rewriter.create<Torch::PrimIfOp>(
binder.getLoc(), TypeRange(resultTypes), conditionBool);
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
};
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
auto replaceTerminator = [&](Region &region) {
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = region.front().getTerminator();
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
terminator, terminator->getOperands());
};
replaceTerminator(primIfOp.getThenRegion());
replaceTerminator(primIfOp.getElseRegion());
rewriter.replaceOp(binder.op, primIfOp.getResults());
return success();
});
patterns.onOp("Less", 13, patterns.onOp("Less", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;

View File

@ -31,15 +31,7 @@ using namespace mlir::torch::onnx_c;
// thing here, so we simplify. // thing here, so we simplify.
// utilities // utilities
// Templatized function to get an item op of a type
namespace { namespace {
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Value &ofItem) {
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
rewriter.getType<T>(), ofItem);
}
// In case the ReduceSum Op was not the first operation performed on the data, // In case the ReduceSum Op was not the first operation performed on the data,
// we provide the original operand through storeResult, which will be modified // we provide the original operand through storeResult, which will be modified
// if the result will be passed onto another operation, and will be used for // if the result will be passed onto another operation, and will be used for
@ -847,12 +839,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
patterns.onOp( patterns.onOp(
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { "Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
float alpha, gamma; float alpha, gamma;
Value operand; Value operand;
// Refer https://onnx.ai/onnx/operators/onnx__Selu.html for the default
// alpha and gamma values.
if (binder.tensorOperand(operand) || if (binder.tensorOperand(operand) ||
binder.f32FloatAttr(alpha, "alpha") || binder.f32FloatAttr(alpha, "alpha", 1.67326) ||
binder.f32FloatAttr(gamma, "gamma") || binder.f32FloatAttr(gamma, "gamma", 1.0507) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
@ -945,22 +940,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*memory_format=*/noneVal); /*memory_format=*/noneVal);
return success(); return success();
}); });
patterns.onOp("ReduceSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();
return reducedSumImpl(binder, rewriter, data, resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp("ReduceLogSum", 1, patterns.onOp("ReduceLogSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
@ -987,6 +966,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, data); binder.op, resultType, data);
return success(); return success();
}); });
patterns.onOp("ReduceSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();
return reducedSumImpl(binder, rewriter, data, resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp("ReduceSumSquare", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();
Value dataSquare = rewriter.create<Torch::AtenMulTensorOp>(
binder.getLoc(), data.getType(), data, data);
return reducedSumImpl(binder, rewriter, dataSquare,
resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp( patterns.onOp(
"ReduceMean", 1, "ReduceMean", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
if (!isUnsignedType) if (!isUnsignedType)
return; return;
int64_t minSI = -(1 << (numBits - 1)); int64_t minSI = -(1 << (numBits - 1));
Value minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, 32); Value minSIValue = rewriter.create<arith::ConstantIntOp>(
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue); zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits); minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric( arg = torch_to_linalg::createElementwiseLinalgGeneric(
@ -797,6 +798,8 @@ public:
auto resultTy = cast<ValueTensorType>(op.getType()); auto resultTy = cast<ValueTensorType>(op.getType());
Value inputZp, weightZp; Value inputZp, weightZp;
bool inputUnsigned = false;
bool weightUnsigned = false;
if (auto make = op.getInput() if (auto make = op.getInput()
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) { .getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
input = make.getSelf(); input = make.getSelf();
@ -806,6 +809,8 @@ public:
inputZp = typeConverter->materializeTargetConversion( inputZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(inputZp.getType()), rewriter, loc, typeConverter->convertType(inputZp.getType()),
inputZp); inputZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
} }
if (auto make = op.getWeight() if (auto make = op.getWeight()
@ -818,6 +823,8 @@ public:
weightZp = typeConverter->materializeTargetConversion( weightZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(weightZp.getType()), rewriter, loc, typeConverter->convertType(weightZp.getType()),
weightZp); weightZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
} }
if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) { if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
@ -916,15 +923,35 @@ public:
SmallVector<Value> strideIntValues = SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts); getAsConstantIntValues(rewriter, loc, strideInts);
// convert any uint8 quantization to int8 quantization
if (auto integerType = dyn_cast<mlir::IntegerType>(inputDTy)) {
int64_t width = integerType.getWidth();
signShift(rewriter, loc, input, inputZp, inputUnsigned, width);
}
if (auto integerType = dyn_cast<mlir::IntegerType>(weightDTy)) {
int64_t width = integerType.getWidth();
signShift(rewriter, loc, weight, weightZp, weightUnsigned, width);
}
// Pad the input tensor according to padding. // Pad the input tensor according to padding.
SmallVector<Value> outDims{inBatch, weightBatch}; SmallVector<Value> outDims{inBatch, weightBatch};
Value paddedInput; Value paddedInput;
if (transposed) { Value pad = inputZp;
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) || if (!pad) {
!isa<mlir::FloatType>(resultDTy)) if (isa<mlir::FloatType>(inputDTy))
return rewriter.notifyMatchFailure( pad = rewriter.create<arith::ConstantOp>(
op, "transpose does not support non-fp type yet"); op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}
if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}
if (transposed) {
Value c0 = Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value c1 = Value c1 =
@ -994,7 +1021,7 @@ public:
// Allocate padded input tensor // Allocate padded input tensor
Value initTensor = Value initTensor =
createZeroInitTensor(rewriter, loc, outerSizes, inputDTy); createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
// Insert input into allocated tensor // Insert input into allocated tensor
SmallVector<Value> strideIndexValues{c1, c1}; SmallVector<Value> strideIndexValues{c1, c1};
@ -1017,24 +1044,6 @@ public:
strideInts.clear(); strideInts.clear();
strideInts.append(numSpatialDims, 1); strideInts.append(numSpatialDims, 1);
} else { } else {
Value pad = inputZp;
if (!pad) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}
if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}
// Pad input // Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);

View File

@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
auto constType = RankedTensorType::get({}, elementTy); auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling // Avg pooling
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
AtenCumsumOp>(op)) { AtenAvgPool3dOp, AtenCumsumOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero( constType, {APFloat::getZero(
@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
} }
// Max pooling // Max pooling
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) { if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
AtenMaxPool2dWithIndicesOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
return nullptr; return nullptr;
} }
// AtenMaxPool2dOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto outTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (inputRank <= 2) {
return op.emitError(
"max_pooling2d only supports inputs with rank higher than 2");
}
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &block = reduceWindowOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArg = block.args_begin();
auto secondArg = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result =
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
// AtenMaxPool2dWithIndicesOp // AtenMaxPool2dWithIndicesOp
template <> template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
return success(); return success();
} }
namespace {
template <typename AtenOpT, int Dim>
class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
public:
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto outTy = cast<RankedTensorType>(
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
if (inputRank <= Dim) {
return op.emitError(
"max_pooling1d/2d only supports inputs with rank higher than 1/2");
}
SmallVector<int64_t, Dim> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op,
"non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(
op, "non-const bool ceil_mode unsupported!");
}
if (stride.empty()) {
stride = kernelSize;
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank
// as input
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - Dim);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - Dim);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - Dim);
Value initVal =
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
if (Dim == 1) {
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
} else if (Dim == 2) {
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
} else if (Dim == 3) {
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
} else {
assert(false && "Unsupported pooling dimension");
}
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &block = reduceWindowOp.getBody().emplaceBlock();
// Add bb argument
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
block.addArgument(blockArgumentType, op->getLoc());
block.addArgument(blockArgumentType, op->getLoc());
auto *firstArg = block.args_begin();
auto secondArg = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result = rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg,
*secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
};
} // namespace
namespace { namespace {
template <typename AtenOpT, int Dim> template <typename AtenOpT, int Dim>
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> { class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
@ -375,8 +404,8 @@ public:
auto outShape = outTy.getShape(); auto outShape = outTy.getShape();
if (inputRank <= Dim) { if (inputRank <= Dim) {
return op.emitError( return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank "
"avg_pooling1d/2d only supports inputs with rank higher than 1/2"); "higher than 1/2/3");
} }
SmallVector<int64_t, Dim> padding, kernelSize, stride; SmallVector<int64_t, Dim> padding, kernelSize, stride;
bool ceilMode = false; bool ceilMode = false;
@ -405,6 +434,10 @@ public:
op, "non-const bool count_include_pad unsupported!"); op, "non-const bool count_include_pad unsupported!");
} }
if (stride.empty()) {
stride = kernelSize;
}
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) { if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -425,11 +458,20 @@ public:
if (Dim == 1) { if (Dim == 1) {
stablehloPadding[stablehloPadding.size() - 2] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[0];
stablehloPadding[stablehloPadding.size() - 1] = padding[0]; stablehloPadding[stablehloPadding.size() - 1] = padding[0];
} else { } else if (Dim == 2) {
stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1];
} else if (Dim == 3) {
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
} else {
assert(false && "Unsupported pooling dimension");
} }
Value initVal = Value initVal =
@ -474,10 +516,17 @@ public:
divisor = divisor =
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {}) hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
.value(); .value();
} else { } else if (Dim == 2) {
divisor = hlo::getConstTensor<int64_t>( divisor = hlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value(); .value();
} else if (Dim == 3) {
divisor = hlo::getConstTensor<int64_t>(
rewriter, op,
{kernelSize[0] * kernelSize[1] * kernelSize[2]}, {})
.value();
} else {
assert(false && "Unsupported pooling dimension");
} }
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenAvgPool1dOp>(); #define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options); target.addIllegalOp<AtenOp>(); \
target.addIllegalOp<AtenMaxPool2dOp>(); patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options); INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
target.addIllegalOp<AtenAvgPool2dOp>(); INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options); #undef INSERT_ATEN_POOLING_PATTERN
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter, #define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \
context, options); target.addIllegalOp<AtenOp>(); \
target.addIllegalOp<AtenCumsumOp>(); patterns.add<ConvertAtenMaxPoolOp<AtenOp, Dim>>(typeConverter, context, \
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options); options)
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1);
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2);
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3);
#undef INSERT_ATEN_MAXPOOL_PATTERN
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ #define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \ patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
options) options)
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3);
#undef INSERT_ATEN_AVGPOOL_PATTERN #undef INSERT_ATEN_AVGPOOL_PATTERN
} }

View File

@ -9,6 +9,7 @@
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
@ -306,6 +307,136 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
.getResult(); .getResult();
} }
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto dimSizes = *dimSizesInfo;
int64_t rank = dimSizes.size();
collapseStartDim = toPositiveDim(collapseStartDim, rank);
collapseEndDim = toPositiveDim(collapseEndDim, rank);
int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
newDimSizes.reserve(newRank);
newShape.reserve(newRank);
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
int64_t collapseShape = 1;
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
if (k < 0 || k >= rank) {
return rewriter.notifyMatchFailure(
op, "collapse dimensions must be within the rank of the tensor");
}
if (collapseShape == ShapedType::kDynamic ||
oldShape[k] == ShapedType::kDynamic) {
collapseShape = ShapedType::kDynamic;
} else {
collapseShape *= oldShape[k];
}
collapseDimSize =
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
}
for (int64_t k = 0; k < collapseStartDim; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
newDimSizes.push_back(collapseDimSize);
newShape.push_back(collapseShape);
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}
// TODO: support splitDim & outerLength to be Value
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto dimSizes = *dimSizesInfo;
int64_t rank = dimSizes.size();
splitDim = toPositiveDim(splitDim, rank);
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
if (splitDim < 0 || splitDim >= rank) {
return rewriter.notifyMatchFailure(
op, "split dimensions must be within the rank of the tensor");
}
int64_t newRank = rank + 1;
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, outerLength));
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
loc, dimSizes[splitDim], outerLengthValue);
int64_t originShape = oldShape[splitDim];
int64_t outerShape = outerLength;
int64_t innerShape = originShape == ShapedType::kDynamic
? ShapedType::kDynamic
: originShape / outerLength;
std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
newDimSizes.reserve(newRank);
newShape.reserve(newRank);
for (int64_t k = 0; k < splitDim; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
newDimSizes.push_back(outerLengthValue);
newShape.push_back(outerShape);
newDimSizes.push_back(innerLengthValue);
newShape.push_back(innerShape);
for (int64_t k = splitDim + 1; k < rank; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}
Value getConstantOfShape(PatternRewriter &rewriter, Location loc, Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType) { TensorType outType) {

View File

@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only constant end is currently supported"); op, "only constant end is currently supported");
start = toPositiveDim(start, rank); auto collapseTensorInfo = hlo::collapseTensor(
end = toPositiveDim(end, rank); rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
SmallVector<int64_t, 4> dims; if (failed(collapseTensorInfo))
dims.reserve(rank); return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
for (int r = 0; r < start; ++r)
dims.push_back(r);
int64_t collapsedDimSize = 1;
for (int r = start; r <= end; ++r) {
if (selfType.getShape()[r] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
op, "the size of the dimension being collapsed is can't be unknown");
collapsedDimSize *= selfType.getShape()[r];
}
dims.push_back(collapsedDimSize);
for (int r = end + 1; r < rank; ++r)
dims.push_back(r);
auto newDimSizesInfo = hlo::getDimSizesOfTensor( rewriter.replaceOp(op, *collapseTensorInfo);
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits); return success();
if (failed(newDimSizesInfo)) }
template <>
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
PrimsSplitDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
if (!selfType) {
return op.emitError("only tensor types are currently supported");
}
auto rank = selfType.getRank();
if (rank == 0)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "the rank of tensor must be greater than 0");
auto newDimSizes = *newDimSizesInfo;
auto stablehloShape = int64_t dim, outerLength;
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes); if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>( return rewriter.notifyMatchFailure(
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), op, "only constant dim is currently supported");
stablehloShape); if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
return rewriter.notifyMatchFailure(
op, "only constant outerLength is currently supported");
auto splitTensorInfo = hlo::splitTensor(
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
if (failed(splitTensorInfo))
return rewriter.notifyMatchFailure(op, "failed to create split tensor");
rewriter.replaceOp(op, *splitTensorInfo);
return success(); return success();
} }
@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp); INSERT_ATENOP_PATTERN(PrimsCollapseOp);
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_VIEW_OP_PATTERN(AtenOp) \ #define INSERT_VIEW_OP_PATTERN(AtenOp) \

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return b.create<arith::ExtSIOp>(loc, dtype, scalar); return b.create<arith::ExtSIOp>(loc, dtype, scalar);
} }
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
auto dtypeElemType = dtypeComplex.getElementType();
// Extract the real and imaginary parts of the scalar.
// Cast them to the target element type, and create a new complex
// value with the target complex type.
Value realVal = b.create<complex::ReOp>(loc, scalar);
Value imgVal = b.create<complex::ImOp>(loc, scalar);
realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType);
imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType);
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
}
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype
<< "(dtype)";
}
llvm_unreachable("convertScalarToDtype should handle all the types"); llvm_unreachable("convertScalarToDtype should handle all the types");
} }

View File

@ -936,7 +936,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
// If no operand comes from a tensor::CastOp and can be folded then fail. // If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand = bool hasTensorCastOperand =
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
if (opOperand->get().isa<BlockArgument>()) if (isa<BlockArgument>(opOperand->get()))
return false; return false;
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>(); auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp); return castOp && canFoldIntoConsumerOp(castOp);

View File

@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc,
return nullptr; return nullptr;
Type inputDtype = inputTensorType.getOptionalDtype(); Type inputDtype = inputTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(64)) if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1)))
return nullptr; return nullptr;
std::optional<unsigned> inputRank = getTensorRank(input); std::optional<unsigned> inputRank = getTensorRank(input);
@ -148,10 +148,19 @@ static Value getScalarIntValue(Value input, Location loc,
return nullptr; return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) { if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue()) if (inputDtype.isInteger(64)) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
.getSplatValue<int64_t>(); .getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>( return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val)); loc, rewriter.getI64IntegerAttr(val));
} else {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
.getSplatValue<bool>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
}
} else if (auto primNumToTensorScalarOp = } else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) { input.getDefiningOp<PrimNumToTensorScalarOp>()) {
return primNumToTensorScalarOp.getA(); return primNumToTensorScalarOp.getA();
@ -2385,6 +2394,30 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// Aten__Contains__StrListOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) {
StringAttr item = dyn_cast<StringAttr>(adaptor.getItem());
if (!item)
return nullptr;
if (auto listConstruct = getL().getDefiningOp<Torch::PrimListConstructOp>()) {
if (isListPotentiallyMutated(listConstruct))
return nullptr;
}
llvm::SmallVector<std::string> strs;
if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) {
for (const auto &str : strs) {
if (item.getValue().str() == str)
return getI1IntegerAttr(getContext(), true);
}
return getI1IntegerAttr(getContext(), false);
}
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenLtIntOp // AtenLtIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -4682,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// PrimsConvertElementTypeOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
auto inputType = cast<BaseTensorType>(getA().getType());
auto outputType = cast<BaseTensorType>(getResult().getType());
if (inputType != outputType)
return nullptr;
if (!inputType.hasDtype() || !outputType.hasDtype())
return nullptr;
if (inputType.getDtype() != outputType.getDtype())
return nullptr;
return getA();
}
//===----------------------------------------------------------------------===//
// AtenMaxPool2dWithIndicesOp
//===----------------------------------------------------------------------===//
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
if (!op.getResult1().use_empty()) {
return rewriter.notifyMatchFailure(
op, "result1 of MaxPool2dWithIndices should be unused");
}
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
op->getLoc(), op.getResult0().getType(), op.getSelf(),
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
op.getCeilMode());
op.getResult0().replaceAllUsesWith(result);
rewriter.eraseOp(op);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenLinalgCrossOp // AtenLinalgCrossOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -6998,6 +6998,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.celu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -7841,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %arg2 : !torch.list<int>\n" " return %arg2 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n" " %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n" " func.func @__torch__.pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n" " %int-1 = torch.constant.int -1\n"
" %int-2 = torch.constant.int -2\n" " %int-2 = torch.constant.int -2\n"
" %int-3 = torch.constant.int -3\n" " %int-3 = torch.constant.int -3\n"
" %str = torch.constant.str \"AssertionError: \"\n" " %str = torch.constant.str \"AssertionError: \"\n"
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n" " %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n"
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n" " %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n"
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n" " %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n"
" %int1 = torch.constant.int 1\n" " %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n" " %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n" " %int2 = torch.constant.int 2\n"
@ -7936,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %23 : !torch.list<int>\n" " return %23 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -10480,6 +10488,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.celu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n" " %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -1059,44 +1059,44 @@ public:
LogicalResult matchAndRewrite(AtenEyeMOp op, LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
int64_t n; auto outType = op.getType().dyn_cast<BaseTensorType>();
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
int64_t m;
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
return rewriter.notifyMatchFailure(op,
"unimplemented: m must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc);
auto outType = dyn_cast<BaseTensorType>(op.getType());
if (!outType) if (!outType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
if (!outType.hasDtype()) { if (!outType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
if (n < 0) { Value none = rewriter.create<ConstantNoneOp>(loc);
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
}
if (m < 0) {
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
}
auto context = op.getContext(); auto context = op.getContext();
auto int64Dtype = getDtypeIntValueForType( auto int64Dtype = getDtypeIntValueForType(
rewriter, loc, rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
int64_t n = kUnknownSize;
int64_t m = kUnknownSize;
// prioritize getting shape from output shape
if (outType.hasSizes() && outType.getSizes().size() == 2) {
n = outType.getSizes().front();
m = outType.getSizes().back();
}
// if output shape is not available, try to get shape from input
if (n == kUnknownSize)
matchPattern(op.getN(), m_TorchConstantInt(&n));
if (m == kUnknownSize)
matchPattern(op.getM(), m_TorchConstantInt(&m));
// prepare two unsqueezed ranges that are equal on and only on the diagonal
auto rangeNSize = llvm::SmallVector<int64_t, 1>({n});
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
Value rangeN = rewriter.create<AtenArangeOp>( Value rangeN = rewriter.create<AtenArangeOp>(
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/op.getDevice(), /*pin_memory=*/none); /*device=*/op.getDevice(), /*pin_memory=*/none);
auto arangeType1 = auto rangeMSize = llvm::SmallVector<int64_t, 1>({m});
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
Value rangeM = rewriter.create<AtenArangeOp>( Value rangeM = rewriter.create<AtenArangeOp>(
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none); /*device=*/none, /*pin_memory=*/none);
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>( Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
@ -1109,7 +1109,6 @@ public:
} }
Value unsqzRangeN = *unsqzTensorInfo; Value unsqzRangeN = *unsqzTensorInfo;
// compare unsqueezed input with boundaries
auto eqType = ValueTensorType::get( auto eqType = ValueTensorType::get(
context, cast<BaseTensorType>(op.getType()).getSizes(), context, cast<BaseTensorType>(op.getType()).getSizes(),
IntegerType::get(context, 1)); IntegerType::get(context, 1));
@ -2415,6 +2414,50 @@ public:
} // namespace } // namespace
// CELU(x)=max(0,x)+min(0,alpha(exp(x/alpha)1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCeluOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
Value alpha = op.getAlpha();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value constantOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
// positiveOutput = max(0,x)
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
// negativeOutput = min(0,alpha(exp(x/alpha)1))
Value scaledInput =
rewriter.create<AtenDivScalarOp>(loc, resType, input, alpha);
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledInput);
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
constantOne, constantOne);
Value scaledExpXM1 =
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, alpha);
Value negativeOutput =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledExpXM1);
Value celuOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOne);
rewriter.replaceOp(op, celuOutput);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> { class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
public: public:
@ -7705,6 +7748,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);

View File

@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>(); target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>(); target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenPreluOp>(); target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenCeluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>(); target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>(); target.addIllegalOp<AtenToDeviceOp>();
target.addIllegalOp<AtenToPrimDeviceOp>(); target.addIllegalOp<AtenToPrimDeviceOp>();

View File

@ -272,6 +272,7 @@ TORCHDYNAMO_XFAIL_SET = {
"QuantizedReluInt8_basic", "QuantizedReluInt8_basic",
"QuantizedReluUint8_basic", "QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc # Dynamo not supporting conv_tbc
"ConvTbcModule_basic", "ConvTbcModule_basic",
"FloatImplicitModule_basic", "FloatImplicitModule_basic",
@ -372,6 +373,7 @@ FX_IMPORTER_XFAIL_SET = {
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"ConvTbcModule_basic", "ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2D_basic",
@ -544,6 +546,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ContainsIntList_True", "ContainsIntList_True",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"ConvTbcModule_basic", "ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2D_basic",
@ -572,6 +575,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseErfIntModule_basic", "ElementwiseErfIntModule_basic",
"ElementwiseLogitModule_basic", "ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic", "ElementwiseReciprocalIntModule_basic",
@ -678,11 +682,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"NumToTensorIntModule_basic", "NumToTensorIntModule_basic",
"NumelModule_basic", "NumelModule_basic",
"NumelZeroRankModule_basic", "NumelZeroRankModule_basic",
"PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"PixelShuffleModuleSpatiallyStatic_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic", "PowIntFloatModule_basic",
"PrimMaxIntModule_basic", "PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic", "PrimMinIntDynamicModule_basic",
@ -951,6 +950,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseCeilModule_basic", "ElementwiseCeilModule_basic",
"ElementwiseCeluStaticModule_basic",
"ElementwiseClampMaxModule_basic", "ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic", "ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic", "ElementwiseClampMinTensorFloatModule_basic",
@ -1079,6 +1079,9 @@ STABLEHLO_PASS_SET = {
"Matmul_vecmat", "Matmul_vecmat",
"MatmulStaticBroadcast_basic", "MatmulStaticBroadcast_basic",
"MaxPool2dStaticModule_basic", "MaxPool2dStaticModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool3dStaticModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MeanDimAllReduceModule_basic", "MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic", "MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic", "MeanDimNoneDimModule_basic",
@ -1156,6 +1159,8 @@ STABLEHLO_PASS_SET = {
"Permute0RankModule_basic", "Permute0RankModule_basic",
"PermuteModule_basic", "PermuteModule_basic",
"PermuteNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic", "PowIntFloatModule_basic",
"PrimListUnpackNumMismatchModule_basic", "PrimListUnpackNumMismatchModule_basic",
"PrimMaxIntModule_basic", "PrimMaxIntModule_basic",
@ -1239,6 +1244,7 @@ STABLEHLO_PASS_SET = {
"SliceWholeTensorModule_basic", "SliceWholeTensorModule_basic",
"SortIntListReverse_basic", "SortIntListReverse_basic",
"SortIntList_basic", "SortIntList_basic",
"SplitDimStaticModule_basic",
"SplitTensorGetItem_Module_basic", "SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic", "SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic", "SplitTensorListUnpackModule_basic",
@ -1571,6 +1577,8 @@ TOSA_PASS_SET = {
"ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseCeilModule_basic", "ElementwiseCeilModule_basic",
"ElementwiseCeluModule_basic",
"ElementwiseCeluStaticModule_basic",
"ElementwiseClampMaxModule_basic", "ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic", "ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic", "ElementwiseClampModule_basic",
@ -1916,11 +1924,6 @@ MAKE_FX_TOSA_PASS_SET = (
# Dynamic shape, has extra unsupported broadcast ops # Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d", "Matmul_3d",
"MatmulStaticBroadcast_basic", "MatmulStaticBroadcast_basic",
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool2dStaticCeilModeTrueModule_basic",
"MaxPool2dStaticModule_basic",
"ResNet18StaticModule_basic",
# Unimplemented operator 'aten._index_put_impl_.hacked_twin' # Unimplemented operator 'aten._index_put_impl_.hacked_twin'
"IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic",
@ -2096,6 +2099,7 @@ LTC_XFAIL_SET = {
"ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv2dQInt8Module_basic", "Conv2dQInt8Module_basic",
"ConvTranspose2DQInt8_basic",
} }
ONNX_XFAIL_SET = { ONNX_XFAIL_SET = {
@ -2121,7 +2125,6 @@ ONNX_XFAIL_SET = {
"ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseLog10IntModule_basic", "ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic", "ElementwiseLog2IntModule_basic",
"ElementwiseSeluModule_basic",
"FlipModuleStaticShape_basic", "FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic", "FlipNegativeIndexModule_basic",
"HardsigmoidModule_basic", "HardsigmoidModule_basic",
@ -2251,6 +2254,7 @@ ONNX_XFAIL_SET = {
"Conv2dWithPaddingModule_basic", "Conv2dWithPaddingModule_basic",
"Conv3dModule_basic", "Conv3dModule_basic",
"ConvTbcModule_basic", "ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic", "Conv_Transpose2dModule_basic",
"Convolution2DModule_basic", "Convolution2DModule_basic",
"Convolution2DStridedModule_basic", "Convolution2DStridedModule_basic",
@ -2306,6 +2310,7 @@ ONNX_XFAIL_SET = {
"ElementwiseExpm1Module_basic", "ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Int_basic", "ElementwiseFmodTensor_Int_basic",
"ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseOrTensorModule_basic", "ElementwiseOrTensorModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic",
@ -2554,16 +2559,12 @@ ONNX_XFAIL_SET = {
"_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic",
"_SoftmaxModule_basic", "_SoftmaxModule_basic",
# Failure - onnx_import
# Failure - onnx_lowering: onnx.AveragePool # Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
# Failure - onnx_lowering: onnx.If # these diagonal modules are currently failing due to dynamic shape.
"DiagonalModule_basic", # We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
"DiagonalModule_nonsquare", # when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
"DiagonalModule_transposed",
"DiagonalModule_with_dims",
"DiagonalModule_with_dims_and_offset",
"DiagonalModule_with_negative_dims",
"DiagonalModule_with_offset",
"TileBigDimsSizeModule_basic", "TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic", "TileSmallDimsSizeModule_basic",
# Failure - onnx_lowering: onnx.MaxPool # Failure - onnx_lowering: onnx.MaxPool
@ -2634,8 +2635,6 @@ ONNX_XFAIL_SET = {
"CopyWithDifferentDTypesModule_basic", "CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticBroadcastModule_basic",
"CumsumInputDtypeInt32Module_basic", "CumsumInputDtypeInt32Module_basic",
"DropoutTrainModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseAcosIntModule_basic", "ElementwiseAcosIntModule_basic",
"ElementwiseAsinIntModule_basic", "ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanTensorIntModule_basic",

View File

@ -3,6 +3,7 @@ from torch_mlir import torchscript
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
# Wrap the bert model to avoid multiple returns problem # Wrap the bert model to avoid multiple returns problem
class BertTinyWrapper(torch.nn.Module): class BertTinyWrapper(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -257,9 +257,9 @@ class _FXGraphImporter:
# FakeTensor's in case of a tuple return with multiple elements. # FakeTensor's in case of a tuple return with multiple elements.
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {} self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
self._module = ir.Module.create(ir.Location.unknown()) self._module = ir.Module.create(ir.Location.unknown())
self._module.operation.attributes[ self._module.operation.attributes["torch.debug_module_name"] = (
"torch.debug_module_name" ir.StringAttr.get(func_name)
] = ir.StringAttr.get(func_name) )
function_type = _extract_function_type_from_graph(g) function_type = _extract_function_type_from_graph(g)
func = func_dialect.FuncOp( func = func_dialect.FuncOp(
func_name, func_name,

View File

@ -526,6 +526,9 @@ def atenelu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu
def atenprelu〡shape(self: List[int], weight: List[int]) -> List[int]: def atenprelu〡shape(self: List[int], weight: List[int]) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
def atencelu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
return upstream_shape_functions.unary(self)
def atenselu〡shape(self: List[int]) -> List[int]: def atenselu〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
@ -958,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
# TODO: This should be upstreamed. # TODO: This should be upstreamed.
# See https://github.com/pytorch/pytorch/pull/76889 for an example. # See https://github.com/pytorch/pytorch/pull/76889 for an example.
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool): def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool):
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int" assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int"
kL = kernel_size[0] kL = kernel_size[0]
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int" assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int"
dL = kL if len(stride) == 0 else stride[0] dL = kL if len(stride) == 0 else stride[0]
assert len(padding) == 1, "avg_pool1d: padding must be a single int" assert len(padding) == 1, "pool1d: padding must be a single int"
padL = padding[0] padL = padding[0]
dilationL = 1 dilationL = 1
@ -1001,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]):
return shape return shape
def atenavg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]: def atenavg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad) return pool1d(self, kernel_size, stride, padding, ceil_mode)
def atenmax_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
return pool1d(self, kernel_size, stride, padding, ceil_mode)
def atenadaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]: def atenadaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
return adaptive_avg_pool1d(self, output_size) return adaptive_avg_pool1d(self, output_size)
@ -2652,6 +2658,11 @@ def atenprelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tu
assert self_dtype == weight_dtype assert self_dtype == weight_dtype
return self_dtype return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, alpha=1.))
def atencelu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1.) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
def atenrelu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def atenrelu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
(ns, unqual + "_", overload if not is_functional_op else "") (ns, unqual + "_", overload if not is_functional_op else "")
), ),
emitter_td, emitter_td,
traits=["IsTrailingUnderscoreInplaceVariant"] traits=(
if not is_functional_op ["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
else [], ),
) )
# ========================================================================== # ==========================================================================
@ -472,6 +472,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)")
emit("aten::real : (Tensor) -> (Tensor)") emit("aten::real : (Tensor) -> (Tensor)")
emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)")
@ -590,9 +591,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
) )
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit( emit(
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
has_canonicalizer=True,
) )
emit( emit(
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
@ -973,6 +976,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::format : (...) -> (str)") emit("aten::format : (...) -> (str)")
emit("aten::join : (str, str[]) -> (str)") emit("aten::join : (str, str[]) -> (str)")
emit("aten::warn : (str, int) -> ()") emit("aten::warn : (str, int) -> ()")
emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True)
# Type conversion ops. # Type conversion ops.
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
@ -1101,7 +1105,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# `prims::` namespace. # `prims::` namespace.
# ========================================================================== # ==========================================================================
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)") emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True)
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
emit("prims::sqrt : (Tensor) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)")
emit("prims::collapse : (Tensor, int, int) -> (Tensor)") emit("prims::collapse : (Tensor, int, int) -> (Tensor)")

View File

@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
examples = [] examples = []
input_names = [] input_names = []
dynamic_tensors = {} dynamic_tensors = {}
for (index, arg) in enumerate(inputs): for index, arg in enumerate(inputs):
shape = map(lambda d: d if d >= 0 else 1, arg.shape) shape = map(lambda d: d if d >= 0 else 1, arg.shape)
shape = tuple(shape) shape = tuple(shape)
examples.append(torch.zeros(size=shape, dtype=arg.dtype)) examples.append(torch.zeros(size=shape, dtype=arg.dtype))
@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
input_names.append(input_name) input_names.append(input_name)
dynamic_dims = {} dynamic_dims = {}
for (dimindex, dim) in enumerate(arg.shape): for dimindex, dim in enumerate(arg.shape):
if dim < 0: if dim < 0:
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex) dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)

View File

@ -101,11 +101,13 @@ class RefBackendInvoker:
def consume_return_funcs(*args): def consume_return_funcs(*args):
self.result = tuple( self.result = tuple(
[ [
(
arg arg
if type in elemental_type_to_ctype if type in elemental_type_to_ctype
else unranked_memref_to_numpy( else unranked_memref_to_numpy(
arg, memref_type_to_np_dtype[type] arg, memref_type_to_np_dtype[type]
) )
)
for arg, type in zip(args, ret_types) for arg, type in zip(args, ret_types)
] ]
) )
@ -178,6 +180,7 @@ LOWERING_PIPELINE = (
"func.func(tm-tensor-to-loops)", "func.func(tm-tensor-to-loops)",
"func.func(refback-munge-memref-copy)", "func.func(refback-munge-memref-copy)",
"func.func(convert-linalg-to-loops)", "func.func(convert-linalg-to-loops)",
"func.func(expand-realloc)",
"func.func(lower-affine)", "func.func(lower-affine)",
"convert-scf-to-cf", "convert-scf-to-cf",
"func.func(refback-expand-ops-for-llvm)", "func.func(refback-expand-ops-for-llvm)",
@ -191,6 +194,7 @@ LOWERING_PIPELINE = (
"convert-bufferization-to-memref", "convert-bufferization-to-memref",
"finalize-memref-to-llvm", "finalize-memref-to-llvm",
"func.func(convert-arith-to-llvm)", "func.func(convert-arith-to-llvm)",
"convert-vector-to-llvm",
"convert-func-to-llvm", "convert-func-to-llvm",
"convert-cf-to-llvm", "convert-cf-to-llvm",
"convert-complex-to-llvm", "convert-complex-to-llvm",

View File

@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
bias = torch.rand(3) bias = torch.rand(3)
module.forward(inputVec, weight, bias) module.forward(inputVec, weight, bias)
N = 10
Cin = 5
Cout = 7
Hin = 10
Win = 8
Hker = 3
Wker = 2
class ConvTranspose2DQInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, input, weight, bias):
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
qinput = torch.dequantize(qinput)
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
qweight = torch.dequantize(qweight)
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
qbias = torch.dequantize(qbias)
qz = torch.ops.aten.convolution(
qinput,
qweight,
bias=qbias,
stride=[2, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=True,
output_padding=[0, 0],
groups=1,
)
return qz
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
module.forward(
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
torch.rand(Cout),
)

View File

@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class DiagonalWithStaticShapeModule(torch.nn.Module):
"""
Diagonal with static shape. The other diagonal modules are failing in onnx
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
when the shape is static.
Please remove this module and associated test once the issue is fixed.
"""
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 9], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diagonal(a)
@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 9))
# ==============================================================================
class DiagonalTransposedModule(torch.nn.Module): class DiagonalTransposedModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module):
@register_test_case(module_factory=lambda: QuantizedReluInt32()) @register_test_case(module_factory=lambda: QuantizedReluInt32())
def QuantizedReluInt32_basic(module, tu: TestUtils): def QuantizedReluInt32_basic(module, tu: TestUtils):
module.forward( module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
)
# ============================================================================== # ==============================================================================
@ -1016,6 +1014,52 @@ def ElementwisePreluStaticModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseCeluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.celu(x, 0.5)
@register_test_case(module_factory=lambda: ElementwiseCeluModule())
def ElementwiseCeluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))
# ==============================================================================
class ElementwiseCeluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 3], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.celu(x)
@register_test_case(module_factory=lambda: ElementwiseCeluStaticModule())
def ElementwiseCeluStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))
# ==============================================================================
class ElementwiseGeluModule(torch.nn.Module): class ElementwiseGeluModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1795,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
# torch.complex32 is not supported by the refbackend.
class ElementwiseMulTensorComplexDiffModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.complex64, True),
([-1], torch.complex128, True),
]
)
def forward(self, a, b):
return torch.mul(a, b)
@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule())
def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.complex64),
tu.randint(4, high=10).type(torch.complex128),
)
# ==============================================================================
class ElementwiseMishModule(torch.nn.Module): class ElementwiseMishModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1). # For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index). # For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
class SliceScatterModule(torch.nn.Module): class SliceScatterModule(torch.nn.Module):

View File

@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK: module attributes {torch.debug_module_name = "TestModule"} # CHECK: module attributes {torch.debug_module_name = "TestModule"}
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -18,6 +18,7 @@ mb = ModuleBuilder()
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so # `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
# naively duplicating a Tensor retains the identity of the TensorImpl. # naively duplicating a Tensor retains the identity of the TensorImpl.
# CHECK-LABEL: torch.class_type @__torch__.TestModule { # CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: torch.class_type @__torch__.TestModule { # CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.add3 # CHECK-LABEL: func.func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch # Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently # graph are ascribed to the first op that carries source information. Presently

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: @__torch__.f # CHECK-LABEL: @__torch__.f
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script

View File

@ -11,6 +11,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.optional_return( # CHECK-LABEL: func.func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int> # CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>

View File

@ -13,6 +13,7 @@ mb = ModuleBuilder()
# else branch and making all defined values optional, so no special handling # else branch and making all defined values optional, so no special handling
# is needed. # is needed.
# CHECK-LABEL: @__torch__.prim_If( # CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int { # CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {

View File

@ -11,6 +11,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike( # CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float { # CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true # CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true

View File

@ -15,6 +15,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor( # CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor # CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor

View File

@ -13,6 +13,7 @@ from utils import create_script_function
mb = ModuleBuilder() mb = ModuleBuilder()
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])]) NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
# CHECK-LABEL: func.func @__torch__.tuple( # CHECK-LABEL: func.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> # CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK: @__torch__.returns_bool # CHECK: @__torch__.returns_bool
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK: @__torch__.returns_none # CHECK: @__torch__.returns_none
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script

View File

@ -9,6 +9,7 @@ from torch._C import CompilationUnit
# RUN: %PYTHON %s # RUN: %PYTHON %s
# Import TorchScript IR string as ScriptFunction. # Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str, **kwargs): def create_script_function(func_name, ts_ir_str, **kwargs):
cu = CompilationUnit() cu = CompilationUnit()

View File

@ -236,12 +236,6 @@ _IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP # set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
if _IS_TORCH_2_1_OR_EARLIER: if _IS_TORCH_2_1_OR_EARLIER:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_numel,
}
SYMBOLIC_OP_TO_TORCH_OP = { SYMBOLIC_OP_TO_TORCH_OP = {
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default, (torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int, (torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
@ -249,13 +243,9 @@ if _IS_TORCH_2_1_OR_EARLIER:
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int, (torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default, (torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
} }
else:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_numel.default,
}
SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP}
else:
SYMBOLIC_OP_TO_TORCH_OP = { SYMBOLIC_OP_TO_TORCH_OP = {
torch.ops.aten.sym_size.default: torch.ops.aten.size.default, torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
torch.ops.aten.sym_size.int: torch.ops.aten.size.int, torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
@ -264,6 +254,8 @@ else:
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default, torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
} }
SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}
@dataclass(frozen=True) @dataclass(frozen=True)
class SparsityMeta: class SparsityMeta:
@ -1857,8 +1849,7 @@ def _emit_operation(
# Opaque value to indicate something is empty. Used in cases where 'None' # Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning. # may have a different meaning.
class EmptyType: class EmptyType: ...
...
Empty = EmptyType() Empty = EmptyType()

View File

@ -156,8 +156,7 @@ class GraphInfo:
return "" return ""
class OnnxImportError(Exception): class OnnxImportError(Exception): ...
...
class NodeImporter: class NodeImporter:
@ -235,22 +234,22 @@ class NodeImporter:
else: else:
default_opset_version = opset_import.version default_opset_version = opset_import.version
if default_opset_version: if default_opset_version:
container_op.attributes[ container_op.attributes["torch.onnx_meta.opset_version"] = (
"torch.onnx_meta.opset_version" IntegerAttr.get(i64_type, default_opset_version)
] = IntegerAttr.get(i64_type, default_opset_version) )
if opset_versions: if opset_versions:
container_op.attributes[ container_op.attributes["torch.onnx_meta.opset_versions"] = (
"torch.onnx_meta.opset_versions" DictAttr.get(opset_versions)
] = DictAttr.get(opset_versions) )
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
IntegerType.get_signed(64), m.ir_version IntegerType.get_signed(64), m.ir_version
) )
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
m.producer_name m.producer_name
) )
container_op.attributes[ container_op.attributes["torch.onnx_meta.producer_version"] = (
"torch.onnx_meta.producer_version" StringAttr.get(m.producer_version)
] = StringAttr.get(m.producer_version) )
def import_all(self, func=True): def import_all(self, func=True):
"""Imports all nodes topologically.""" """Imports all nodes topologically."""
@ -348,8 +347,14 @@ class NodeImporter:
continue continue
elif handler is False: elif handler is False:
# Active error. # Active error.
# try matching attribute type ID to name for a more descriptive error message
try:
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
except ValueError:
attr_type_name = "UNKNOWN"
raise OnnxImportError( raise OnnxImportError(
f"ONNX importer does not support generic node attribute type {attr_type}. " f"ONNX importer does not support generic node attribute type {attr_type_name} "
f"with ID {attr_type}. "
f"This likely means that this is a special node which requires specific " f"This likely means that this is a special node which requires specific "
f"handling in the importer: {onnx_attr}" f"handling in the importer: {onnx_attr}"
) )
@ -658,9 +663,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
RankedTensorType.get(shape, IntegerType.get_signed(64)), RankedTensorType.get(shape, IntegerType.get_signed(64)),
IntegerAttr.get( IntegerAttr.get(
IntegerType.get_signed(64), IntegerType.get_signed(64),
(
int.from_bytes(tp.raw_data, "little", signed=True) int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data") if tp.HasField("raw_data")
else tp.int64_data[0], else tp.int64_data[0]
),
), ),
), ),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
@ -703,7 +710,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
), ),
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get( onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
) ),
# Intentionally unsupported: STRING # Intentionally unsupported: STRING
} }

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE. # Also available under a BSD-style license. See LICENSE.
from typing import Optional, Union, Dict, Tuple, Any from typing import Optional, Union, Dict, Tuple, Any, Callable
import warnings import warnings
@ -25,7 +25,7 @@ def export_and_import(
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False, experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None, hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[list] = None, decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main", func_name: str = "main",
enable_graph_printing: bool = False, enable_graph_printing: bool = False,
**kwargs, **kwargs,

View File

@ -1 +1 @@
0a3e5f5badd8a0cb7fac97f5ec9d48c304e5c0b7 34ade3521ca41f20af3469bba276c2b0499c3892

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre --pre
torch==2.4.0.dev20240422 torch==2.4.0.dev20240428

View File

@ -0,0 +1,20 @@
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
// CHECK-LABEL: func.func @test_ifop_basic
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
// CHECK-DAG: } else {
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}, {
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}
return %0 : !torch.vtensor<[1],f32>
}

View File

@ -1996,3 +1996,160 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32> %0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32>
return %0 : !torch.vtensor<[3,?],f32> return %0 : !torch.vtensor<[3,?],f32>
} }
// -----
// CHECK-LABEL: func.func @test_blackmanwindow_symmetric
func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_blackmanwindow
func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_hannwindow
func.func @test_hannwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_hannwindow_symmetric
func.func @test_hannwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}

View File

@ -860,6 +860,57 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2
// ----- // -----
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0:.+]] = torch.constant.int 0
@ -942,41 +993,24 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
// ----- // -----
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example // CHECK-LABEL: func.func @test_reduce_sum_square_default_axes_keepdims_example
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int> // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> // CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32> return %0 : !torch.vtensor<[1,1,1],f32>
} }
// ----- // -----
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example // CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_reduce_sum_square_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
@ -984,15 +1018,65 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int> // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> // CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32> %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32> return %0 : !torch.vtensor<[3,2],f32>
} }
// ----- // -----
// CHECK-LABEL: func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero
func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 8: si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[2,0,4],f32>, !torch.vtensor<[2,0,4],f32> -> !torch.vtensor<[2,0,4],f32>
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[2,0,1],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32>
return %0 : !torch.vtensor<[2,0,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_example
func.func @test_reduce_sum_square_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
return %0 : !torch.vtensor<[3,1,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_int_example
func.func @test_reduce_sum_square_keepdims_int_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
return %0 : !torch.vtensor<[3,1,2],f32>
}
// -----
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example // CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>

View File

@ -55,7 +55,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt
// CHECK-LABEL: func.func @torch.aten.reciprocal( // CHECK-LABEL: func.func @torch.aten.reciprocal(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32> // CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32> // CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
@ -124,7 +124,7 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32> // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
@ -152,7 +152,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32> // CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex> // CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32> // CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32> // CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32> // CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32> // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
@ -185,7 +185,7 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32> // CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32> // CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32> // CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
@ -214,7 +214,7 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> // CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32> // CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) // CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> // CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> // CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> // CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>

View File

@ -4,9 +4,9 @@
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "none" // CHECK: %[[STR:.*]] = torch.constant.str "none"
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 2.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 5.000000e-01 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32> // CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32> // CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32> // CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
@ -487,7 +487,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
// CHECK-LABEL: func.func @torch.aten.relu( // CHECK-LABEL: func.func @torch.aten.relu(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 0.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32> // CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>

View File

@ -10,7 +10,7 @@
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32> // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> // CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false}> : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32> // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32> // CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
@ -31,7 +31,7 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32> // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32> // CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false}> : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32> // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
@ -53,7 +53,7 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32> // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32> // CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false}> : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32> // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>

View File

@ -14,11 +14,11 @@
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32> // CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ // CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>}> ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>): // CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32> // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32> // CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
@ -46,12 +46,12 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32> // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ // CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>}> ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>): // CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32> // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32> // CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
// CHECK: }) // CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
@ -96,7 +96,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64> // CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64> // CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64> // CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({ // CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>}> ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>): // CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32> // CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64> // CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64> // CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64> // CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>) // CHECK: }) : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
// CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> // CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
@ -137,11 +137,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ // CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>): // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32> // CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32> // CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32> // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
@ -158,11 +159,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> // CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32> // CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ // CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>): // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32> // CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32> // CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32> // CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
@ -194,11 +196,12 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({ // CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): // CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> // CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: stablehlo.return %[[T10]] : tensor<f32> // CHECK: stablehlo.return %[[T10]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64> // CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32> // CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>

View File

@ -22,10 +22,10 @@
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x1xi64> // CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x1xi64>
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor<?x?x1xi64> // CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor<?x?x1xi64>
// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64> // CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64>
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) ({ // CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
// CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>): // CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>):
// CHECK: stablehlo.return %[[ARG_4]] : tensor<i64> // CHECK: stablehlo.return %[[ARG_4]] : tensor<i64>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false} : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64> // CHECK: }) : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
// CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64> // CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64> // CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64>
func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {

View File

@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool {
// CHECK-LABEL: func.func @torch.aten.eq.str$same_operand( // CHECK-LABEL: func.func @torch.aten.eq.str$same_operand(
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true // CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-NEXT: return %[[F]] : !torch.bool // CHECK-NEXT: return %[[TRUE]] : !torch.bool
func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool { func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool %0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
@ -522,8 +522,8 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool {
} }
// CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool true // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[FALSE]] : !torch.bool // CHECK: return %[[TRUE]] : !torch.bool
func.func @torch.aten.ne.str$different_value() -> !torch.bool { func.func @torch.aten.ne.str$different_value() -> !torch.bool {
%str4 = torch.constant.str "4" %str4 = torch.constant.str "4"
%str5 = torch.constant.str "5" %str5 = torch.constant.str "5"
@ -533,16 +533,16 @@ func.func @torch.aten.ne.str$different_value() -> !torch.bool {
// CHECK-LABEL: func.func @torch.aten.ne.str$same_operand( // CHECK-LABEL: func.func @torch.aten.ne.str$same_operand(
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false // CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-NEXT: return %[[F]] : !torch.bool // CHECK-NEXT: return %[[FALSE]] : !torch.bool
func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool { func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool {
%0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool %0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[TRUE]] : !torch.bool // CHECK: return %[[FALSE]] : !torch.bool
func.func @torch.aten.ne.str$same_value() -> !torch.bool { func.func @torch.aten.ne.str$same_value() -> !torch.bool {
%str4 = torch.constant.str "4" %str4 = torch.constant.str "4"
%str4_0 = torch.constant.str "4" %str4_0 = torch.constant.str "4"
@ -568,6 +568,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int {
return %2 : !torch.int return %2 : !torch.int
} }
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
%str = torch.constant.str "c"
%str_0 = torch.constant.str "b"
%str_1 = torch.constant.str "a"
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
%str = torch.constant.str "aa"
%str_0 = torch.constant.str "aa"
%str_1 = torch.constant.str "ccc"
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.__not__ // CHECK-LABEL: func.func @torch.aten.__not__
// CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool // CHECK: return %[[TRUE]] : !torch.bool
@ -2950,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32> %result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32> return %result : !torch.vtensor<[4], f32>
} }
// -----
// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32>
func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32>
return %0 : !torch.vtensor<[64],f32>
}
// -----
// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
// CHECK: return %[[RET]] : !torch.vtensor<[64],si32>
func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
return %0 : !torch.vtensor<[64],si32>
}
// -----
// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]]
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32>
func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
return %result0 : !torch.vtensor<[10,64,56,56],f32>
}

View File

@ -105,6 +105,33 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
print(m) print(m)
@run
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
def test_broadcast_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.broadcast_to(x, (y.shape[0], -1))
# Sample inputs
x = torch.randn(1, 2)
y = torch.randn(10)
dim_0 = Dim("dim_0")
dynamic_shapes = {
"x": {},
"y": {0: dim_0},
}
m = fx.export_and_import(
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net"
)
print(m)
@make_boxed_compiler @make_boxed_compiler
def fx_import_aot_autograd_backend( def fx_import_aot_autograd_backend(
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend(
@run @run
# CHECK-LABEL: test_stateless_fx_import # CHECK-LABEL: test_stateless_fx_import
# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32> # CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
def test_stateless_fx_import(): def test_stateless_fx_import():

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre --pre
torchvision==0.19.0.dev20240422 torchvision==0.19.0.dev20240428