Merge remote-tracking branch 'upstream/main' into emit_detach_

pull/3287/head
Xinyu Yang 2024-05-06 13:04:03 +08:00
commit 99e2f143f0
66 changed files with 1709 additions and 376 deletions

View File

@ -11,7 +11,7 @@ repos:
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.4.2
hooks:
- id: black

22
.yamllint.yml 100644
View File

@ -0,0 +1,22 @@
---
extends: default
rules:
# These do not appear to be conventional in GitHub actions.
document-end:
present: false
document-start:
present: false
# GitHub actions use "on" for triggers.
truthy: disable
# We have lots of long strings and command lines.
line-length: disable
comments:
# Formatters may do this (e.g. Prettier does) and it seems like the most
# trivial thing to get a failing check for.
min-spaces-from-content: 1
# This is not a useful check, especially when disabling entire blocks.
comments-indentation: disable
ignore: /third_party/*

View File

@ -50,6 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \
-DLLVM_TARGETS_TO_BUILD=host \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_LTC=ON
echo "::endgroup::"
echo "::group::Build"

View File

@ -432,6 +432,8 @@ function clean_build() {
}
function build_torch_mlir() {
# Disable LTC build for releases
export TORCH_MLIR_ENABLE_LTC=0
local torch_version="$1"
case $torch_version in
nightly)
@ -440,7 +442,7 @@ function build_torch_mlir() {
--extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir \
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir \
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \
-r /main_checkout/torch-mlir/whl-requirements.txt
;;
@ -450,7 +452,7 @@ function build_torch_mlir() {
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir
;;
*)
echo "Unrecognized torch version '$torch_version'"
@ -474,7 +476,7 @@ function build_torch_mlir_core() {
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir
}
function clean_wheels() {

View File

@ -2,6 +2,7 @@
See https://github.com/llvm/torch-mlir/issues/1374
"""
import argparse
import json

@ -1 +1 @@
Subproject commit a952c123880eb1168f1021b116485e27170d48ca
Subproject commit 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf

2
externals/stablehlo vendored

@ -1 +1 @@
Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97
Subproject commit ab92adeda9119a6c3914cd42367b0a2b70765e91

View File

@ -30,8 +30,6 @@ namespace detail {
LogicalResult verifyTMTensorOpInterface(Operation *op);
}
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
/// Include the generated interface declarations.
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export
@ -39,4 +37,6 @@ LogicalResult verifyTMTensorOpInterface(Operation *op);
} // namespace torch
} // namespace mlir
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export
#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_

View File

@ -97,6 +97,31 @@ struct OpBinder {
return success();
}
ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
for (auto result : op->getResults()) {
auto t = toValidTensorType(result.getType());
if (!t)
return failure();
typeList.push_back(t);
}
return success();
}
// The importer imports Onnx.GraphProto attributes as regions attached to the
// op.
ParseResult getRegionAtIndex(mlir::Region *&region, int64_t idx) {
if (idx >= op->getNumRegions())
return failure();
region = &op->getRegion(idx);
if (region == nullptr) {
return failure();
}
return success();
}
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
int64_t idx) {
if (idx >= op->getNumResults())

View File

@ -38,6 +38,13 @@ Value createConstantIntList(OpBinder binder,
Type getQTorchTypeFromTorchIntType(Type ty);
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Value &ofItem) {
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
rewriter.getType<T>(), ofItem);
}
LogicalResult OnnxLstmExpander(OpBinder binder,
ConversionPatternRewriter &rewriter);

View File

@ -69,6 +69,17 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
size_t dimSizeIndexBits);
// Get a tensor that collapse the specified dimensions of the input tensor
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits);
// Get a tensor that splits the specified dimensions of the input tensor
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits);
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType);

View File

@ -4810,6 +4810,53 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [
}];
}
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCeluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCelu_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenRealOp : Torch_Op<"aten.real", [
AllowsTypeRefinement,
ReadOnly
@ -6590,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
}];
}
def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenMaxPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
@ -6645,6 +6720,7 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices",
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
@ -13601,6 +13677,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [
}];
}
def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`";
let arguments = (ins
AnyTorchListOfTorchStringType:$l,
Torch_StringType:$item
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
@ -15904,6 +16005,7 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_PrimsVarOp : Torch_Op<"prims.var", [

View File

@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl<bool> &bind_values) {
return detail::torch_list_of_constant_bools_op_binder(bind_values);
}
namespace detail {
/// Matches the constant strs stored in a `torch.ListConstruct`.
struct torch_list_of_constant_strs_op_binder {
SmallVectorImpl<std::string> &bind_values;
/// Creates a matcher instance that binds the value to bvs if match succeeds.
torch_list_of_constant_strs_op_binder(SmallVectorImpl<std::string> &bvs)
: bind_values(bvs) {}
bool match(Operation *op) {
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
if (!listConstruct)
return false;
for (Value value : listConstruct.getElements()) {
std::string str;
if (matchPattern(value, m_TorchConstantStr(str)))
bind_values.push_back(str);
else
return false;
}
return true;
}
};
} // namespace detail
/// Matches the constant strs stored in a `torch.prim.ListConstruct`.
inline detail::torch_list_of_constant_strs_op_binder
m_TorchListOfConstantStrs(SmallVectorImpl<std::string> &bind_values) {
return detail::torch_list_of_constant_strs_op_binder(bind_values);
}
namespace detail {
/// Matches the expected tensor and dim from `torch.aten.size.int`.
struct torch_tensor_size_int_op_binder {

View File

@ -35,6 +35,108 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
return success();
}
namespace {
LogicalResult windowFunctionImpl(OpBinder binder,
ConversionPatternRewriter &rewriter,
Value size, Value a0, Value a1, Value a2,
Torch::ValueTensorType resultType,
int64_t output_datatype, int64_t periodic) {
Location loc = binder.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
double isPeriodicFp = static_cast<double>(periodic);
Value zero = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(0.0));
Value one = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(1.0));
Value two = b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2.0));
constexpr double pi = llvm::numbers::pi;
Value tau = b.create<Torch::ConstantFloatOp>(
rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi));
Value noneVal = b.create<Torch::ConstantNoneOp>();
Value cstFalse = b.create<Torch::ConstantBoolOp>(false);
Value float32Type = b.create<Torch::ConstantIntOp>(
rewriter.getI64IntegerAttr(/*float32Type*/ 6));
// Create an f32 ValueTensorType with thse same size as size, the
// operand
auto shapeOfOperand =
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
shapeOfOperand, rewriter.getF32Type());
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal);
Value symmetricSizeFloat = b.create<Torch::AtenSubScalarOp>(
periodicSizeFloat.getType(), periodicSizeFloat, one, one);
Value isPeriodic =
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(isPeriodicFp));
Value isSymmetricFloat = b.create<Torch::ConstantFloatOp>(
rewriter.getF64FloatAttr(1.0 - isPeriodicFp));
Value periodicComponent = b.create<Torch::AtenMulScalarOp>(
periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic);
Value symmetricComponent = b.create<Torch::AtenMulScalarOp>(
symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat);
Value sizeFloat = b.create<Torch::AtenAddTensorOp>(
symmetricComponent.getType(), symmetricComponent, periodicComponent, one);
// Here, size can be used in the place of periodicSizeFloat, as the
// latter is just a float representation of the former.
Value scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, size);
Value rangeArr = b.create<Torch::AtenArangeStartStepOp>(
resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal);
Value rangeTimesTau =
b.create<Torch::AtenMulScalarOp>(resultType, rangeArr, tau);
Value rangeAngular =
b.create<Torch::AtenDivTensorOp>(resultType, rangeTimesTau, sizeFloat);
Value twoRangeAngular =
b.create<Torch::AtenMulScalarOp>(resultType, rangeAngular, two);
Value cosRangeAngular = b.create<Torch::AtenCosOp>(resultType, rangeAngular);
Value cosTwoRangeAngular =
b.create<Torch::AtenCosOp>(resultType, twoRangeAngular);
Value a1Component =
b.create<Torch::AtenMulScalarOp>(resultType, cosRangeAngular, a1);
Value a2Component =
b.create<Torch::AtenMulScalarOp>(resultType, cosTwoRangeAngular, a2);
// AtenSubScalarOp actually requires a tensor operand as the LHS, that
// is, operand #1. Therefore, to avoid errors, the onnx implementation
// has been modified. a1 has been changed to negative half, and the
// AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add
// operation is commutative.
Value subA1Component =
b.create<Torch::AtenAddScalarOp>(resultType, a1Component, a0, one);
Value result = b.create<Torch::AtenAddTensorOp>(resultType, subA1Component,
a2Component, one);
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(output_datatype);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for the given dtype conversion");
}
Value outputDtype = b.create<Torch::ConstantIntOp>(
rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch.value()));
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, result, outputDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/noneVal);
return success();
}
} // namespace
// Simple rewrites for the default domain.
// See: https://onnx.ai/onnx/operators/
// For operators that are effectively version invariant, we register with
@ -2186,4 +2288,65 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
return success();
});
patterns.onOp(
"BlackmanWindow", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value size;
Torch::ValueTensorType resultType;
int64_t periodic, output_datatype;
if (binder.tensorOperand(size) ||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.42));
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.08));
auto windowFunctionResult =
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
output_datatype, periodic);
if (failed(windowFunctionResult))
return failure();
return success();
});
patterns.onOp(
"HannWindow", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value size;
Torch::ValueTensorType resultType;
int64_t periodic, output_datatype;
if (binder.tensorOperand(size) ||
binder.s64IntegerAttr(output_datatype, "output_datatype", 1) ||
binder.s64IntegerAttr(periodic, "periodic", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value a0 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.5));
Value a1 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getFloatAttr(rewriter.getF64Type(), -0.5));
Value a2 = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
auto windowFunctionResult =
windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType,
output_datatype, periodic);
if (failed(windowFunctionResult))
return failure();
return success();
});
}

View File

@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
alignCorners);
return success();
});
patterns.onOp(
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value conditionTensor;
if (binder.tensorOperand(conditionTensor)) {
return rewriter.notifyMatchFailure(binder.op,
"condition bind failure");
}
auto conditionType =
conditionTensor.getType().cast<Torch::ValueTensorType>();
if (!conditionType || conditionType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(
binder.op, "condition must have one single element per "
"https://onnx.ai/onnx/operators/onnx__If.html");
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
conditionTensor);
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
llvm::SmallVector<mlir::Type> resultTypes;
if (binder.tensorResultTypes(resultTypes)) {
return rewriter.notifyMatchFailure(binder.op,
"result type bind failure");
}
Region *thenRegion, *elseRegion;
if (binder.getRegionAtIndex(elseRegion, 0) ||
binder.getRegionAtIndex(thenRegion, 1)) {
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
}
auto primIfOp = rewriter.create<Torch::PrimIfOp>(
binder.getLoc(), TypeRange(resultTypes), conditionBool);
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
};
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
inlineIfCase(*elseRegion, primIfOp.getElseRegion());
auto replaceTerminator = [&](Region &region) {
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = region.front().getTerminator();
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
terminator, terminator->getOperands());
};
replaceTerminator(primIfOp.getThenRegion());
replaceTerminator(primIfOp.getElseRegion());
rewriter.replaceOp(binder.op, primIfOp.getResults());
return success();
});
patterns.onOp("Less", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -31,15 +31,7 @@ using namespace mlir::torch::onnx_c;
// thing here, so we simplify.
// utilities
// Templatized function to get an item op of a type
namespace {
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Value &ofItem) {
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
rewriter.getType<T>(), ofItem);
}
// In case the ReduceSum Op was not the first operation performed on the data,
// we provide the original operand through storeResult, which will be modified
// if the result will be passed onto another operation, and will be used for
@ -847,12 +839,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
patterns.onOp(
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0
Torch::ValueTensorType resultType;
float alpha, gamma;
Value operand;
// Refer https://onnx.ai/onnx/operators/onnx__Selu.html for the default
// alpha and gamma values.
if (binder.tensorOperand(operand) ||
binder.f32FloatAttr(alpha, "alpha") ||
binder.f32FloatAttr(gamma, "gamma") ||
binder.f32FloatAttr(alpha, "alpha", 1.67326) ||
binder.f32FloatAttr(gamma, "gamma", 1.0507) ||
binder.tensorResultType(resultType))
return failure();
@ -945,22 +940,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*memory_format=*/noneVal);
return success();
});
patterns.onOp("ReduceSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();
return reducedSumImpl(binder, rewriter, data, resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp("ReduceLogSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
@ -987,6 +966,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, data);
return success();
});
patterns.onOp("ReduceSum", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();
return reducedSumImpl(binder, rewriter, data, resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp("ReduceSumSquare", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();
Value dataSquare = rewriter.create<Torch::AtenMulTensorOp>(
binder.getLoc(), data.getType(), data, data);
return reducedSumImpl(binder, rewriter, dataSquare,
resultType,
/*storeValue=*/data, keepDims,
noop_with_empty_axes, false);
});
patterns.onOp(
"ReduceMean", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
if (!isUnsignedType)
return;
int64_t minSI = -(1 << (numBits - 1));
Value minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, 32);
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric(
@ -797,6 +798,8 @@ public:
auto resultTy = cast<ValueTensorType>(op.getType());
Value inputZp, weightZp;
bool inputUnsigned = false;
bool weightUnsigned = false;
if (auto make = op.getInput()
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
input = make.getSelf();
@ -806,6 +809,8 @@ public:
inputZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(inputZp.getType()),
inputZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
}
if (auto make = op.getWeight()
@ -818,6 +823,8 @@ public:
weightZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(weightZp.getType()),
weightZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
}
if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
@ -916,15 +923,35 @@ public:
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
// convert any uint8 quantization to int8 quantization
if (auto integerType = dyn_cast<mlir::IntegerType>(inputDTy)) {
int64_t width = integerType.getWidth();
signShift(rewriter, loc, input, inputZp, inputUnsigned, width);
}
if (auto integerType = dyn_cast<mlir::IntegerType>(weightDTy)) {
int64_t width = integerType.getWidth();
signShift(rewriter, loc, weight, weightZp, weightUnsigned, width);
}
// Pad the input tensor according to padding.
SmallVector<Value> outDims{inBatch, weightBatch};
Value paddedInput;
if (transposed) {
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
!isa<mlir::FloatType>(resultDTy))
return rewriter.notifyMatchFailure(
op, "transpose does not support non-fp type yet");
Value pad = inputZp;
if (!pad) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}
if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}
if (transposed) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value c1 =
@ -994,7 +1021,7 @@ public:
// Allocate padded input tensor
Value initTensor =
createZeroInitTensor(rewriter, loc, outerSizes, inputDTy);
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);
// Insert input into allocated tensor
SmallVector<Value> strideIndexValues{c1, c1};
@ -1017,24 +1044,6 @@ public:
strideInts.clear();
strideInts.append(numSpatialDims, 1);
} else {
Value pad = inputZp;
if (!pad) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}
if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}
// Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);

View File

@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
AtenCumsumOp>(op)) {
AtenAvgPool3dOp, AtenCumsumOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
}
// Max pooling
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
AtenMaxPool2dWithIndicesOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get(
constType,
@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
return nullptr;
}
// AtenMaxPool2dOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto outTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (inputRank <= 2) {
return op.emitError(
"max_pooling2d only supports inputs with rank higher than 2");
}
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &block = reduceWindowOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArg = block.args_begin();
auto secondArg = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result =
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
// AtenMaxPool2dWithIndicesOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
return success();
}
namespace {
template <typename AtenOpT, int Dim>
class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
public:
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto outTy = cast<RankedTensorType>(
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
if (inputRank <= Dim) {
return op.emitError(
"max_pooling1d/2d only supports inputs with rank higher than 1/2");
}
SmallVector<int64_t, Dim> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op,
"non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(
op, "non-const bool ceil_mode unsupported!");
}
if (stride.empty()) {
stride = kernelSize;
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank
// as input
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - Dim);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - Dim);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - Dim);
Value initVal =
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
if (Dim == 1) {
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
} else if (Dim == 2) {
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
} else if (Dim == 3) {
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
} else {
assert(false && "Unsupported pooling dimension");
}
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &block = reduceWindowOp.getBody().emplaceBlock();
// Add bb argument
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
block.addArgument(blockArgumentType, op->getLoc());
block.addArgument(blockArgumentType, op->getLoc());
auto *firstArg = block.args_begin();
auto secondArg = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result = rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg,
*secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
};
} // namespace
namespace {
template <typename AtenOpT, int Dim>
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
@ -375,8 +404,8 @@ public:
auto outShape = outTy.getShape();
if (inputRank <= Dim) {
return op.emitError(
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank "
"higher than 1/2/3");
}
SmallVector<int64_t, Dim> padding, kernelSize, stride;
bool ceilMode = false;
@ -405,6 +434,10 @@ public:
op, "non-const bool count_include_pad unsupported!");
}
if (stride.empty()) {
stride = kernelSize;
}
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
return rewriter.notifyMatchFailure(
@ -425,11 +458,20 @@ public:
if (Dim == 1) {
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
} else {
} else if (Dim == 2) {
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
} else if (Dim == 3) {
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
} else {
assert(false && "Unsupported pooling dimension");
}
Value initVal =
@ -474,10 +516,17 @@ public:
divisor =
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
.value();
} else {
} else if (Dim == 2) {
divisor = hlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value();
} else if (Dim == 3) {
divisor = hlo::getConstTensor<int64_t>(
rewriter, op,
{kernelSize[0] * kernelSize[1] * kernelSize[2]}, {})
.value();
} else {
assert(false && "Unsupported pooling dimension");
}
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
DenseI64ArrayAttr bcastDimensions;
@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenAvgPool1dOp>();
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context, options);
target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
#undef INSERT_ATEN_POOLING_PATTERN
#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMaxPoolOp<AtenOp, Dim>>(typeConverter, context, \
options)
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1);
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2);
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3);
#undef INSERT_ATEN_MAXPOOL_PATTERN
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
options)
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3);
#undef INSERT_ATEN_AVGPOOL_PATTERN
}

View File

@ -9,6 +9,7 @@
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
@ -306,6 +307,136 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
.getResult();
}
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto dimSizes = *dimSizesInfo;
int64_t rank = dimSizes.size();
collapseStartDim = toPositiveDim(collapseStartDim, rank);
collapseEndDim = toPositiveDim(collapseEndDim, rank);
int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
newDimSizes.reserve(newRank);
newShape.reserve(newRank);
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
int64_t collapseShape = 1;
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
if (k < 0 || k >= rank) {
return rewriter.notifyMatchFailure(
op, "collapse dimensions must be within the rank of the tensor");
}
if (collapseShape == ShapedType::kDynamic ||
oldShape[k] == ShapedType::kDynamic) {
collapseShape = ShapedType::kDynamic;
} else {
collapseShape *= oldShape[k];
}
collapseDimSize =
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
}
for (int64_t k = 0; k < collapseStartDim; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
newDimSizes.push_back(collapseDimSize);
newShape.push_back(collapseShape);
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}
// TODO: support splitDim & outerLength to be Value
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto dimSizes = *dimSizesInfo;
int64_t rank = dimSizes.size();
splitDim = toPositiveDim(splitDim, rank);
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
if (splitDim < 0 || splitDim >= rank) {
return rewriter.notifyMatchFailure(
op, "split dimensions must be within the rank of the tensor");
}
int64_t newRank = rank + 1;
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, outerLength));
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
loc, dimSizes[splitDim], outerLengthValue);
int64_t originShape = oldShape[splitDim];
int64_t outerShape = outerLength;
int64_t innerShape = originShape == ShapedType::kDynamic
? ShapedType::kDynamic
: originShape / outerLength;
std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
newDimSizes.reserve(newRank);
newShape.reserve(newRank);
for (int64_t k = 0; k < splitDim; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
newDimSizes.push_back(outerLengthValue);
newShape.push_back(outerShape);
newDimSizes.push_back(innerLengthValue);
newShape.push_back(innerShape);
for (int64_t k = splitDim + 1; k < rank; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType) {

View File

@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "only constant end is currently supported");
start = toPositiveDim(start, rank);
end = toPositiveDim(end, rank);
SmallVector<int64_t, 4> dims;
dims.reserve(rank);
for (int r = 0; r < start; ++r)
dims.push_back(r);
int64_t collapsedDimSize = 1;
for (int r = start; r <= end; ++r) {
if (selfType.getShape()[r] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
op, "the size of the dimension being collapsed is can't be unknown");
collapsedDimSize *= selfType.getShape()[r];
}
dims.push_back(collapsedDimSize);
for (int r = end + 1; r < rank; ++r)
dims.push_back(r);
auto collapseTensorInfo = hlo::collapseTensor(
rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
if (failed(collapseTensorInfo))
return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
auto newDimSizesInfo = hlo::getDimSizesOfTensor(
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
rewriter.replaceOp(op, *collapseTensorInfo);
return success();
}
template <>
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
PrimsSplitDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
if (!selfType) {
return op.emitError("only tensor types are currently supported");
}
auto rank = selfType.getRank();
if (rank == 0)
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo;
auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
stablehloShape);
op, "the rank of tensor must be greater than 0");
int64_t dim, outerLength;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
return rewriter.notifyMatchFailure(
op, "only constant outerLength is currently supported");
auto splitTensorInfo = hlo::splitTensor(
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
if (failed(splitTensorInfo))
return rewriter.notifyMatchFailure(op, "failed to create split tensor");
rewriter.replaceOp(op, *splitTensorInfo);
return success();
}
@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_VIEW_OP_PATTERN(AtenOp) \

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
}
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
auto dtypeElemType = dtypeComplex.getElementType();
// Extract the real and imaginary parts of the scalar.
// Cast them to the target element type, and create a new complex
// value with the target complex type.
Value realVal = b.create<complex::ReOp>(loc, scalar);
Value imgVal = b.create<complex::ImOp>(loc, scalar);
realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType);
imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType);
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
}
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype
<< "(dtype)";
}
llvm_unreachable("convertScalarToDtype should handle all the types");
}

View File

@ -936,7 +936,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
if (opOperand->get().isa<BlockArgument>())
if (isa<BlockArgument>(opOperand->get()))
return false;
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);

View File

@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc,
return nullptr;
Type inputDtype = inputTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(64))
if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1)))
return nullptr;
std::optional<unsigned> inputRank = getTensorRank(input);
@ -148,10 +148,19 @@ static Value getScalarIntValue(Value input, Location loc,
return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
if (inputDtype.isInteger(64)) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
.getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
} else {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseIntElementsAttr>()
.getSplatValue<bool>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
}
} else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
return primNumToTensorScalarOp.getA();
@ -2385,6 +2394,30 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// Aten__Contains__StrListOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) {
StringAttr item = dyn_cast<StringAttr>(adaptor.getItem());
if (!item)
return nullptr;
if (auto listConstruct = getL().getDefiningOp<Torch::PrimListConstructOp>()) {
if (isListPotentiallyMutated(listConstruct))
return nullptr;
}
llvm::SmallVector<std::string> strs;
if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) {
for (const auto &str : strs) {
if (item.getValue().str() == str)
return getI1IntegerAttr(getContext(), true);
}
return getI1IntegerAttr(getContext(), false);
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenLtIntOp
//===----------------------------------------------------------------------===//
@ -4682,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// PrimsConvertElementTypeOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
auto inputType = cast<BaseTensorType>(getA().getType());
auto outputType = cast<BaseTensorType>(getResult().getType());
if (inputType != outputType)
return nullptr;
if (!inputType.hasDtype() || !outputType.hasDtype())
return nullptr;
if (inputType.getDtype() != outputType.getDtype())
return nullptr;
return getA();
}
//===----------------------------------------------------------------------===//
// AtenMaxPool2dWithIndicesOp
//===----------------------------------------------------------------------===//
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
if (!op.getResult1().use_empty()) {
return rewriter.notifyMatchFailure(
op, "result1 of MaxPool2dWithIndices should be unused");
}
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
op->getLoc(), op.getResult0().getType(), op.getSelf(),
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
op.getCeilMode());
op.getResult0().replaceAllUsesWith(result);
rewriter.eraseOp(op);
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenLinalgCrossOp
//===----------------------------------------------------------------------===//

View File

@ -6998,6 +6998,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.celu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -7841,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %arg2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n"
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
" func.func @__torch__.pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %int-2 = torch.constant.int -2\n"
" %int-3 = torch.constant.int -3\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n"
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n"
" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n"
" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n"
" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
@ -7936,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %23 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10480,6 +10488,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.celu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -1059,44 +1059,44 @@ public:
LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
int64_t n;
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
int64_t m;
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
return rewriter.notifyMatchFailure(op,
"unimplemented: m must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc);
auto outType = dyn_cast<BaseTensorType>(op.getType());
auto outType = op.getType().dyn_cast<BaseTensorType>();
if (!outType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
if (!outType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
if (n < 0) {
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
}
if (m < 0) {
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
}
Value none = rewriter.create<ConstantNoneOp>(loc);
auto context = op.getContext();
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
int64_t n = kUnknownSize;
int64_t m = kUnknownSize;
// prioritize getting shape from output shape
if (outType.hasSizes() && outType.getSizes().size() == 2) {
n = outType.getSizes().front();
m = outType.getSizes().back();
}
// if output shape is not available, try to get shape from input
if (n == kUnknownSize)
matchPattern(op.getN(), m_TorchConstantInt(&n));
if (m == kUnknownSize)
matchPattern(op.getM(), m_TorchConstantInt(&m));
// prepare two unsqueezed ranges that are equal on and only on the diagonal
auto rangeNSize = llvm::SmallVector<int64_t, 1>({n});
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
Value rangeN = rewriter.create<AtenArangeOp>(
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/op.getDevice(), /*pin_memory=*/none);
auto arangeType1 =
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
auto rangeMSize = llvm::SmallVector<int64_t, 1>({m});
Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
Value rangeM = rewriter.create<AtenArangeOp>(
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
@ -1109,7 +1109,6 @@ public:
}
Value unsqzRangeN = *unsqzTensorInfo;
// compare unsqueezed input with boundaries
auto eqType = ValueTensorType::get(
context, cast<BaseTensorType>(op.getType()).getSizes(),
IntegerType::get(context, 1));
@ -2415,6 +2414,50 @@ public:
} // namespace
// CELU(x)=max(0,x)+min(0,alpha(exp(x/alpha)1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCeluOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
Value alpha = op.getAlpha();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value constantOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
// positiveOutput = max(0,x)
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
// negativeOutput = min(0,alpha(exp(x/alpha)1))
Value scaledInput =
rewriter.create<AtenDivScalarOp>(loc, resType, input, alpha);
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledInput);
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
constantOne, constantOne);
Value scaledExpXM1 =
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, alpha);
Value negativeOutput =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledExpXM1);
Value celuOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOne);
rewriter.replaceOp(op, celuOutput);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
public:
@ -7705,6 +7748,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);

View File

@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenCeluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();
target.addIllegalOp<AtenToPrimDeviceOp>();

View File

@ -272,6 +272,7 @@ TORCHDYNAMO_XFAIL_SET = {
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic",
"ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",
"FloatImplicitModule_basic",
@ -372,6 +373,7 @@ FX_IMPORTER_XFAIL_SET = {
"Conv2dQInt8Module_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
@ -544,6 +546,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ContainsIntList_True",
"Conv2dQInt8Module_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
@ -572,6 +575,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseErfIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
@ -678,11 +682,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"NumToTensorIntModule_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"PixelShuffleModuleSpatiallyStatic_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
@ -951,6 +950,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseCeilModule_basic",
"ElementwiseCeluStaticModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
@ -1079,6 +1079,9 @@ STABLEHLO_PASS_SET = {
"Matmul_vecmat",
"MatmulStaticBroadcast_basic",
"MaxPool2dStaticModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool3dStaticModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic",
@ -1156,6 +1159,8 @@ STABLEHLO_PASS_SET = {
"Permute0RankModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimMaxIntModule_basic",
@ -1239,6 +1244,7 @@ STABLEHLO_PASS_SET = {
"SliceWholeTensorModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SplitDimStaticModule_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
@ -1571,6 +1577,8 @@ TOSA_PASS_SET = {
"ElementwiseBitwiseXorModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseCeilModule_basic",
"ElementwiseCeluModule_basic",
"ElementwiseCeluStaticModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",
@ -1916,11 +1924,6 @@ MAKE_FX_TOSA_PASS_SET = (
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
"MatmulStaticBroadcast_basic",
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool2dStaticCeilModeTrueModule_basic",
"MaxPool2dStaticModule_basic",
"ResNet18StaticModule_basic",
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
@ -2096,6 +2099,7 @@ LTC_XFAIL_SET = {
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv2dQInt8Module_basic",
"ConvTranspose2DQInt8_basic",
}
ONNX_XFAIL_SET = {
@ -2121,7 +2125,6 @@ ONNX_XFAIL_SET = {
"ElementwiseAtenFloorDivideTensorNegativeModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseSeluModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"HardsigmoidModule_basic",
@ -2251,6 +2254,7 @@ ONNX_XFAIL_SET = {
"Conv2dWithPaddingModule_basic",
"Conv3dModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
"Convolution2DModule_basic",
"Convolution2DStridedModule_basic",
@ -2306,6 +2310,7 @@ ONNX_XFAIL_SET = {
"ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseOrTensorModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
@ -2554,16 +2559,12 @@ ONNX_XFAIL_SET = {
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"_SoftmaxModule_basic",
# Failure - onnx_import
# Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
# Failure - onnx_lowering: onnx.If
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
"DiagonalModule_with_dims",
"DiagonalModule_with_dims_and_offset",
"DiagonalModule_with_negative_dims",
"DiagonalModule_with_offset",
# these diagonal modules are currently failing due to dynamic shape.
# We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
# when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
# Failure - onnx_lowering: onnx.MaxPool
@ -2634,8 +2635,6 @@ ONNX_XFAIL_SET = {
"CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"CumsumInputDtypeInt32Module_basic",
"DropoutTrainModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseAcosIntModule_basic",
"ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic",

View File

@ -3,6 +3,7 @@ from torch_mlir import torchscript
from transformers import BertForMaskedLM
# Wrap the bert model to avoid multiple returns problem
class BertTinyWrapper(torch.nn.Module):
def __init__(self) -> None:

View File

@ -257,9 +257,9 @@ class _FXGraphImporter:
# FakeTensor's in case of a tuple return with multiple elements.
self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {}
self._module = ir.Module.create(ir.Location.unknown())
self._module.operation.attributes[
"torch.debug_module_name"
] = ir.StringAttr.get(func_name)
self._module.operation.attributes["torch.debug_module_name"] = (
ir.StringAttr.get(func_name)
)
function_type = _extract_function_type_from_graph(g)
func = func_dialect.FuncOp(
func_name,

View File

@ -526,6 +526,9 @@ def atenelu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu
def atenprelu〡shape(self: List[int], weight: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atencelu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
return upstream_shape_functions.unary(self)
def atenselu〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -958,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
# TODO: This should be upstreamed.
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool):
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int"
def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool):
assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int"
kL = kernel_size[0]
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int"
assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int"
dL = kL if len(stride) == 0 else stride[0]
assert len(padding) == 1, "avg_pool1d: padding must be a single int"
assert len(padding) == 1, "pool1d: padding must be a single int"
padL = padding[0]
dilationL = 1
@ -1001,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]):
return shape
def atenavg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
return pool1d(self, kernel_size, stride, padding, ceil_mode)
def atenmax_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
return pool1d(self, kernel_size, stride, padding, ceil_mode)
def atenadaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
return adaptive_avg_pool1d(self, output_size)
@ -2652,6 +2658,11 @@ def atenprelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tu
assert self_dtype == weight_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, alpha=1.))
def atencelu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1.) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
def atenrelu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
(ns, unqual + "_", overload if not is_functional_op else "")
),
emitter_td,
traits=["IsTrailingUnderscoreInplaceVariant"]
if not is_functional_op
else [],
traits=(
["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else []
),
)
# ==========================================================================
@ -472,6 +472,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)")
emit("aten::real : (Tensor) -> (Tensor)")
emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
@ -590,9 +591,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit(
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
has_canonicalizer=True,
)
emit(
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
@ -973,6 +976,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::format : (...) -> (str)")
emit("aten::join : (str, str[]) -> (str)")
emit("aten::warn : (str, int) -> ()")
emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True)
# Type conversion ops.
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
@ -1101,7 +1105,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# `prims::` namespace.
# ==========================================================================
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True)
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
emit("prims::sqrt : (Tensor) -> (Tensor)")
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")

View File

@ -46,7 +46,7 @@ def convert_onnx(model, inputs):
examples = []
input_names = []
dynamic_tensors = {}
for (index, arg) in enumerate(inputs):
for index, arg in enumerate(inputs):
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
shape = tuple(shape)
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
@ -55,7 +55,7 @@ def convert_onnx(model, inputs):
input_names.append(input_name)
dynamic_dims = {}
for (dimindex, dim) in enumerate(arg.shape):
for dimindex, dim in enumerate(arg.shape):
if dim < 0:
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)

View File

@ -101,11 +101,13 @@ class RefBackendInvoker:
def consume_return_funcs(*args):
self.result = tuple(
[
(
arg
if type in elemental_type_to_ctype
else unranked_memref_to_numpy(
arg, memref_type_to_np_dtype[type]
)
)
for arg, type in zip(args, ret_types)
]
)
@ -178,6 +180,7 @@ LOWERING_PIPELINE = (
"func.func(tm-tensor-to-loops)",
"func.func(refback-munge-memref-copy)",
"func.func(convert-linalg-to-loops)",
"func.func(expand-realloc)",
"func.func(lower-affine)",
"convert-scf-to-cf",
"func.func(refback-expand-ops-for-llvm)",
@ -191,6 +194,7 @@ LOWERING_PIPELINE = (
"convert-bufferization-to-memref",
"finalize-memref-to-llvm",
"func.func(convert-arith-to-llvm)",
"convert-vector-to-llvm",
"convert-func-to-llvm",
"convert-cf-to-llvm",
"convert-complex-to-llvm",

View File

@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, bias)
N = 10
Cin = 5
Cout = 7
Hin = 10
Win = 8
Hker = 3
Wker = 2
class ConvTranspose2DQInt8Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, input, weight, bias):
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
qinput = torch.dequantize(qinput)
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
qweight = torch.dequantize(qweight)
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
qbias = torch.dequantize(qbias)
qz = torch.ops.aten.convolution(
qinput,
qweight,
bias=qbias,
stride=[2, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=True,
output_padding=[0, 0],
groups=1,
)
return qz
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
module.forward(
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
torch.rand(Cout),
)

View File

@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
# ==============================================================================
class DiagonalWithStaticShapeModule(torch.nn.Module):
"""
Diagonal with static shape. The other diagonal modules are failing in onnx
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
when the shape is static.
Please remove this module and associated test once the issue is fixed.
"""
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 9], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diagonal(a)
@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 9))
# ==============================================================================
class DiagonalTransposedModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module):
@register_test_case(module_factory=lambda: QuantizedReluInt32())
def QuantizedReluInt32_basic(module, tu: TestUtils):
module.forward(
tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)
)
module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32))
# ==============================================================================
@ -1016,6 +1014,52 @@ def ElementwisePreluStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseCeluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.celu(x, 0.5)
@register_test_case(module_factory=lambda: ElementwiseCeluModule())
def ElementwiseCeluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))
# ==============================================================================
class ElementwiseCeluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 3], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.celu(x)
@register_test_case(module_factory=lambda: ElementwiseCeluStaticModule())
def ElementwiseCeluStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))
# ==============================================================================
class ElementwiseGeluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1795,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
# ==============================================================================
# torch.complex32 is not supported by the refbackend.
class ElementwiseMulTensorComplexDiffModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.complex64, True),
([-1], torch.complex128, True),
]
)
def forward(self, a, b):
return torch.mul(a, b)
@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule())
def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.complex64),
tu.randint(4, high=10).type(torch.complex128),
)
# ==============================================================================
class ElementwiseMishModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
class SliceScatterModule(torch.nn.Module):

View File

@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK: module attributes {torch.debug_module_name = "TestModule"}
class TestModule(torch.nn.Module):
def __init__(self):

View File

@ -18,6 +18,7 @@ mb = ModuleBuilder()
# `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so
# naively duplicating a Tensor retains the identity of the TensorImpl.
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module):
def __init__(self):

View File

@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: torch.class_type @__torch__.TestModule {
class TestModule(torch.nn.Module):
def __init__(self):

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: @__torch__.f
@mb.import_function
@torch.jit.script

View File

@ -11,6 +11,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>

View File

@ -13,6 +13,7 @@ mb = ModuleBuilder()
# else branch and making all defined values optional, so no special handling
# is needed.
# CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int {

View File

@ -11,6 +11,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true

View File

@ -15,6 +15,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor

View File

@ -13,6 +13,7 @@ from utils import create_script_function
mb = ModuleBuilder()
NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])])
# CHECK-LABEL: func.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK: @__torch__.returns_bool
@mb.import_function
@torch.jit.script

View File

@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder
mb = ModuleBuilder()
# CHECK: @__torch__.returns_none
@mb.import_function
@torch.jit.script

View File

@ -9,6 +9,7 @@ from torch._C import CompilationUnit
# RUN: %PYTHON %s
# Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str, **kwargs):
cu = CompilationUnit()

View File

@ -236,12 +236,6 @@ _IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
if _IS_TORCH_2_1_OR_EARLIER:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_numel,
}
SYMBOLIC_OP_TO_TORCH_OP = {
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
@ -249,13 +243,9 @@ if _IS_TORCH_2_1_OR_EARLIER:
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
}
else:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_numel.default,
}
SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP}
else:
SYMBOLIC_OP_TO_TORCH_OP = {
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
@ -264,6 +254,8 @@ else:
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
}
SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}
@dataclass(frozen=True)
class SparsityMeta:
@ -1857,8 +1849,7 @@ def _emit_operation(
# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
...
class EmptyType: ...
Empty = EmptyType()

View File

@ -156,8 +156,7 @@ class GraphInfo:
return ""
class OnnxImportError(Exception):
...
class OnnxImportError(Exception): ...
class NodeImporter:
@ -235,22 +234,22 @@ class NodeImporter:
else:
default_opset_version = opset_import.version
if default_opset_version:
container_op.attributes[
"torch.onnx_meta.opset_version"
] = IntegerAttr.get(i64_type, default_opset_version)
container_op.attributes["torch.onnx_meta.opset_version"] = (
IntegerAttr.get(i64_type, default_opset_version)
)
if opset_versions:
container_op.attributes[
"torch.onnx_meta.opset_versions"
] = DictAttr.get(opset_versions)
container_op.attributes["torch.onnx_meta.opset_versions"] = (
DictAttr.get(opset_versions)
)
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
IntegerType.get_signed(64), m.ir_version
)
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
m.producer_name
)
container_op.attributes[
"torch.onnx_meta.producer_version"
] = StringAttr.get(m.producer_version)
container_op.attributes["torch.onnx_meta.producer_version"] = (
StringAttr.get(m.producer_version)
)
def import_all(self, func=True):
"""Imports all nodes topologically."""
@ -348,8 +347,14 @@ class NodeImporter:
continue
elif handler is False:
# Active error.
# try matching attribute type ID to name for a more descriptive error message
try:
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
except ValueError:
attr_type_name = "UNKNOWN"
raise OnnxImportError(
f"ONNX importer does not support generic node attribute type {attr_type}. "
f"ONNX importer does not support generic node attribute type {attr_type_name} "
f"with ID {attr_type}. "
f"This likely means that this is a special node which requires specific "
f"handling in the importer: {onnx_attr}"
)
@ -658,9 +663,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
RankedTensorType.get(shape, IntegerType.get_signed(64)),
IntegerAttr.get(
IntegerType.get_signed(64),
(
int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data")
else tp.int64_data[0],
else tp.int64_data[0]
),
),
),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
@ -703,7 +710,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
),
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
)
),
# Intentionally unsupported: STRING
}

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Optional, Union, Dict, Tuple, Any
from typing import Optional, Union, Dict, Tuple, Any, Callable
import warnings
@ -25,7 +25,7 @@ def export_and_import(
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[list] = None,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main",
enable_graph_printing: bool = False,
**kwargs,

View File

@ -1 +1 @@
0a3e5f5badd8a0cb7fac97f5ec9d48c304e5c0b7
34ade3521ca41f20af3469bba276c2b0499c3892

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.4.0.dev20240422
torch==2.4.0.dev20240428

View File

@ -0,0 +1,20 @@
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s
// CHECK-LABEL: func.func @test_ifop_basic
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
// CHECK-DAG: } else {
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}, {
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}
return %0 : !torch.vtensor<[1],f32>
}

View File

@ -1996,3 +1996,160 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten
%0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32>
return %0 : !torch.vtensor<[3,?],f32>
}
// -----
// CHECK-LABEL: func.func @test_blackmanwindow_symmetric
func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_blackmanwindow
func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02
// CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_hannwindow
func.func @test_hannwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}
// -----
// CHECK-LABEL: func.func @test_hannwindow_symmetric
func.func @test_hannwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01
// CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00
// CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6
// CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int
// CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6
// CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32>
%0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32>
}

View File

@ -860,6 +860,57 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
@ -942,41 +993,24 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-LABEL: func.func @test_reduce_sum_square_default_axes_keepdims_example
func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example
func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32>
return %0 : !torch.vtensor<[3,2,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example
func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example
func.func @test_reduce_sum_square_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
@ -984,15 +1018,65 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero
func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 8: si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[2,0,4],f32>, !torch.vtensor<[2,0,4],f32> -> !torch.vtensor<[2,0,4],f32>
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[2,0,1],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32>
return %0 : !torch.vtensor<[2,0,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_example
func.func @test_reduce_sum_square_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
return %0 : !torch.vtensor<[3,1,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_int_example
func.func @test_reduce_sum_square_keepdims_int_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
// CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32>
%0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
return %0 : !torch.vtensor<[3,1,2],f32>
}
// -----
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>

View File

@ -55,7 +55,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt
// CHECK-LABEL: func.func @torch.aten.reciprocal(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
@ -124,7 +124,7 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
@ -152,7 +152,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
@ -185,7 +185,7 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
@ -214,7 +214,7 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>

View File

@ -4,9 +4,9 @@
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "none"
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 1.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 2.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 5.000000e-01 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
@ -487,7 +487,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
// CHECK-LABEL: func.func @torch.aten.relu(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 0.000000e+00 : f32}> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>

View File

@ -10,7 +10,7 @@
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false}> : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
@ -31,7 +31,7 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false}> : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
@ -53,7 +53,7 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false}> : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>

View File

@ -14,11 +14,11 @@
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>}> ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
@ -46,12 +46,12 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>}> ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
// CHECK: })
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 2, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 1, 1>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
@ -96,7 +96,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>}> ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array<i64: 1, 1, 1>, window_dimensions = array<i64: 1, 3, 3>, window_strides = array<i64: 1, 2, 2>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
// CHECK: }) : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
// CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>
@ -137,11 +137,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
@ -158,11 +159,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
@ -194,11 +196,12 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]])
// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>}> ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: stablehlo.return %[[T10]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 2, 2>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: }) : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>

View File

@ -22,10 +22,10 @@
// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x1xi64>
// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor<?x?x1xi64>
// CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor<?x?x1xi64>, tensor<?x?x1xi64>) -> tensor<?x?x2xi64>
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) ({
// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
// CHECK: ^bb0(%arg3: tensor<i64>, %[[ARG_4:.*]]: tensor<i64>):
// CHECK: stablehlo.return %[[ARG_4]] : tensor<i64>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false} : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
// CHECK: }) : (tensor<?x?xi64>, tensor<?x?x2xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
// CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64>
func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {

View File

@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool {
// CHECK-LABEL: func.func @torch.aten.eq.str$same_operand(
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true
// CHECK-NEXT: return %[[F]] : !torch.bool
// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-NEXT: return %[[TRUE]] : !torch.bool
func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool {
%0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
return %0 : !torch.bool
@ -522,8 +522,8 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool {
}
// CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool true
// CHECK: return %[[FALSE]] : !torch.bool
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func.func @torch.aten.ne.str$different_value() -> !torch.bool {
%str4 = torch.constant.str "4"
%str5 = torch.constant.str "5"
@ -533,16 +533,16 @@ func.func @torch.aten.ne.str$different_value() -> !torch.bool {
// CHECK-LABEL: func.func @torch.aten.ne.str$same_operand(
// CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool {
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false
// CHECK-NEXT: return %[[F]] : !torch.bool
// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-NEXT: return %[[FALSE]] : !torch.bool
func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool {
%0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool false
// CHECK: return %[[TRUE]] : !torch.bool
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func.func @torch.aten.ne.str$same_value() -> !torch.bool {
%str4 = torch.constant.str "4"
%str4_0 = torch.constant.str "4"
@ -568,6 +568,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int {
return %2 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool
func.func @torch.aten.__contains__.str_list$false() -> !torch.bool {
%str = torch.constant.str "c"
%str_0 = torch.constant.str "b"
%str_1 = torch.constant.str "a"
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func.func @torch.aten.__contains__.str_list$true() -> !torch.bool {
%str = torch.constant.str "aa"
%str_0 = torch.constant.str "aa"
%str_1 = torch.constant.str "ccc"
%1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list<str>
%2 = torch.aten.__contains__.str_list %1, %str : !torch.list<str>, !torch.str -> !torch.bool
return %2 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.__not__
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
@ -2950,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}
// -----
// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32>
func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32>
return %0 : !torch.vtensor<[64],f32>
}
// -----
// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
// CHECK: return %[[RET]] : !torch.vtensor<[64],si32>
func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
return %0 : !torch.vtensor<[64],si32>
}
// -----
// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]]
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32>
func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
return %result0 : !torch.vtensor<[10,64,56,56],f32>
}

View File

@ -105,6 +105,33 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
print(m)
@run
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
def test_broadcast_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.broadcast_to(x, (y.shape[0], -1))
# Sample inputs
x = torch.randn(1, 2)
y = torch.randn(10)
dim_0 = Dim("dim_0")
dynamic_shapes = {
"x": {},
"y": {0: dim_0},
}
m = fx.export_and_import(
Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net"
)
print(m)
@make_boxed_compiler
def fx_import_aot_autograd_backend(
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend(
@run
# CHECK-LABEL: test_stateless_fx_import
# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
def test_stateless_fx_import():

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision==0.19.0.dev20240422
torchvision==0.19.0.dev20240428